shrish191 commited on
Commit
aef26d4
Β·
verified Β·
1 Parent(s): 8cedd8e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +183 -1
app.py CHANGED
@@ -359,7 +359,7 @@ demo = gr.Interface(
359
 
360
  demo.launch()
361
  '''
362
-
363
  import gradio as gr
364
  from transformers import TFBertForSequenceClassification, BertTokenizer
365
  import tensorflow as tf
@@ -534,6 +534,188 @@ demo = gr.TabbedInterface(
534
  )
535
 
536
  demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
537
 
538
 
539
 
 
359
 
360
  demo.launch()
361
  '''
362
+ '''
363
  import gradio as gr
364
  from transformers import TFBertForSequenceClassification, BertTokenizer
365
  import tensorflow as tf
 
534
  )
535
 
536
  demo.launch()
537
+ '''
538
+ import gradio as gr
539
+ from transformers import TFBertForSequenceClassification, BertTokenizer
540
+ import tensorflow as tf
541
+ import praw
542
+ import os
543
+ import pytesseract
544
+ from PIL import Image
545
+ import cv2
546
+ import numpy as np
547
+ import re
548
+
549
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
550
+ import torch
551
+ from scipy.special import softmax
552
+ import matplotlib.pyplot as plt
553
+ import pandas as pd
554
+
555
+ # Install tesseract OCR (only runs once in Hugging Face Spaces)
556
+ os.system("apt-get update && apt-get install -y tesseract-ocr")
557
+
558
+ # Load main model
559
+ model = TFBertForSequenceClassification.from_pretrained("shrish191/sentiment-bert")
560
+ tokenizer = BertTokenizer.from_pretrained("shrish191/sentiment-bert")
561
+
562
+ LABELS = {0: "Neutral", 1: "Positive", 2: "Negative"}
563
+
564
+ # Load fallback model
565
+ fallback_model_name = "cardiffnlp/twitter-roberta-base-sentiment"
566
+ fallback_tokenizer = AutoTokenizer.from_pretrained(fallback_model_name)
567
+ fallback_model = AutoModelForSequenceClassification.from_pretrained(fallback_model_name)
568
+
569
+ # Reddit API setup
570
+ reddit = praw.Reddit(
571
+ client_id=os.getenv("REDDIT_CLIENT_ID"),
572
+ client_secret=os.getenv("REDDIT_CLIENT_SECRET"),
573
+ user_agent=os.getenv("REDDIT_USER_AGENT", "sentiment-classifier-ui")
574
+ )
575
+
576
+ def fetch_reddit_text(reddit_url):
577
+ try:
578
+ submission = reddit.submission(url=reddit_url)
579
+ return f"{submission.title}\n\n{submission.selftext}"
580
+ except Exception as e:
581
+ return f"Error fetching Reddit post: {str(e)}"
582
+
583
+ def fallback_classifier(text):
584
+ encoded_input = fallback_tokenizer(text, return_tensors='pt', truncation=True, padding=True)
585
+ with torch.no_grad():
586
+ output = fallback_model(**encoded_input)
587
+ scores = softmax(output.logits.numpy()[0])
588
+ labels = ['Negative', 'Neutral', 'Positive']
589
+ pred_index = scores.argmax()
590
+ pred_label = labels[pred_index]
591
+ confidence = float(scores[pred_index])
592
+ return f"Prediction: {pred_label} (Confidence: {confidence*100:.2f}%)"
593
+
594
+ def clean_ocr_text(text):
595
+ text = text.strip()
596
+ text = re.sub(r'\s+', ' ', text)
597
+ text = re.sub(r'[^\x00-\x7F]+', '', text)
598
+ return text
599
+
600
+ def classify_sentiment(text_input, reddit_url, image):
601
+ if reddit_url.strip():
602
+ text = fetch_reddit_text(reddit_url)
603
+ elif image is not None:
604
+ try:
605
+ img_array = np.array(image)
606
+ gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
607
+ _, thresh = cv2.threshold(gray, 150, 255, cv2.THRESH_BINARY)
608
+ text = pytesseract.image_to_string(thresh)
609
+ text = clean_ocr_text(text)
610
+ except Exception as e:
611
+ return f"[!] OCR failed: {str(e)}"
612
+ elif text_input.strip():
613
+ text = text_input
614
+ else:
615
+ return "[!] Please enter some text, upload an image, or provide a Reddit URL."
616
+
617
+ if text.lower().startswith("error") or "Unable to extract" in text:
618
+ return f"[!] {text}"
619
+
620
+ try:
621
+ inputs = tokenizer(text, return_tensors="tf", truncation=True, padding=True)
622
+ outputs = model(inputs)
623
+ probs = tf.nn.softmax(outputs.logits, axis=1)
624
+ confidence = float(tf.reduce_max(probs).numpy())
625
+ pred_label = tf.argmax(probs, axis=1).numpy()[0]
626
+
627
+ if confidence < 0.5:
628
+ return fallback_classifier(text)
629
+
630
+ return f"Prediction: {LABELS[pred_label]} (Confidence: {confidence*100:.2f}%)"
631
+ except Exception as e:
632
+ return f"[!] Prediction error: {str(e)}"
633
+
634
+ # Subreddit sentiment analysis function
635
+ def analyze_subreddit(subreddit_name):
636
+ try:
637
+ subreddit = reddit.subreddit(subreddit_name)
638
+ posts = list(subreddit.hot(limit=20))
639
+
640
+ sentiments = []
641
+ titles = []
642
+
643
+ for post in posts:
644
+ text = f"{post.title}\n{post.selftext}"
645
+ try:
646
+ inputs = tokenizer(text, return_tensors="tf", truncation=True, padding=True)
647
+ outputs = model(inputs)
648
+ probs = tf.nn.softmax(outputs.logits, axis=1)
649
+ confidence = float(tf.reduce_max(probs).numpy())
650
+ pred_label = tf.argmax(probs, axis=1).numpy()[0]
651
+
652
+ if confidence >= 0.5:
653
+ sentiment = f"{LABELS[pred_label]} ({confidence*100:.2f}%)"
654
+ else:
655
+ sentiment = fallback_classifier(text).replace("Prediction: ", "")
656
+ except:
657
+ sentiment = "Error"
658
+ sentiments.append(sentiment)
659
+ titles.append(post.title)
660
+
661
+ df = pd.DataFrame({"Title": titles, "Sentiment": sentiments})
662
+ sentiment_labels_only = [s.split(" ")[0] if "(" in s else s for s in sentiments]
663
+ sentiment_counts = pd.Series(sentiment_labels_only).value_counts()
664
+
665
+ # Plot bar chart
666
+ fig, ax = plt.subplots()
667
+ sentiment_counts.plot(kind="bar", color=["red", "green", "gray"], ax=ax)
668
+ ax.set_title(f"Sentiment Distribution in r/{subreddit_name}")
669
+ ax.set_xlabel("Sentiment")
670
+ ax.set_ylabel("Number of Posts")
671
+
672
+ return fig, df
673
+ except Exception as e:
674
+ return f"[!] Error: {str(e)}", pd.DataFrame()
675
+
676
+ # Gradio tab 1: Text/Image/Reddit Post Analysis
677
+ main_interface = gr.Interface(
678
+ fn=classify_sentiment,
679
+ inputs=[
680
+ gr.Textbox(
681
+ label="Text Input (can be tweet or any content)",
682
+ placeholder="Paste tweet or type any content here...",
683
+ lines=4
684
+ ),
685
+ gr.Textbox(
686
+ label="Reddit Post URL",
687
+ placeholder="Paste a Reddit post URL (optional)",
688
+ lines=1
689
+ ),
690
+ gr.Image(
691
+ label="Upload Image (optional)",
692
+ type="pil"
693
+ )
694
+ ],
695
+ outputs="text",
696
+ title="Sentiment Analyzer",
697
+ description="πŸ” Paste any text, Reddit post URL, or upload an image containing text to analyze sentiment.\n\nπŸ’‘ Tweet URLs are not supported. Please paste tweet content or screenshot instead."
698
+ )
699
+
700
+ # Gradio tab 2: Subreddit Analysis
701
+ subreddit_interface = gr.Interface(
702
+ fn=analyze_subreddit,
703
+ inputs=gr.Textbox(label="Subreddit Name", placeholder="e.g., AskReddit"),
704
+ outputs=[
705
+ gr.Plot(label="Sentiment Distribution"),
706
+ gr.Dataframe(label="Post Titles and Sentiments", wrap=True)
707
+ ],
708
+ title="Subreddit Sentiment Analysis",
709
+ description="πŸ“Š Enter a subreddit to analyze sentiment of its top 20 hot posts."
710
+ )
711
+
712
+ # Tabs
713
+ demo = gr.TabbedInterface(
714
+ interface_list=[main_interface, subreddit_interface],
715
+ tab_names=["General Sentiment Analysis", "Subreddit Analysis"]
716
+ )
717
+
718
+ demo.launch()
719
 
720
 
721