shrish191 commited on
Commit
e403ca5
·
verified ·
1 Parent(s): aef26d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -175
app.py CHANGED
@@ -359,7 +359,7 @@ demo = gr.Interface(
359
 
360
  demo.launch()
361
  '''
362
- '''
363
  import gradio as gr
364
  from transformers import TFBertForSequenceClassification, BertTokenizer
365
  import tensorflow as tf
@@ -534,188 +534,14 @@ demo = gr.TabbedInterface(
534
  )
535
 
536
  demo.launch()
537
- '''
538
- import gradio as gr
539
- from transformers import TFBertForSequenceClassification, BertTokenizer
540
- import tensorflow as tf
541
- import praw
542
- import os
543
- import pytesseract
544
- from PIL import Image
545
- import cv2
546
- import numpy as np
547
- import re
548
-
549
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
550
- import torch
551
- from scipy.special import softmax
552
- import matplotlib.pyplot as plt
553
- import pandas as pd
554
 
555
- # Install tesseract OCR (only runs once in Hugging Face Spaces)
556
- os.system("apt-get update && apt-get install -y tesseract-ocr")
557
 
558
- # Load main model
559
- model = TFBertForSequenceClassification.from_pretrained("shrish191/sentiment-bert")
560
- tokenizer = BertTokenizer.from_pretrained("shrish191/sentiment-bert")
561
 
562
- LABELS = {0: "Neutral", 1: "Positive", 2: "Negative"}
563
 
564
- # Load fallback model
565
- fallback_model_name = "cardiffnlp/twitter-roberta-base-sentiment"
566
- fallback_tokenizer = AutoTokenizer.from_pretrained(fallback_model_name)
567
- fallback_model = AutoModelForSequenceClassification.from_pretrained(fallback_model_name)
568
 
569
- # Reddit API setup
570
- reddit = praw.Reddit(
571
- client_id=os.getenv("REDDIT_CLIENT_ID"),
572
- client_secret=os.getenv("REDDIT_CLIENT_SECRET"),
573
- user_agent=os.getenv("REDDIT_USER_AGENT", "sentiment-classifier-ui")
574
- )
575
 
576
- def fetch_reddit_text(reddit_url):
577
- try:
578
- submission = reddit.submission(url=reddit_url)
579
- return f"{submission.title}\n\n{submission.selftext}"
580
- except Exception as e:
581
- return f"Error fetching Reddit post: {str(e)}"
582
 
583
- def fallback_classifier(text):
584
- encoded_input = fallback_tokenizer(text, return_tensors='pt', truncation=True, padding=True)
585
- with torch.no_grad():
586
- output = fallback_model(**encoded_input)
587
- scores = softmax(output.logits.numpy()[0])
588
- labels = ['Negative', 'Neutral', 'Positive']
589
- pred_index = scores.argmax()
590
- pred_label = labels[pred_index]
591
- confidence = float(scores[pred_index])
592
- return f"Prediction: {pred_label} (Confidence: {confidence*100:.2f}%)"
593
-
594
- def clean_ocr_text(text):
595
- text = text.strip()
596
- text = re.sub(r'\s+', ' ', text)
597
- text = re.sub(r'[^\x00-\x7F]+', '', text)
598
- return text
599
-
600
- def classify_sentiment(text_input, reddit_url, image):
601
- if reddit_url.strip():
602
- text = fetch_reddit_text(reddit_url)
603
- elif image is not None:
604
- try:
605
- img_array = np.array(image)
606
- gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
607
- _, thresh = cv2.threshold(gray, 150, 255, cv2.THRESH_BINARY)
608
- text = pytesseract.image_to_string(thresh)
609
- text = clean_ocr_text(text)
610
- except Exception as e:
611
- return f"[!] OCR failed: {str(e)}"
612
- elif text_input.strip():
613
- text = text_input
614
- else:
615
- return "[!] Please enter some text, upload an image, or provide a Reddit URL."
616
-
617
- if text.lower().startswith("error") or "Unable to extract" in text:
618
- return f"[!] {text}"
619
-
620
- try:
621
- inputs = tokenizer(text, return_tensors="tf", truncation=True, padding=True)
622
- outputs = model(inputs)
623
- probs = tf.nn.softmax(outputs.logits, axis=1)
624
- confidence = float(tf.reduce_max(probs).numpy())
625
- pred_label = tf.argmax(probs, axis=1).numpy()[0]
626
-
627
- if confidence < 0.5:
628
- return fallback_classifier(text)
629
 
630
- return f"Prediction: {LABELS[pred_label]} (Confidence: {confidence*100:.2f}%)"
631
- except Exception as e:
632
- return f"[!] Prediction error: {str(e)}"
633
-
634
- # Subreddit sentiment analysis function
635
- def analyze_subreddit(subreddit_name):
636
- try:
637
- subreddit = reddit.subreddit(subreddit_name)
638
- posts = list(subreddit.hot(limit=20))
639
-
640
- sentiments = []
641
- titles = []
642
-
643
- for post in posts:
644
- text = f"{post.title}\n{post.selftext}"
645
- try:
646
- inputs = tokenizer(text, return_tensors="tf", truncation=True, padding=True)
647
- outputs = model(inputs)
648
- probs = tf.nn.softmax(outputs.logits, axis=1)
649
- confidence = float(tf.reduce_max(probs).numpy())
650
- pred_label = tf.argmax(probs, axis=1).numpy()[0]
651
-
652
- if confidence >= 0.5:
653
- sentiment = f"{LABELS[pred_label]} ({confidence*100:.2f}%)"
654
- else:
655
- sentiment = fallback_classifier(text).replace("Prediction: ", "")
656
- except:
657
- sentiment = "Error"
658
- sentiments.append(sentiment)
659
- titles.append(post.title)
660
-
661
- df = pd.DataFrame({"Title": titles, "Sentiment": sentiments})
662
- sentiment_labels_only = [s.split(" ")[0] if "(" in s else s for s in sentiments]
663
- sentiment_counts = pd.Series(sentiment_labels_only).value_counts()
664
-
665
- # Plot bar chart
666
- fig, ax = plt.subplots()
667
- sentiment_counts.plot(kind="bar", color=["red", "green", "gray"], ax=ax)
668
- ax.set_title(f"Sentiment Distribution in r/{subreddit_name}")
669
- ax.set_xlabel("Sentiment")
670
- ax.set_ylabel("Number of Posts")
671
-
672
- return fig, df
673
- except Exception as e:
674
- return f"[!] Error: {str(e)}", pd.DataFrame()
675
-
676
- # Gradio tab 1: Text/Image/Reddit Post Analysis
677
- main_interface = gr.Interface(
678
- fn=classify_sentiment,
679
- inputs=[
680
- gr.Textbox(
681
- label="Text Input (can be tweet or any content)",
682
- placeholder="Paste tweet or type any content here...",
683
- lines=4
684
- ),
685
- gr.Textbox(
686
- label="Reddit Post URL",
687
- placeholder="Paste a Reddit post URL (optional)",
688
- lines=1
689
- ),
690
- gr.Image(
691
- label="Upload Image (optional)",
692
- type="pil"
693
- )
694
- ],
695
- outputs="text",
696
- title="Sentiment Analyzer",
697
- description="🔍 Paste any text, Reddit post URL, or upload an image containing text to analyze sentiment.\n\n💡 Tweet URLs are not supported. Please paste tweet content or screenshot instead."
698
- )
699
-
700
- # Gradio tab 2: Subreddit Analysis
701
- subreddit_interface = gr.Interface(
702
- fn=analyze_subreddit,
703
- inputs=gr.Textbox(label="Subreddit Name", placeholder="e.g., AskReddit"),
704
- outputs=[
705
- gr.Plot(label="Sentiment Distribution"),
706
- gr.Dataframe(label="Post Titles and Sentiments", wrap=True)
707
- ],
708
- title="Subreddit Sentiment Analysis",
709
- description="📊 Enter a subreddit to analyze sentiment of its top 20 hot posts."
710
- )
711
-
712
- # Tabs
713
- demo = gr.TabbedInterface(
714
- interface_list=[main_interface, subreddit_interface],
715
- tab_names=["General Sentiment Analysis", "Subreddit Analysis"]
716
- )
717
-
718
- demo.launch()
719
 
720
 
721
 
 
359
 
360
  demo.launch()
361
  '''
362
+
363
  import gradio as gr
364
  from transformers import TFBertForSequenceClassification, BertTokenizer
365
  import tensorflow as tf
 
534
  )
535
 
536
  demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
537
 
 
 
538
 
 
 
 
539
 
 
540
 
 
 
 
 
541
 
 
 
 
 
 
 
542
 
 
 
 
 
 
 
543
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
544
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
545
 
546
 
547