shrish191 commited on
Commit
3f40319
·
verified ·
1 Parent(s): cbf8340

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -183
app.py CHANGED
@@ -359,7 +359,7 @@ demo = gr.Interface(
359
 
360
  demo.launch()
361
  '''
362
- '''
363
  import gradio as gr
364
  from transformers import TFBertForSequenceClassification, BertTokenizer
365
  import tensorflow as tf
@@ -534,198 +534,16 @@ demo = gr.TabbedInterface(
534
  )
535
 
536
  demo.launch()
537
- '''
538
-
539
- import gradio as gr
540
- from transformers import TFBertForSequenceClassification, BertTokenizer
541
- import tensorflow as tf
542
- import praw
543
- import os
544
- import pytesseract
545
- from PIL import Image
546
- import cv2
547
- import numpy as np
548
- import re
549
-
550
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
551
- import torch
552
- from scipy.special import softmax
553
- import matplotlib.pyplot as plt
554
- import pandas as pd
555
-
556
- from evaluate import get_classification_report
557
-
558
-
559
- # Install tesseract OCR (only runs once in Hugging Face Spaces)
560
- os.system("apt-get update && apt-get install -y tesseract-ocr")
561
-
562
- # Load main model
563
- model = TFBertForSequenceClassification.from_pretrained("shrish191/sentiment-bert")
564
- tokenizer = BertTokenizer.from_pretrained("shrish191/sentiment-bert")
565
-
566
- LABELS = {0: "Neutral", 1: "Positive", 2: "Negative"}
567
-
568
- # Load fallback model
569
- fallback_model_name = "cardiffnlp/twitter-roberta-base-sentiment"
570
- fallback_tokenizer = AutoTokenizer.from_pretrained(fallback_model_name)
571
- fallback_model = AutoModelForSequenceClassification.from_pretrained(fallback_model_name)
572
-
573
- # Reddit API setup
574
- reddit = praw.Reddit(
575
- client_id=os.getenv("REDDIT_CLIENT_ID"),
576
- client_secret=os.getenv("REDDIT_CLIENT_SECRET"),
577
- user_agent=os.getenv("REDDIT_USER_AGENT", "sentiment-classifier-ui")
578
- )
579
-
580
- def fetch_reddit_text(reddit_url):
581
- try:
582
- submission = reddit.submission(url=reddit_url)
583
- return f"{submission.title}\n\n{submission.selftext}"
584
- except Exception as e:
585
- return f"Error fetching Reddit post: {str(e)}"
586
-
587
- def fallback_classifier(text):
588
- encoded_input = fallback_tokenizer(text, return_tensors='pt', truncation=True, padding=True)
589
- with torch.no_grad():
590
- output = fallback_model(**encoded_input)
591
- scores = softmax(output.logits.numpy()[0])
592
- labels = ['Negative', 'Neutral', 'Positive']
593
- return f"Prediction: {labels[scores.argmax()]}"
594
-
595
- def clean_ocr_text(text):
596
- text = text.strip()
597
- text = re.sub(r'\s+', ' ', text)
598
- text = re.sub(r'[^\x00-\x7F]+', '', text)
599
- return text
600
-
601
- def classify_sentiment(text_input, reddit_url, image):
602
- if reddit_url.strip():
603
- text = fetch_reddit_text(reddit_url)
604
- elif image is not None:
605
- try:
606
- img_array = np.array(image)
607
- gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
608
- _, thresh = cv2.threshold(gray, 150, 255, cv2.THRESH_BINARY)
609
- text = pytesseract.image_to_string(thresh)
610
- text = clean_ocr_text(text)
611
- except Exception as e:
612
- return f"[!] OCR failed: {str(e)}"
613
- elif text_input.strip():
614
- text = text_input
615
- else:
616
- return "[!] Please enter some text, upload an image, or provide a Reddit URL."
617
-
618
- if text.lower().startswith("error") or "Unable to extract" in text:
619
- return f"[!] {text}"
620
-
621
- try:
622
- inputs = tokenizer(text, return_tensors="tf", truncation=True, padding=True)
623
- outputs = model(inputs)
624
- probs = tf.nn.softmax(outputs.logits, axis=1)
625
- confidence = float(tf.reduce_max(probs).numpy())
626
- pred_label = tf.argmax(probs, axis=1).numpy()[0]
627
 
628
- if confidence < 0.5:
629
- return fallback_classifier(text)
630
 
631
- return f"Prediction: {LABELS[pred_label]}"
632
- except Exception as e:
633
- return f"[!] Prediction error: {str(e)}"
634
 
635
- # Subreddit sentiment analysis function
636
- def analyze_subreddit(subreddit_name):
637
- try:
638
- subreddit = reddit.subreddit(subreddit_name)
639
- posts = list(subreddit.hot(limit=20))
640
 
641
- sentiments = []
642
- titles = []
643
 
644
- for post in posts:
645
- text = f"{post.title}\n{post.selftext}"
646
- try:
647
- inputs = tokenizer(text, return_tensors="tf", truncation=True, padding=True)
648
- outputs = model(inputs)
649
- probs = tf.nn.softmax(outputs.logits, axis=1)
650
- confidence = float(tf.reduce_max(probs).numpy())
651
- pred_label = tf.argmax(probs, axis=1).numpy()[0]
652
 
653
- sentiment = LABELS[pred_label] if confidence >= 0.5 else fallback_classifier(text).split(": ")[-1]
654
- except:
655
- sentiment = "Error"
656
- sentiments.append(sentiment)
657
- titles.append(post.title)
658
 
659
- df = pd.DataFrame({"Title": titles, "Sentiment": sentiments})
660
- sentiment_counts = df["Sentiment"].value_counts()
661
 
662
- # Plot bar chart
663
- fig, ax = plt.subplots()
664
- sentiment_counts.plot(kind="bar", color=["red", "green", "gray"], ax=ax)
665
- ax.set_title(f"Sentiment Distribution in r/{subreddit_name}")
666
- ax.set_xlabel("Sentiment")
667
- ax.set_ylabel("Number of Posts")
668
 
669
- return fig, df
670
- except Exception as e:
671
- return f"[!] Error: {str(e)}", pd.DataFrame()
672
 
673
- # Gradio tab 1: Text/Image/Reddit Post Analysis
674
- main_interface = gr.Interface(
675
- fn=classify_sentiment,
676
- inputs=[
677
- gr.Textbox(
678
- label="Text Input (can be tweet or any content)",
679
- placeholder="Paste tweet or type any content here...",
680
- lines=4
681
- ),
682
- gr.Textbox(
683
- label="Reddit Post URL",
684
- placeholder="Paste a Reddit post URL (optional)",
685
- lines=1
686
- ),
687
- gr.Image(
688
- label="Upload Image (optional)",
689
- type="pil"
690
- )
691
- ],
692
- outputs="text",
693
- title="Sentiment Analyzer",
694
- description="🔍 Paste any text, Reddit post URL, or upload an image containing text to analyze sentiment.\n\n💡 Tweet URLs are not supported. Please paste tweet content or screenshot instead."
695
- )
696
-
697
- # Gradio tab 2: Subreddit Analysis
698
- subreddit_interface = gr.Interface(
699
- fn=analyze_subreddit,
700
- inputs=gr.Textbox(label="Subreddit Name", placeholder="e.g., AskReddit"),
701
- outputs=[
702
- gr.Plot(label="Sentiment Distribution"),
703
- gr.Dataframe(label="Post Titles and Sentiments", wrap=True)
704
- ],
705
- title="Subreddit Sentiment Analysis",
706
- description="📊 Enter a subreddit to analyze sentiment of its top 20 hot posts."
707
- )
708
- eval_interface = gr.Interface(
709
- fn=get_classification_report,
710
- inputs=[],
711
- outputs="text",
712
- title="Evaluate Model",
713
- description="Run evaluation on test.csv and view classification report."
714
- )
715
-
716
-
717
- '''# Tabs
718
- demo = gr.TabbedInterface(
719
- interface_list=[main_interface, subreddit_interface],
720
- tab_names=["General Sentiment Analysis", "Subreddit Analysis"]
721
- )'''
722
- demo = gr.TabbedInterface(
723
- interface_list=[main_interface, subreddit_interface, eval_interface],
724
- tab_names=["General Sentiment Analysis", "Subreddit Analysis", "Evaluate Model"]
725
- )
726
-
727
-
728
- demo.launch()
729
 
730
 
731
 
 
359
 
360
  demo.launch()
361
  '''
362
+
363
  import gradio as gr
364
  from transformers import TFBertForSequenceClassification, BertTokenizer
365
  import tensorflow as tf
 
534
  )
535
 
536
  demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
537
 
 
 
538
 
 
 
 
539
 
 
 
 
 
 
540
 
 
 
541
 
 
 
 
 
 
 
 
 
542
 
 
 
 
 
 
543
 
 
 
544
 
 
 
 
 
 
 
545
 
 
 
 
546
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
547
 
548
 
549