shrish191 commited on
Commit
c364051
·
verified ·
1 Parent(s): a062f7c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -182
app.py CHANGED
@@ -359,7 +359,7 @@ demo = gr.Interface(
359
 
360
  demo.launch()
361
  '''
362
- '''
363
  import gradio as gr
364
  from transformers import TFBertForSequenceClassification, BertTokenizer
365
  import tensorflow as tf
@@ -534,195 +534,16 @@ demo = gr.TabbedInterface(
534
  )
535
 
536
  demo.launch()
537
- '''
538
-
539
- import gradio as gr
540
- from transformers import TFBertForSequenceClassification, BertTokenizer
541
- import tensorflow as tf
542
- import praw
543
- import os
544
- import pytesseract
545
- from PIL import Image
546
- import cv2
547
- import numpy as np
548
- import re
549
-
550
- from evaluate import get_classification_report
551
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
552
- import torch
553
- from scipy.special import softmax
554
- import matplotlib.pyplot as plt
555
- import pandas as pd
556
-
557
- # Install tesseract OCR (only runs once in Hugging Face Spaces)
558
- os.system("apt-get update && apt-get install -y tesseract-ocr")
559
-
560
- # Load main model
561
- model = TFBertForSequenceClassification.from_pretrained("shrish191/sentiment-bert")
562
- tokenizer = BertTokenizer.from_pretrained("shrish191/sentiment-bert")
563
-
564
- LABELS = {0: "Neutral", 1: "Positive", 2: "Negative"}
565
-
566
- # Load fallback model
567
- fallback_model_name = "cardiffnlp/twitter-roberta-base-sentiment"
568
- fallback_tokenizer = AutoTokenizer.from_pretrained(fallback_model_name)
569
- fallback_model = AutoModelForSequenceClassification.from_pretrained(fallback_model_name)
570
-
571
- # Reddit API setup
572
- reddit = praw.Reddit(
573
- client_id=os.getenv("REDDIT_CLIENT_ID"),
574
- client_secret=os.getenv("REDDIT_CLIENT_SECRET"),
575
- user_agent=os.getenv("REDDIT_USER_AGENT", "sentiment-classifier-ui")
576
- )
577
-
578
- def fetch_reddit_text(reddit_url):
579
- try:
580
- submission = reddit.submission(url=reddit_url)
581
- return f"{submission.title}\n\n{submission.selftext}"
582
- except Exception as e:
583
- return f"Error fetching Reddit post: {str(e)}"
584
 
585
- def fallback_classifier(text):
586
- encoded_input = fallback_tokenizer(text, return_tensors='pt', truncation=True, padding=True)
587
- with torch.no_grad():
588
- output = fallback_model(**encoded_input)
589
- scores = softmax(output.logits.numpy()[0])
590
- labels = ['Negative', 'Neutral', 'Positive']
591
- return f"Prediction: {labels[scores.argmax()]}"
592
 
593
- def clean_ocr_text(text):
594
- text = text.strip()
595
- text = re.sub(r'\s+', ' ', text)
596
- text = re.sub(r'[^\x00-\x7F]+', '', text)
597
- return text
598
 
599
- def classify_sentiment(text_input, reddit_url, image):
600
- if reddit_url.strip():
601
- text = fetch_reddit_text(reddit_url)
602
- elif image is not None:
603
- try:
604
- img_array = np.array(image)
605
- gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
606
- _, thresh = cv2.threshold(gray, 150, 255, cv2.THRESH_BINARY)
607
- text = pytesseract.image_to_string(thresh)
608
- text = clean_ocr_text(text)
609
- except Exception as e:
610
- return f"[!] OCR failed: {str(e)}"
611
- elif text_input.strip():
612
- text = text_input
613
- else:
614
- return "[!] Please enter some text, upload an image, or provide a Reddit URL."
615
 
616
- if text.lower().startswith("error") or "Unable to extract" in text:
617
- return f"[!] {text}"
618
 
619
- try:
620
- inputs = tokenizer(text, return_tensors="tf", truncation=True, padding=True)
621
- outputs = model(inputs)
622
- probs = tf.nn.softmax(outputs.logits, axis=1)
623
- confidence = float(tf.reduce_max(probs).numpy())
624
- pred_label = tf.argmax(probs, axis=1).numpy()[0]
625
 
626
- if confidence < 0.5:
627
- return fallback_classifier(text)
628
 
629
- return f"Prediction: {LABELS[pred_label]}"
630
- except Exception as e:
631
- return f"[!] Prediction error: {str(e)}"
632
-
633
- def analyze_subreddit(subreddit_name):
634
- try:
635
- subreddit = reddit.subreddit(subreddit_name)
636
- posts = list(subreddit.hot(limit=20))
637
-
638
- sentiments = []
639
- titles = []
640
-
641
- for post in posts:
642
- text = f"{post.title}\n{post.selftext}"
643
- try:
644
- inputs = tokenizer(text, return_tensors="tf", truncation=True, padding=True)
645
- outputs = model(inputs)
646
- probs = tf.nn.softmax(outputs.logits, axis=1)
647
- confidence = float(tf.reduce_max(probs).numpy())
648
- pred_label = tf.argmax(probs, axis=1).numpy()[0]
649
-
650
- sentiment = LABELS[pred_label] if confidence >= 0.5 else fallback_classifier(text).split(": ")[-1]
651
- except:
652
- sentiment = "Error"
653
- sentiments.append(sentiment)
654
- titles.append(post.title)
655
-
656
- df = pd.DataFrame({"Title": titles, "Sentiment": sentiments})
657
- sentiment_counts = df["Sentiment"].value_counts()
658
-
659
- # Plot bar chart
660
- fig, ax = plt.subplots()
661
- sentiment_counts.plot(kind="bar", color=["red", "green", "gray"], ax=ax)
662
- ax.set_title(f"Sentiment Distribution in r/{subreddit_name}")
663
- ax.set_xlabel("Sentiment")
664
- ax.set_ylabel("Number of Posts")
665
-
666
- return fig, df
667
- except Exception as e:
668
- return f"[!] Error: {str(e)}", pd.DataFrame()
669
-
670
- # Gradio tab 1: Text/Image/Reddit Post Analysis
671
- main_interface = gr.Interface(
672
- fn=classify_sentiment,
673
- inputs=[
674
- gr.Textbox(
675
- label="Text Input (can be tweet or any content)",
676
- placeholder="Paste tweet or type any content here...",
677
- lines=4
678
- ),
679
- gr.Textbox(
680
- label="Reddit Post URL",
681
- placeholder="Paste a Reddit post URL (optional)",
682
- lines=1
683
- ),
684
- gr.Image(
685
- label="Upload Image (optional)",
686
- type="pil"
687
- )
688
- ],
689
- outputs="text",
690
- title="Sentiment Analyzer",
691
- description="🔍 Paste any text, Reddit post URL, or upload an image containing text to analyze sentiment.\n\n💡 Tweet URLs are not supported. Please paste tweet content or screenshot instead."
692
- )
693
-
694
- # Gradio tab 2: Subreddit Analysis
695
- subreddit_interface = gr.Interface(
696
- fn=analyze_subreddit,
697
- inputs=gr.Textbox(label="Subreddit Name", placeholder="e.g., AskReddit"),
698
- outputs=[
699
- gr.Plot(label="Sentiment Distribution"),
700
- gr.Dataframe(label="Post Titles and Sentiments", wrap=True)
701
- ],
702
- title="Subreddit Sentiment Analysis",
703
- description="📊 Enter a subreddit to analyze sentiment of its top 20 hot posts."
704
- )
705
- eval_interface = gr.Interface(
706
- fn=lambda: get_classification_report(),
707
- inputs=[],
708
- outputs="text",
709
- title="Evaluate Model",
710
- description="Run evaluation on test.csv and view classification report."
711
- )
712
-
713
-
714
- # Combine tabs
715
- '''demo = gr.TabbedInterface(
716
- interface_list=[main_interface, subreddit_interface],
717
- tab_names=["General Sentiment Analysis", "Subreddit Analysis"]
718
- )
719
- '''
720
- demo = gr.TabbedInterface(
721
- interface_list=[main_interface, subreddit_interface, eval_interface],
722
- tab_names=["General Sentiment Analysis", "Subreddit Analysis", "Evaluate Model"]
723
- )
724
 
725
- demo.launch()
726
 
727
 
728
 
 
359
 
360
  demo.launch()
361
  '''
362
+
363
  import gradio as gr
364
  from transformers import TFBertForSequenceClassification, BertTokenizer
365
  import tensorflow as tf
 
534
  )
535
 
536
  demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
537
 
 
 
 
 
 
 
 
538
 
 
 
 
 
 
539
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
540
 
 
 
541
 
 
 
 
 
 
 
542
 
 
 
543
 
544
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
545
 
546
+
547
 
548
 
549