shrish191 commited on
Commit
149d261
·
verified ·
1 Parent(s): bd6cebd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +150 -1
app.py CHANGED
@@ -237,7 +237,7 @@ demo = gr.Interface(
237
 
238
  demo.launch()
239
  '''
240
-
241
  import gradio as gr
242
  from transformers import TFBertForSequenceClassification, BertTokenizer
243
  import tensorflow as tf
@@ -358,6 +358,155 @@ demo = gr.Interface(
358
  )
359
 
360
  demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
 
362
 
363
 
 
237
 
238
  demo.launch()
239
  '''
240
+ '''
241
  import gradio as gr
242
  from transformers import TFBertForSequenceClassification, BertTokenizer
243
  import tensorflow as tf
 
358
  )
359
 
360
  demo.launch()
361
+ '''
362
+ import gradio as gr
363
+ from transformers import TFBertForSequenceClassification, BertTokenizer
364
+ import tensorflow as tf
365
+ import praw
366
+ import os
367
+ import pytesseract
368
+ from PIL import Image
369
+ import numpy as np
370
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
371
+ import torch
372
+ from scipy.special import softmax
373
+ import plotly.graph_objs as go
374
+
375
+ # Install tesseract OCR (only runs once in Hugging Face Spaces)
376
+ os.system("apt-get update && apt-get install -y tesseract-ocr")
377
+
378
+ # Load main model
379
+ model = TFBertForSequenceClassification.from_pretrained("shrish191/sentiment-bert")
380
+ tokenizer = BertTokenizer.from_pretrained("shrish191/sentiment-bert")
381
+
382
+ LABELS = {
383
+ 0: "Neutral",
384
+ 1: "Positive",
385
+ 2: "Negative"
386
+ }
387
+
388
+ # Load fallback model
389
+ fallback_model_name = "cardiffnlp/twitter-roberta-base-sentiment"
390
+ fallback_tokenizer = AutoTokenizer.from_pretrained(fallback_model_name)
391
+ fallback_model = AutoModelForSequenceClassification.from_pretrained(fallback_model_name)
392
+
393
+ # Reddit API setup
394
+ reddit = praw.Reddit(
395
+ client_id=os.getenv("REDDIT_CLIENT_ID"),
396
+ client_secret=os.getenv("REDDIT_CLIENT_SECRET"),
397
+ user_agent=os.getenv("REDDIT_USER_AGENT", "sentiment-classifier-ui")
398
+ )
399
+
400
+ def fetch_reddit_text(reddit_url):
401
+ try:
402
+ submission = reddit.submission(url=reddit_url)
403
+ return f"{submission.title}\n\n{submission.selftext}"
404
+ except Exception as e:
405
+ return f"Error fetching Reddit post: {str(e)}"
406
+
407
+ def fetch_subreddit_texts(subreddit_name, limit=15):
408
+ texts = []
409
+ try:
410
+ subreddit = reddit.subreddit(subreddit_name)
411
+ for submission in subreddit.hot(limit=limit):
412
+ combined = f"{submission.title} {submission.selftext}".strip()
413
+ if combined:
414
+ texts.append(combined)
415
+ return texts
416
+ except Exception as e:
417
+ return [f"Error fetching subreddit: {str(e)}"]
418
+
419
+ def fallback_classifier(text):
420
+ encoded_input = fallback_tokenizer(text, return_tensors='pt', truncation=True, padding=True)
421
+ with torch.no_grad():
422
+ output = fallback_model(**encoded_input)
423
+ scores = softmax(output.logits.numpy()[0])
424
+ labels = ['Negative', 'Neutral', 'Positive']
425
+ return f"Prediction: {labels[scores.argmax()]}"
426
+
427
+ def classify_multiple_sentiments(texts):
428
+ counts = {"Positive": 0, "Neutral": 0, "Negative": 0}
429
+ for text in texts:
430
+ try:
431
+ inputs = tokenizer(text, return_tensors="tf", truncation=True, padding=True)
432
+ outputs = model(inputs)
433
+ probs = tf.nn.softmax(outputs.logits, axis=1)
434
+ confidence = float(tf.reduce_max(probs).numpy())
435
+ pred_label = tf.argmax(probs, axis=1).numpy()[0]
436
+
437
+ if confidence < 0.5:
438
+ label = fallback_classifier(text).split(":")[-1].strip()
439
+ else:
440
+ label = LABELS[pred_label]
441
+
442
+ counts[label] += 1
443
+ except:
444
+ continue
445
+ return counts
446
+
447
+ def sentiment_pie_chart(counts):
448
+ labels = list(counts.keys())
449
+ values = list(counts.values())
450
+ fig = go.Figure(data=[go.Pie(labels=labels, values=values, hole=.3)])
451
+ fig.update_layout(title_text="Sentiment Distribution in Subreddit")
452
+ return fig
453
+
454
+ def classify_sentiment(text_input, reddit_url, image, subreddit_name):
455
+ # Subreddit Dashboard has priority
456
+ if subreddit_name.strip():
457
+ texts = fetch_subreddit_texts(subreddit_name)
458
+ if "Error" in texts[0]:
459
+ return texts[0]
460
+ counts = classify_multiple_sentiments(texts)
461
+ return sentiment_pie_chart(counts)
462
+
463
+ if reddit_url.strip():
464
+ text = fetch_reddit_text(reddit_url)
465
+ elif image is not None:
466
+ try:
467
+ img_array = np.array(image)
468
+ gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
469
+ thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1]
470
+ text = pytesseract.image_to_string(thresh)
471
+ except Exception as e:
472
+ return f"[!] OCR failed: {str(e)}"
473
+ elif text_input.strip():
474
+ text = text_input
475
+ else:
476
+ return "[!] Please enter some text, upload an image, or provide a Reddit URL."
477
+
478
+ if text.lower().startswith("error") or "Unable to extract" in text:
479
+ return f"[!] {text}"
480
+
481
+ try:
482
+ inputs = tokenizer(text, return_tensors="tf", truncation=True, padding=True)
483
+ outputs = model(inputs)
484
+ probs = tf.nn.softmax(outputs.logits, axis=1)
485
+ confidence = float(tf.reduce_max(probs).numpy())
486
+ pred_label = tf.argmax(probs, axis=1).numpy()[0]
487
+
488
+ if confidence < 0.5:
489
+ return fallback_classifier(text)
490
+
491
+ return f"Prediction: {LABELS[pred_label]}"
492
+ except Exception as e:
493
+ return f"[!] Prediction error: {str(e)}"
494
+
495
+ # Gradio interface
496
+ demo = gr.Interface(
497
+ fn=classify_sentiment,
498
+ inputs=[
499
+ gr.Textbox(label="Text Input (can be tweet or any content)", placeholder="Paste tweet or type any content here...", lines=4),
500
+ gr.Textbox(label="Reddit Post URL", placeholder="Paste a Reddit post URL (optional)", lines=1),
501
+ gr.Image(label="Upload Image (optional)", type="pil"),
502
+ gr.Textbox(label="Subreddit Name", placeholder="e.g. AskReddit (optional)", lines=1),
503
+ ],
504
+ outputs=gr.outputs.Component(label="Result"),
505
+ title="Sentiment Analyzer with Dashboard",
506
+ description="\ud83d\udd0d Paste any text, Reddit post URL, upload an image, or enter a subreddit name to analyze sentiment."
507
+ )
508
+
509
+ demo.launch()
510
 
511
 
512