shrish191 commited on
Commit
1dfe639
·
verified ·
1 Parent(s): a656f58

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -52
app.py CHANGED
@@ -361,69 +361,134 @@ demo.launch()
361
  '''
362
 
363
  import gradio as gr
364
- import praw
365
- import pandas as pd
366
- import plotly.graph_objs as go
367
- from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
368
- from tensorflow.nn import softmax
369
  import numpy as np
 
 
 
 
 
 
 
370
 
371
- # Load model and tokenizer
372
- model_name = "shrish191/sentiment-bert"
373
- tokenizer = AutoTokenizer.from_pretrained(model_name)
374
- model = TFAutoModelForSequenceClassification.from_pretrained(model_name)
375
 
376
- def classify_sentiment(text):
377
- inputs = tokenizer(text, return_tensors="tf", padding=True, truncation=True)
378
- outputs = model(inputs)
379
- scores = softmax(outputs.logits, axis=1).numpy()[0]
380
- labels = ['Negative', 'Neutral', 'Positive']
381
- sentiment = labels[np.argmax(scores)]
382
- confidence = round(float(np.max(scores)) * 100, 2)
383
- return sentiment, confidence
384
 
385
- # Reddit sentiment dashboard
386
- reddit = praw.Reddit(
387
- client_id="YOUR_CLIENT_ID",
388
- client_secret="YOUR_CLIENT_SECRET",
389
- user_agent="YOUR_USER_AGENT"
390
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
 
392
- def analyze_subreddit(subreddit_name, num_posts):
393
- posts = []
394
- for submission in reddit.subreddit(subreddit_name).hot(limit=num_posts):
395
- if not submission.stickied:
396
- sentiment, confidence = classify_sentiment(submission.title)
397
- posts.append({"title": submission.title, "sentiment": sentiment, "confidence": confidence})
398
 
399
- df = pd.DataFrame(posts)
400
- sentiment_counts = df['sentiment'].value_counts().reindex(['Positive', 'Neutral', 'Negative'], fill_value=0)
401
- total = sentiment_counts.sum()
402
- sentiment_percentages = (sentiment_counts / total * 100).round(2)
 
403
 
404
- fig = go.Figure(data=[
405
- go.Pie(labels=sentiment_percentages.index, values=sentiment_percentages.values, hole=.4)
406
- ])
407
- fig.update_layout(title="Sentiment Distribution in r/{} ({} posts)".format(subreddit_name, num_posts))
408
 
409
- return df, fig
 
 
410
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
411
  with gr.Blocks() as demo:
412
- gr.Markdown("## Reddit Subreddit Sentiment Dashboard")
413
- subreddit_input = gr.Textbox(label="Enter Subreddit (without r/)", placeholder="e.g., technology")
414
- num_posts_input = gr.Slider(10, 100, step=10, value=30, label="Number of Posts to Analyze")
415
- analyze_button = gr.Button("Analyze")
416
- sentiment_table = gr.Dataframe(label="Post Sentiments")
417
- sentiment_chart = gr.Plot(label="Sentiment Pie Chart")
418
-
419
- analyze_button.click(
420
- analyze_subreddit,
421
- inputs=[subreddit_input, num_posts_input],
422
- outputs=[sentiment_table, sentiment_chart]
423
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424
 
425
- if __name__ == "__main__":
426
- demo.launch()
427
 
428
 
429
 
 
361
  '''
362
 
363
  import gradio as gr
364
+ from transformers import TFAutoModelForSequenceClassification, AutoTokenizer
365
+ import tensorflow as tf
 
 
 
366
  import numpy as np
367
+ import praw
368
+ import re
369
+ from wordcloud import WordCloud
370
+ import matplotlib.pyplot as plt
371
+ from collections import Counter
372
+ import plotly.graph_objects as go
373
+ import os
374
 
375
+ # Load pre-trained model and tokenizer
376
+ model = TFAutoModelForSequenceClassification.from_pretrained("shrish191/sentiment-bert")
377
+ tokenizer = AutoTokenizer.from_pretrained("shrish191/sentiment-bert")
 
378
 
379
+ label_map = {0: 'Negative', 1: 'Neutral', 2: 'Positive'}
 
 
 
 
 
 
 
380
 
