shrish191 commited on
Commit
3f08853
Β·
verified Β·
1 Parent(s): 6c11ebf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -1
app.py CHANGED
@@ -237,7 +237,7 @@ demo = gr.Interface(
237
 
238
  demo.launch()
239
  '''
240
-
241
  import gradio as gr
242
  from transformers import TFBertForSequenceClassification, BertTokenizer
243
  import tensorflow as tf
@@ -358,6 +358,181 @@ demo = gr.Interface(
358
  )
359
 
360
  demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
 
362
 
363
 
 
237
 
238
  demo.launch()
239
  '''
240
+ '''
241
  import gradio as gr
242
  from transformers import TFBertForSequenceClassification, BertTokenizer
243
  import tensorflow as tf
 
358
  )
359
 
360
  demo.launch()
361
+ '''
362
+ import gradio as gr
363
+ from transformers import TFBertForSequenceClassification, BertTokenizer
364
+ import tensorflow as tf
365
+ import praw
366
+ import os
367
+ import pytesseract
368
+ from PIL import Image
369
+ import cv2
370
+ import numpy as np
371
+ import re
372
+
373
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
374
+ import torch
375
+ from scipy.special import softmax
376
+ import matplotlib.pyplot as plt
377
+ import pandas as pd
378
+
379
+ # Install tesseract OCR (only runs once in Hugging Face Spaces)
380
+ os.system("apt-get update && apt-get install -y tesseract-ocr")
381
+
382
+ # Load main model
383
+ model = TFBertForSequenceClassification.from_pretrained("shrish191/sentiment-bert")
384
+ tokenizer = BertTokenizer.from_pretrained("shrish191/sentiment-bert")
385
+
386
+ LABELS = {0: "Neutral", 1: "Positive", 2: "Negative"}
387
+
388
+ # Load fallback model
389
+ fallback_model_name = "cardiffnlp/twitter-roberta-base-sentiment"
390
+ fallback_tokenizer = AutoTokenizer.from_pretrained(fallback_model_name)
391
+ fallback_model = AutoModelForSequenceClassification.from_pretrained(fallback_model_name)
392
+
393
+ # Reddit API setup
394
+ reddit = praw.Reddit(
395
+ client_id=os.getenv("REDDIT_CLIENT_ID"),
396
+ client_secret=os.getenv("REDDIT_CLIENT_SECRET"),
397
+ user_agent=os.getenv("REDDIT_USER_AGENT", "sentiment-classifier-ui")
398
+ )
399
+
400
+ def fetch_reddit_text(reddit_url):
401
+ try:
402
+ submission = reddit.submission(url=reddit_url)
403
+ return f"{submission.title}\n\n{submission.selftext}"
404
+ except Exception as e:
405
+ return f"Error fetching Reddit post: {str(e)}"
406
+
407
+ def fallback_classifier(text):
408
+ encoded_input = fallback_tokenizer(text, return_tensors='pt', truncation=True, padding=True)
409
+ with torch.no_grad():
410
+ output = fallback_model(**encoded_input)
411
+ scores = softmax(output.logits.numpy()[0])
412
+ labels = ['Negative', 'Neutral', 'Positive']
413
+ return f"Prediction: {labels[scores.argmax()]}"
414
+
415
+ def clean_ocr_text(text):
416
+ text = text.strip()
417
+ text = re.sub(r'\s+', ' ', text)
418
+ text = re.sub(r'[^\x00-\x7F]+', '', text)
419
+ return text
420
+
421
+ def classify_sentiment(text_input, reddit_url, image):
422
+ if reddit_url.strip():
423
+ text = fetch_reddit_text(reddit_url)
424
+ elif image is not None:
425
+ try:
426
+ img_array = np.array(image)
427
+ gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
428
+ _, thresh = cv2.threshold(gray, 150, 255, cv2.THRESH_BINARY)
429
+ text = pytesseract.image_to_string(thresh)
430
+ text = clean_ocr_text(text)
431
+ except Exception as e:
432
+ return f"[!] OCR failed: {str(e)}"
433
+ elif text_input.strip():
434
+ text = text_input
435
+ else:
436
+ return "[!] Please enter some text, upload an image, or provide a Reddit URL."
437
+
438
+ if text.lower().startswith("error") or "Unable to extract" in text:
439
+ return f"[!] {text}"
440
+
441
+ try:
442
+ inputs = tokenizer(text, return_tensors="tf", truncation=True, padding=True)
443
+ outputs = model(inputs)
444
+ probs = tf.nn.softmax(outputs.logits, axis=1)
445
+ confidence = float(tf.reduce_max(probs).numpy())
446
+ pred_label = tf.argmax(probs, axis=1).numpy()[0]
447
+
448
+ if confidence < 0.5:
449
+ return fallback_classifier(text)
450
+
451
+ return f"Prediction: {LABELS[pred_label]}"
452
+ except Exception as e:
453
+ return f"[!] Prediction error: {str(e)}"
454
+
455
+ # Subreddit sentiment analysis function
456
+ def analyze_subreddit(subreddit_name):
457
+ try:
458
+ subreddit = reddit.subreddit(subreddit_name)
459
+ posts = list(subreddit.hot(limit=20))
460
+
461
+ sentiments = []
462
+ titles = []
463
+
464
+ for post in posts:
465
+ text = f"{post.title}\n{post.selftext}"
466
+ try:
467
+ inputs = tokenizer(text, return_tensors="tf", truncation=True, padding=True)
468
+ outputs = model(inputs)
469
+ probs = tf.nn.softmax(outputs.logits, axis=1)
470
+ confidence = float(tf.reduce_max(probs).numpy())
471
+ pred_label = tf.argmax(probs, axis=1).numpy()[0]
472
+
473
+ sentiment = LABELS[pred_label] if confidence >= 0.5 else fallback_classifier(text).split(": ")[-1]
474
+ except:
475
+ sentiment = "Error"
476
+ sentiments.append(sentiment)
477
+ titles.append(post.title)
478
+
479
+ df = pd.DataFrame({"Title": titles, "Sentiment": sentiments})
480
+ sentiment_counts = df["Sentiment"].value_counts()
481
+
482
+ # Plot bar chart
483
+ fig, ax = plt.subplots()
484
+ sentiment_counts.plot(kind="bar", color=["red", "green", "gray"], ax=ax)
485
+ ax.set_title(f"Sentiment Distribution in r/{subreddit_name}")
486
+ ax.set_xlabel("Sentiment")
487
+ ax.set_ylabel("Number of Posts")
488
+
489
+ return fig, df
490
+ except Exception as e:
491
+ return f"[!] Error: {str(e)}", pd.DataFrame()
492
+
493
+ # Gradio tab 1: Text/Image/Reddit Post Analysis
494
+ main_interface = gr.Interface(
495
+ fn=classify_sentiment,
496
+ inputs=[
497
+ gr.Textbox(
498
+ label="Text Input (can be tweet or any content)",
499
+ placeholder="Paste tweet or type any content here...",
500
+ lines=4
501
+ ),
502
+ gr.Textbox(
503
+ label="Reddit Post URL",
504
+ placeholder="Paste a Reddit post URL (optional)",
505
+ lines=1
506
+ ),
507
+ gr.Image(
508
+ label="Upload Image (optional)",
509
+ type="pil"
510
+ )
511
+ ],
512
+ outputs="text",
513
+ title="Sentiment Analyzer",
514
+ description="πŸ” Paste any text, Reddit post URL, or upload an image containing text to analyze sentiment.\n\nπŸ’‘ Tweet URLs are not supported. Please paste tweet content or screenshot instead."
515
+ )
516
+
517
+ # Gradio tab 2: Subreddit Analysis
518
+ subreddit_interface = gr.Interface(
519
+ fn=analyze_subreddit,
520
+ inputs=gr.Textbox(label="Subreddit Name", placeholder="e.g., AskReddit"),
521
+ outputs=[
522
+ gr.Plot(label="Sentiment Distribution"),
523
+ gr.Dataframe(label="Post Titles and Sentiments", wrap=True)
524
+ ],
525
+ title="Subreddit Sentiment Analysis",
526
+ description="πŸ“Š Enter a subreddit to analyze sentiment of its top 20 hot posts."
527
+ )
528
+
529
+ # Tabs
530
+ demo = gr.TabbedInterface(
531
+ interface_list=[main_interface, subreddit_interface],
532
+ tab_names=["General Sentiment Analysis", "Subreddit Analysis"]
533
+ )
534
+
535
+ demo.launch()
536
 
537
 
538