shrish191 commited on
Commit
9634b65
Β·
verified Β·
1 Parent(s): d4c903e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +175 -2
app.py CHANGED
@@ -359,7 +359,7 @@ demo = gr.Interface(
359
 
360
  demo.launch()
361
  '''
362
-
363
  import gradio as gr
364
  from transformers import TFBertForSequenceClassification, BertTokenizer
365
  import tensorflow as tf
@@ -534,9 +534,182 @@ demo = gr.TabbedInterface(
534
  )
535
 
536
  demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
537
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
538
 
539
-
540
 
541
 
542
 
 
359
 
360
  demo.launch()
361
  '''
362
+ '''
363
  import gradio as gr
364
  from transformers import TFBertForSequenceClassification, BertTokenizer
365
  import tensorflow as tf
 
534
  )
535
 
536
  demo.launch()
537
+ '''
538
+
539
+ import gradio as gr
540
+ from transformers import TFBertForSequenceClassification, BertTokenizer
541
+ import tensorflow as tf
542
+ import praw
543
+ import os
544
+ import pytesseract
545
+ from PIL import Image
546
+ import cv2
547
+ import numpy as np
548
+ import re
549
 
550
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
551
+ import torch
552
+ from scipy.special import softmax
553
+ import matplotlib.pyplot as plt
554
+ import pandas as pd
555
+
556
+ # Install tesseract OCR (only runs once in Hugging Face Spaces)
557
+ os.system("apt-get update && apt-get install -y tesseract-ocr")
558
+
559
+ # Load main model
560
+ model = TFBertForSequenceClassification.from_pretrained("shrish191/sentiment-bert")
561
+ tokenizer = BertTokenizer.from_pretrained("shrish191/sentiment-bert")
562
+
563
+ LABELS = {0: "Neutral", 1: "Positive", 2: "Negative"}
564
+
565
+ # Load fallback model
566
+ fallback_model_name = "cardiffnlp/twitter-roberta-base-sentiment"
567
+ fallback_tokenizer = AutoTokenizer.from_pretrained(fallback_model_name)
568
+ fallback_model = AutoModelForSequenceClassification.from_pretrained(fallback_model_name)
569
+
570
+ # Reddit API setup
571
+ reddit = praw.Reddit(
572
+ client_id=os.getenv("REDDIT_CLIENT_ID"),
573
+ client_secret=os.getenv("REDDIT_CLIENT_SECRET"),
574
+ user_agent=os.getenv("REDDIT_USER_AGENT", "sentiment-classifier-ui")
575
+ )
576
+
577
+ def fetch_reddit_text(reddit_url):
578
+ try:
579
+ submission = reddit.submission(url=reddit_url)
580
+ return f"{submission.title}\n\n{submission.selftext}"
581
+ except Exception as e:
582
+ return f"Error fetching Reddit post: {str(e)}"
583
+
584
+ def fallback_classifier(text):
585
+ encoded_input = fallback_tokenizer(text, return_tensors='pt', truncation=True, padding=True)
586
+ with torch.no_grad():
587
+ output = fallback_model(**encoded_input)
588
+ scores = softmax(output.logits.numpy()[0])
589
+ labels = ['Negative', 'Neutral', 'Positive']
590
+ return f"Prediction: {labels[scores.argmax()]}"
591
+
592
+ def clean_ocr_text(text):
593
+ text = text.strip()
594
+ text = re.sub(r'\s+', ' ', text)
595
+ text = re.sub(r'[^\x00-\x7F]+', '', text)
596
+ return text
597
+
598
+ def classify_sentiment(text_input, reddit_url, image):
599
+ if reddit_url.strip():
600
+ text = fetch_reddit_text(reddit_url)
601
+ elif image is not None:
602
+ try:
603
+ img_array = np.array(image)
604
+ gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
605
+ _, thresh = cv2.threshold(gray, 150, 255, cv2.THRESH_BINARY)
606
+ text = pytesseract.image_to_string(thresh)
607
+ text = clean_ocr_text(text)
608
+ except Exception as e:
609
+ return f"[!] OCR failed: {str(e)}"
610
+ elif text_input.strip():
611
+ text = text_input
612
+ else:
613
+ return "[!] Please enter some text, upload an image, or provide a Reddit URL."
614
+
615
+ if text.lower().startswith("error") or "Unable to extract" in text:
616
+ return f"[!] {text}"
617
+
618
+ try:
619
+ inputs = tokenizer(text, return_tensors="tf", truncation=True, padding=True)
620
+ outputs = model(inputs)
621
+ probs = tf.nn.softmax(outputs.logits, axis=1)
622
+ confidence = float(tf.reduce_max(probs).numpy())
623
+ pred_label = tf.argmax(probs, axis=1).numpy()[0]
624
+
625
+ if confidence < 0.5:
626
+ return fallback_classifier(text)
627
+
628
+ return f"Prediction: {LABELS[pred_label]}"
629
+ except Exception as e:
630
+ return f"[!] Prediction error: {str(e)}"
631
+
632
+ def analyze_subreddit(subreddit_name):
633
+ try:
634
+ subreddit = reddit.subreddit(subreddit_name)
635
+ posts = list(subreddit.hot(limit=20))
636
+
637
+ sentiments = []
638
+ titles = []
639
+
640
+ for post in posts:
641
+ text = f"{post.title}\n{post.selftext}"
642
+ try:
643
+ inputs = tokenizer(text, return_tensors="tf", truncation=True, padding=True)
644
+ outputs = model(inputs)
645
+ probs = tf.nn.softmax(outputs.logits, axis=1)
646
+ confidence = float(tf.reduce_max(probs).numpy())
647
+ pred_label = tf.argmax(probs, axis=1).numpy()[0]
648
+
649
+ sentiment = LABELS[pred_label] if confidence >= 0.5 else fallback_classifier(text).split(": ")[-1]
650
+ except:
651
+ sentiment = "Error"
652
+ sentiments.append(sentiment)
653
+ titles.append(post.title)
654
+
655
+ df = pd.DataFrame({"Title": titles, "Sentiment": sentiments})
656
+ sentiment_counts = df["Sentiment"].value_counts()
657
+
658
+ # Plot bar chart
659
+ fig, ax = plt.subplots()
660
+ sentiment_counts.plot(kind="bar", color=["red", "green", "gray"], ax=ax)
661
+ ax.set_title(f"Sentiment Distribution in r/{subreddit_name}")
662
+ ax.set_xlabel("Sentiment")
663
+ ax.set_ylabel("Number of Posts")
664
+
665
+ return fig, df
666
+ except Exception as e:
667
+ return f"[!] Error: {str(e)}", pd.DataFrame()
668
+
669
+ # Gradio tab 1: Text/Image/Reddit Post Analysis
670
+ main_interface = gr.Interface(
671
+ fn=classify_sentiment,
672
+ inputs=[
673
+ gr.Textbox(
674
+ label="Text Input (can be tweet or any content)",
675
+ placeholder="Paste tweet or type any content here...",
676
+ lines=4
677
+ ),
678
+ gr.Textbox(
679
+ label="Reddit Post URL",
680
+ placeholder="Paste a Reddit post URL (optional)",
681
+ lines=1
682
+ ),
683
+ gr.Image(
684
+ label="Upload Image (optional)",
685
+ type="pil"
686
+ )
687
+ ],
688
+ outputs="text",
689
+ title="Sentiment Analyzer",
690
+ description="πŸ” Paste any text, Reddit post URL, or upload an image containing text to analyze sentiment.\n\nπŸ’‘ Tweet URLs are not supported. Please paste tweet content or screenshot instead."
691
+ )
692
+
693
+ # Gradio tab 2: Subreddit Analysis
694
+ subreddit_interface = gr.Interface(
695
+ fn=analyze_subreddit,
696
+ inputs=gr.Textbox(label="Subreddit Name", placeholder="e.g., AskReddit"),
697
+ outputs=[
698
+ gr.Plot(label="Sentiment Distribution"),
699
+ gr.Dataframe(label="Post Titles and Sentiments", wrap=True)
700
+ ],
701
+ title="Subreddit Sentiment Analysis",
702
+ description="πŸ“Š Enter a subreddit to analyze sentiment of its top 20 hot posts."
703
+ )
704
+
705
+ # Combine tabs
706
+ demo = gr.TabbedInterface(
707
+ interface_list=[main_interface, subreddit_interface],
708
+ tab_names=["General Sentiment Analysis", "Subreddit Analysis"]
709
+ )
710
+
711
+ demo.launch()
712
 
 
713
 
714
 
715