shrish191 commited on
Commit
a656f58
·
verified ·
1 Parent(s): 9ea7163

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -140
app.py CHANGED
@@ -359,155 +359,71 @@ demo = gr.Interface(
359
 
360
  demo.launch()
361
  '''
 
362
  import gradio as gr
363
- from transformers import TFBertForSequenceClassification, BertTokenizer
364
- import tensorflow as tf
365
  import praw
366
- import os
367
- import pytesseract
368
- from PIL import Image
369
- import numpy as np
370
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
371
- import torch
372
- from scipy.special import softmax
373
  import plotly.graph_objs as go
 
 
 
374
 
375
- # Install tesseract OCR (only runs once in Hugging Face Spaces)
376
- os.system("apt-get update && apt-get install -y tesseract-ocr")
377
-
378
- # Load main model
379
- model = TFBertForSequenceClassification.from_pretrained("shrish191/sentiment-bert")
380
- tokenizer = BertTokenizer.from_pretrained("shrish191/sentiment-bert")
381
-
382
- LABELS = {
383
- 0: "Neutral",
384
- 1: "Positive",
385
- 2: "Negative"
386
- }
387
-
388
- # Load fallback model
389
- fallback_model_name = "cardiffnlp/twitter-roberta-base-sentiment"
390
- fallback_tokenizer = AutoTokenizer.from_pretrained(fallback_model_name)
391
- fallback_model = AutoModelForSequenceClassification.from_pretrained(fallback_model_name)
392
-
393
- # Reddit API setup
394
- reddit = praw.Reddit(
395
- client_id=os.getenv("REDDIT_CLIENT_ID"),
396
- client_secret=os.getenv("REDDIT_CLIENT_SECRET"),
397
- user_agent=os.getenv("REDDIT_USER_AGENT", "sentiment-classifier-ui")
398
- )
399
-
400
- def fetch_reddit_text(reddit_url):
401
- try:
402
- submission = reddit.submission(url=reddit_url)
403
- return f"{submission.title}\n\n{submission.selftext}"
404
- except Exception as e:
405
- return f"Error fetching Reddit post: {str(e)}"
406
-
407
- def fetch_subreddit_texts(subreddit_name, limit=15):
408
- texts = []
409
- try:
410
- subreddit = reddit.subreddit(subreddit_name)
411
- for submission in subreddit.hot(limit=limit):
412
- combined = f"{submission.title} {submission.selftext}".strip()
413
- if combined:
414
- texts.append(combined)
415
- return texts
416
- except Exception as e:
417
- return [f"Error fetching subreddit: {str(e)}"]
418
 
419
- def fallback_classifier(text):
420
- encoded_input = fallback_tokenizer(text, return_tensors='pt', truncation=True, padding=True)
421
- with torch.no_grad():
422
- output = fallback_model(**encoded_input)
423
- scores = softmax(output.logits.numpy()[0])
424
  labels = ['Negative', 'Neutral', 'Positive']
425
- return f"Prediction: {labels[scores.argmax()]}"
426
-
427
- def classify_multiple_sentiments(texts):
428
- counts = {"Positive": 0, "Neutral": 0, "Negative": 0}
429
- for text in texts:
430
- try:
431
- inputs = tokenizer(text, return_tensors="tf", truncation=True, padding=True)
432
- outputs = model(inputs)
433
- probs = tf.nn.softmax(outputs.logits, axis=1)
434
- confidence = float(tf.reduce_max(probs).numpy())
435
- pred_label = tf.argmax(probs, axis=1).numpy()[0]
436
-
437
- if confidence < 0.5:
438
- label = fallback_classifier(text).split(":")[-1].strip()
439
- else:
440
- label = LABELS[pred_label]
441
-
442
- counts[label] += 1
443
- except:
444
- continue
445
- return counts
446
-
447
- def sentiment_pie_chart(counts):
448
- labels = list(counts.keys())
449
- values = list(counts.values())
450
- fig = go.Figure(data=[go.Pie(labels=labels, values=values, hole=.3)])
451
- fig.update_layout(title_text="Sentiment Distribution in Subreddit")
452
- return fig
453
-
454
- def classify_sentiment(text_input, reddit_url, image, subreddit_name):
455
- # Subreddit Dashboard has priority
456
- if subreddit_name.strip():
457
- texts = fetch_subreddit_texts(subreddit_name)
458
- if "Error" in texts[0]:
459
- return texts[0]
460
- counts = classify_multiple_sentiments(texts)
461
- return sentiment_pie_chart(counts)
462
-
463
- if reddit_url.strip():
464
- text = fetch_reddit_text(reddit_url)
465
- elif image is not None:
466
- try:
467
- img_array = np.array(image)
468
- gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
469
- thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1]
470
- text = pytesseract.image_to_string(thresh)
471
- except Exception as e:
472
- return f"[!] OCR failed: {str(e)}"
473
- elif text_input.strip():
474
- text = text_input
475
- else:
476
- return "[!] Please enter some text, upload an image, or provide a Reddit URL."
477
-
478
- if text.lower().startswith("error") or "Unable to extract" in text:
479
- return f"[!] {text}"
480
-
481
- try:
482
- inputs = tokenizer(text, return_tensors="tf", truncation=True, padding=True)
483
- outputs = model(inputs)
484
- probs = tf.nn.softmax(outputs.logits, axis=1)
485
- confidence = float(tf.reduce_max(probs).numpy())
486
- pred_label = tf.argmax(probs, axis=1).numpy()[0]
487
-
488
- if confidence < 0.5:
489
- return fallback_classifier(text)
490
 
491
- return f"Prediction: {LABELS[pred_label]}"
492
- except Exception as e:
493
- return f"[!] Prediction error: {str(e)}"
494
-
495
- # Gradio interface
496
- demo = gr.Interface(
497
- fn=classify_sentiment,
498
- inputs=[
499
- gr.Textbox(label="Text Input (can be tweet or any content)", placeholder="Paste tweet or type any content here...", lines=4),
500
- gr.Textbox(label="Reddit Post URL", placeholder="Paste a Reddit post URL (optional)", lines=1),
501
- gr.Image(label="Upload Image (optional)", type="pil"),
502
- gr.Textbox(label="Subreddit Name", placeholder="e.g. AskReddit (optional)", lines=1),
503
- ],
504
- outputs=gr.outputs.Component(label="Result"),
505
- title="Sentiment Analyzer with Dashboard",
506
- description="\ud83d\udd0d Paste any text, Reddit post URL, upload an image, or enter a subreddit name to analyze sentiment."
507
  )
508
 
509
- demo.launch()
510
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
511
 
512
 
513
 
 
359
 
360
  demo.launch()
361
  '''
362
+
363
  import gradio as gr
 
 
364
  import praw
365
+ import pandas as pd
 
 
 
 
 
 
366
  import plotly.graph_objs as go
367
+ from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
368
+ from tensorflow.nn import softmax
369
+ import numpy as np
370
 
371
+ # Load model and tokenizer
372
+ model_name = "shrish191/sentiment-bert"
373
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
374
+ model = TFAutoModelForSequenceClassification.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
 
376
+ def classify_sentiment(text):
377
+ inputs = tokenizer(text, return_tensors="tf", padding=True, truncation=True)
378
+ outputs = model(inputs)
379
+ scores = softmax(outputs.logits, axis=1).numpy()[0]
 
380
  labels = ['Negative', 'Neutral', 'Positive']
381
+ sentiment = labels[np.argmax(scores)]
382
+ confidence = round(float(np.max(scores)) * 100, 2)
383
+ return sentiment, confidence
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
 
385
+ # Reddit sentiment dashboard
386
+ reddit = praw.Reddit(
387
+ client_id="YOUR_CLIENT_ID",
388
+ client_secret="YOUR_CLIENT_SECRET",
389
+ user_agent="YOUR_USER_AGENT"
 
 
 
 
 
 
 
 
 
 
 
390
  )
391
 
392
+ def analyze_subreddit(subreddit_name, num_posts):
393
+ posts = []
394
+ for submission in reddit.subreddit(subreddit_name).hot(limit=num_posts):
395
+ if not submission.stickied:
396
+ sentiment, confidence = classify_sentiment(submission.title)
397
+ posts.append({"title": submission.title, "sentiment": sentiment, "confidence": confidence})
398
+
399
+ df = pd.DataFrame(posts)
400
+ sentiment_counts = df['sentiment'].value_counts().reindex(['Positive', 'Neutral', 'Negative'], fill_value=0)
401
+ total = sentiment_counts.sum()
402
+ sentiment_percentages = (sentiment_counts / total * 100).round(2)
403
+
404
+ fig = go.Figure(data=[
405
+ go.Pie(labels=sentiment_percentages.index, values=sentiment_percentages.values, hole=.4)
406
+ ])
407
+ fig.update_layout(title="Sentiment Distribution in r/{} ({} posts)".format(subreddit_name, num_posts))
408
+
409
+ return df, fig
410
+
411
+ with gr.Blocks() as demo:
412
+ gr.Markdown("## Reddit Subreddit Sentiment Dashboard")
413
+ subreddit_input = gr.Textbox(label="Enter Subreddit (without r/)", placeholder="e.g., technology")
414
+ num_posts_input = gr.Slider(10, 100, step=10, value=30, label="Number of Posts to Analyze")
415
+ analyze_button = gr.Button("Analyze")
416
+ sentiment_table = gr.Dataframe(label="Post Sentiments")
417
+ sentiment_chart = gr.Plot(label="Sentiment Pie Chart")
418
+
419
+ analyze_button.click(
420
+ analyze_subreddit,
421
+ inputs=[subreddit_input, num_posts_input],
422
+ outputs=[sentiment_table, sentiment_chart]
423
+ )
424
+
425
+ if __name__ == "__main__":
426
+ demo.launch()
427
 
428
 
429