shrish191 commited on
Commit
ae4aef6
Β·
verified Β·
1 Parent(s): 61313a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +190 -1
app.py CHANGED
@@ -359,7 +359,7 @@ demo = gr.Interface(
359
 
360
  demo.launch()
361
  '''
362
-
363
  import gradio as gr
364
  from transformers import TFBertForSequenceClassification, BertTokenizer
365
  import tensorflow as tf
@@ -534,9 +534,198 @@ demo = gr.TabbedInterface(
534
  )
535
 
536
  demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
537
 
 
538
 
 
 
 
 
539
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
540
 
541
 
542
 
 
359
 
360
  demo.launch()
361
  '''
362
+ '''
363
  import gradio as gr
364
  from transformers import TFBertForSequenceClassification, BertTokenizer
365
  import tensorflow as tf
 
534
  )
535
 
536
  demo.launch()
537
+ '''
538
+
539
+ import gradio as gr
540
+ from transformers import TFBertForSequenceClassification, BertTokenizer
541
+ import tensorflow as tf
542
+ import praw
543
+ import os
544
+ import pytesseract
545
+ from PIL import Image
546
+ import cv2
547
+ import numpy as np
548
+ import re
549
+
550
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
551
+ import torch
552
+ from scipy.special import softmax
553
+ import matplotlib.pyplot as plt
554
+ import pandas as pd
555
+
556
+ from evaluate import get_classification_report
557
+
558
+
559
+ # Install tesseract OCR (only runs once in Hugging Face Spaces)
560
+ os.system("apt-get update && apt-get install -y tesseract-ocr")
561
+
562
+ # Load main model
563
+ model = TFBertForSequenceClassification.from_pretrained("shrish191/sentiment-bert")
564
+ tokenizer = BertTokenizer.from_pretrained("shrish191/sentiment-bert")
565
 
566
+ LABELS = {0: "Neutral", 1: "Positive", 2: "Negative"}
567
 
568
+ # Load fallback model
569
+ fallback_model_name = "cardiffnlp/twitter-roberta-base-sentiment"
570
+ fallback_tokenizer = AutoTokenizer.from_pretrained(fallback_model_name)
571
+ fallback_model = AutoModelForSequenceClassification.from_pretrained(fallback_model_name)
572
 
573
+ # Reddit API setup
574
+ reddit = praw.Reddit(
575
+ client_id=os.getenv("REDDIT_CLIENT_ID"),
576
+ client_secret=os.getenv("REDDIT_CLIENT_SECRET"),
577
+ user_agent=os.getenv("REDDIT_USER_AGENT", "sentiment-classifier-ui")
578
+ )
579
+
580
+ def fetch_reddit_text(reddit_url):
581
+ try:
582
+ submission = reddit.submission(url=reddit_url)
583
+ return f"{submission.title}\n\n{submission.selftext}"
584
+ except Exception as e:
585
+ return f"Error fetching Reddit post: {str(e)}"
586
+
587
+ def fallback_classifier(text):
588
+ encoded_input = fallback_tokenizer(text, return_tensors='pt', truncation=True, padding=True)
589
+ with torch.no_grad():
590
+ output = fallback_model(**encoded_input)
591
+ scores = softmax(output.logits.numpy()[0])
592
+ labels = ['Negative', 'Neutral', 'Positive']
593
+ return f"Prediction: {labels[scores.argmax()]}"
594
+
595
+ def clean_ocr_text(text):
596
+ text = text.strip()
597
+ text = re.sub(r'\s+', ' ', text)
598
+ text = re.sub(r'[^\x00-\x7F]+', '', text)
599
+ return text
600
+
601
+ def classify_sentiment(text_input, reddit_url, image):
602
+ if reddit_url.strip():
603
+ text = fetch_reddit_text(reddit_url)
604
+ elif image is not None:
605
+ try:
606
+ img_array = np.array(image)
607
+ gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
608
+ _, thresh = cv2.threshold(gray, 150, 255, cv2.THRESH_BINARY)
609
+ text = pytesseract.image_to_string(thresh)
610
+ text = clean_ocr_text(text)
611
+ except Exception as e:
612
+ return f"[!] OCR failed: {str(e)}"
613
+ elif text_input.strip():
614
+ text = text_input
615
+ else:
616
+ return "[!] Please enter some text, upload an image, or provide a Reddit URL."
617
+
618
+ if text.lower().startswith("error") or "Unable to extract" in text:
619
+ return f"[!] {text}"
620
+
621
+ try:
622
+ inputs = tokenizer(text, return_tensors="tf", truncation=True, padding=True)
623
+ outputs = model(inputs)
624
+ probs = tf.nn.softmax(outputs.logits, axis=1)
625
+ confidence = float(tf.reduce_max(probs).numpy())
626
+ pred_label = tf.argmax(probs, axis=1).numpy()[0]
627
+
628
+ if confidence < 0.5:
629
+ return fallback_classifier(text)
630
+
631
+ return f"Prediction: {LABELS[pred_label]}"
632
+ except Exception as e:
633
+ return f"[!] Prediction error: {str(e)}"
634
+
635
+ # Subreddit sentiment analysis function
636
+ def analyze_subreddit(subreddit_name):
637
+ try:
638
+ subreddit = reddit.subreddit(subreddit_name)
639
+ posts = list(subreddit.hot(limit=20))
640
+
641
+ sentiments = []
642
+ titles = []
643
+
644
+ for post in posts:
645
+ text = f"{post.title}\n{post.selftext}"
646
+ try:
647
+ inputs = tokenizer(text, return_tensors="tf", truncation=True, padding=True)
648
+ outputs = model(inputs)
649
+ probs = tf.nn.softmax(outputs.logits, axis=1)
650
+ confidence = float(tf.reduce_max(probs).numpy())
651
+ pred_label = tf.argmax(probs, axis=1).numpy()[0]
652
+
653
+ sentiment = LABELS[pred_label] if confidence >= 0.5 else fallback_classifier(text).split(": ")[-1]
654
+ except:
655
+ sentiment = "Error"
656
+ sentiments.append(sentiment)
657
+ titles.append(post.title)
658
+
659
+ df = pd.DataFrame({"Title": titles, "Sentiment": sentiments})
660
+ sentiment_counts = df["Sentiment"].value_counts()
661
+
662
+ # Plot bar chart
663
+ fig, ax = plt.subplots()
664
+ sentiment_counts.plot(kind="bar", color=["red", "green", "gray"], ax=ax)
665
+ ax.set_title(f"Sentiment Distribution in r/{subreddit_name}")
666
+ ax.set_xlabel("Sentiment")
667
+ ax.set_ylabel("Number of Posts")
668
+
669
+ return fig, df
670
+ except Exception as e:
671
+ return f"[!] Error: {str(e)}", pd.DataFrame()
672
+
673
+ # Gradio tab 1: Text/Image/Reddit Post Analysis
674
+ main_interface = gr.Interface(
675
+ fn=classify_sentiment,
676
+ inputs=[
677
+ gr.Textbox(
678
+ label="Text Input (can be tweet or any content)",
679
+ placeholder="Paste tweet or type any content here...",
680
+ lines=4
681
+ ),
682
+ gr.Textbox(
683
+ label="Reddit Post URL",
684
+ placeholder="Paste a Reddit post URL (optional)",
685
+ lines=1
686
+ ),
687
+ gr.Image(
688
+ label="Upload Image (optional)",
689
+ type="pil"
690
+ )
691
+ ],
692
+ outputs="text",
693
+ title="Sentiment Analyzer",
694
+ description="πŸ” Paste any text, Reddit post URL, or upload an image containing text to analyze sentiment.\n\nπŸ’‘ Tweet URLs are not supported. Please paste tweet content or screenshot instead."
695
+ )
696
+
697
+ # Gradio tab 2: Subreddit Analysis
698
+ subreddit_interface = gr.Interface(
699
+ fn=analyze_subreddit,
700
+ inputs=gr.Textbox(label="Subreddit Name", placeholder="e.g., AskReddit"),
701
+ outputs=[
702
+ gr.Plot(label="Sentiment Distribution"),
703
+ gr.Dataframe(label="Post Titles and Sentiments", wrap=True)
704
+ ],
705
+ title="Subreddit Sentiment Analysis",
706
+ description="πŸ“Š Enter a subreddit to analyze sentiment of its top 20 hot posts."
707
+ )
708
+ eval_interface = gr.Interface(
709
+ fn=get_classification_report,
710
+ inputs=[],
711
+ outputs="text",
712
+ title="Evaluate Model",
713
+ description="Run evaluation on test.csv and view classification report."
714
+ )
715
+
716
+
717
+ '''# Tabs
718
+ demo = gr.TabbedInterface(
719
+ interface_list=[main_interface, subreddit_interface],
720
+ tab_names=["General Sentiment Analysis", "Subreddit Analysis"]
721
+ )'''
722
+ demo = gr.TabbedInterface(
723
+ interface_list=[main_interface, subreddit_interface, eval_interface],
724
+ tab_names=["General Sentiment Analysis", "Subreddit Analysis", "Evaluate Model"]
725
+ )
726
+
727
+
728
+ demo.launch()
729
 
730
 
731