381
+ # Sentiment Prediction Function
382
+ def predict_sentiment(text):
383
+ inputs = tokenizer(text, return_tensors="tf", truncation=True, padding=True)
384
+ outputs = model(inputs)[0]
385
+ probs = tf.nn.softmax(outputs, axis=1).numpy()
386
+ pred_label = np.argmax(probs, axis=1)[0]
387
+ return label_map[pred_label]
388
+
389
+ # Reddit URL Handling
390
+ def analyze_reddit_url(url):
391
+ reddit = praw.Reddit(
392
+ client_id="YOUR_CLIENT_ID",
393
+ client_secret="YOUR_CLIENT_SECRET",
394
+ user_agent="YOUR_USER_AGENT"
395
+ )
396
+ try:
397
+ submission = reddit.submission(url=url)
398
+ submission.comments.replace_more(limit=0)
399
+ comments = [comment.body for comment in submission.comments.list() if len(comment.body) > 10][:100]
400
+ sentiments = [predict_sentiment(comment) for comment in comments]
401
+ sentiment_counts = Counter(sentiments)
402
+ result_text = "\n".join([f"{s}: {c}" for s, c in sentiment_counts.items()])
403
+
404
+ # Pie chart
405
+ fig = go.Figure(data=[go.Pie(labels=list(sentiment_counts.keys()),
406
+ values=list(sentiment_counts.values()),
407
+ hole=0.3)])
408
+ fig.update_layout(title="Sentiment Distribution of Reddit Comments")
409
+ return result_text, fig
410
+ except Exception as e:
411
+ return str(e), None
412
+
413
+ # Subreddit Analysis Function
414
+ def analyze_subreddit(subreddit_name):
415
+ reddit = praw.Reddit(
416
+ client_id="YOUR_CLIENT_ID",
417
+ client_secret="YOUR_CLIENT_SECRET",
418
+ user_agent="YOUR_USER_AGENT"
419
+ )
420
+ try:
421
+ subreddit = reddit.subreddit(subreddit_name)
422
+ posts = list(subreddit.hot(limit=100))
423
+ texts = [post.title + " " + post.selftext for post in posts if post.selftext or post.title]
424
+
425
+ if not texts:
426
+ return "No valid text data found in subreddit.", None
427
 
428
+ sentiments = [predict_sentiment(text) for text in texts]
429
+ sentiment_counts = Counter(sentiments)
430
+ result_text = "\n".join([f"{s}: {c}" for s, c in sentiment_counts.items()])
 
 
 
431
 
432
+ # Pie chart
433
+ fig = go.Figure(data=[go.Pie(labels=list(sentiment_counts.keys()),
434
+ values=list(sentiment_counts.values()),
435
+ hole=0.3)])
436
+ fig.update_layout(title=f"Sentiment Distribution in r/{subreddit_name}")
437
 
438
+ return result_text, fig
439
+ except Exception as e:
440
+ return str(e), None
 
441
 
442
+ # Image Upload Functionality
443
+ from PIL import Image
444
+ import pytesseract
445
 
446
+ def extract_text_from_image(image):
447
+ try:
448
+ img = Image.open(image)
449
+ text = pytesseract.image_to_string(img)
450
+ return text
451
+ except Exception as e:
452
+ return f"Error extracting text: {e}"
453
+
454
+ def analyze_image_sentiment(image):
455
+ extracted_text = extract_text_from_image(image)
456
+ if extracted_text:
457
+ sentiment = predict_sentiment(extracted_text)
458
+ return f"Extracted Text: {extracted_text}\n\nPredicted Sentiment: {sentiment}"
459
+ return "No text extracted."
460
+
461
+ # Gradio Interface
462
  with gr.Blocks() as demo:
463
+ gr.Markdown("## 🧠 Sentiment Analysis App")
464
+ with gr.Tab("Analyze Text"):
465
+ input_text = gr.Textbox(label="Enter text")
466
+ output_text = gr.Textbox(label="Predicted Sentiment")
467
+ analyze_btn = gr.Button("Analyze")
468
+ analyze_btn.click(fn=predict_sentiment, inputs=input_text, outputs=output_text)
469
+
470
+ with gr.Tab("Analyze Reddit URL"):
471
+ reddit_url = gr.Textbox(label="Enter Reddit post URL")
472
+ url_result = gr.Textbox(label="Sentiment Counts")
473
+ url_plot = gr.Plot(label="Pie Chart")
474
+ analyze_url_btn = gr.Button("Analyze Reddit Comments")
475
+ analyze_url_btn.click(fn=analyze_reddit_url, inputs=reddit_url, outputs=[url_result, url_plot])
476
+
477
+ with gr.Tab("Analyze Image"):
478
+ image_input = gr.Image(label="Upload an image")
479
+ image_result = gr.Textbox(label="Sentiment from Image Text")
480
+ analyze_img_btn = gr.Button("Analyze Image")
481
+ analyze_img_btn.click(fn=analyze_image_sentiment, inputs=image_input, outputs=image_result)
482
+
483
+ with gr.Tab("Analyze Subreddit"):
484
+ subreddit_input = gr.Textbox(label="Enter subreddit name (without r/)")
485
+ subreddit_result = gr.Textbox(label="Sentiment Counts")
486
+ subreddit_plot = gr.Plot(label="Pie Chart")
487
+ analyze_subreddit_btn = gr.Button("Analyze Subreddit")
488
+ analyze_subreddit_btn.click(fn=analyze_subreddit, inputs=subreddit_input, outputs=[subreddit_result, subreddit_plot])
489
+
490
+ demo.launch()
491
 
 
 
492
 
493
 
494