Update app.py
Browse files
app.py
CHANGED
@@ -237,7 +237,7 @@ demo = gr.Interface(
|
|
237 |
|
238 |
demo.launch()
|
239 |
'''
|
240 |
-
|
241 |
import gradio as gr
|
242 |
from transformers import TFBertForSequenceClassification, BertTokenizer
|
243 |
import tensorflow as tf
|
@@ -358,6 +358,155 @@ demo = gr.Interface(
|
|
358 |
)
|
359 |
|
360 |
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
361 |
|
362 |
|
363 |
|
|
|
237 |
|
238 |
demo.launch()
|
239 |
'''
|
240 |
+
'''
|
241 |
import gradio as gr
|
242 |
from transformers import TFBertForSequenceClassification, BertTokenizer
|
243 |
import tensorflow as tf
|
|
|
358 |
)
|
359 |
|
360 |
demo.launch()
|
361 |
+
'''
|
362 |
+
import gradio as gr
|
363 |
+
from transformers import TFBertForSequenceClassification, BertTokenizer
|
364 |
+
import tensorflow as tf
|
365 |
+
import praw
|
366 |
+
import os
|
367 |
+
import pytesseract
|
368 |
+
from PIL import Image
|
369 |
+
import numpy as np
|
370 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
371 |
+
import torch
|
372 |
+
from scipy.special import softmax
|
373 |
+
import plotly.graph_objs as go
|
374 |
+
|
375 |
+
# Install tesseract OCR (only runs once in Hugging Face Spaces)
|
376 |
+
os.system("apt-get update && apt-get install -y tesseract-ocr")
|
377 |
+
|
378 |
+
# Load main model
|
379 |
+
model = TFBertForSequenceClassification.from_pretrained("shrish191/sentiment-bert")
|
380 |
+
tokenizer = BertTokenizer.from_pretrained("shrish191/sentiment-bert")
|
381 |
+
|
382 |
+
LABELS = {
|
383 |
+
0: "Neutral",
|
384 |
+
1: "Positive",
|
385 |
+
2: "Negative"
|
386 |
+
}
|
387 |
+
|
388 |
+
# Load fallback model
|
389 |
+
fallback_model_name = "cardiffnlp/twitter-roberta-base-sentiment"
|
390 |
+
fallback_tokenizer = AutoTokenizer.from_pretrained(fallback_model_name)
|
391 |
+
fallback_model = AutoModelForSequenceClassification.from_pretrained(fallback_model_name)
|
392 |
+
|
393 |
+
# Reddit API setup
|
394 |
+
reddit = praw.Reddit(
|
395 |
+
client_id=os.getenv("REDDIT_CLIENT_ID"),
|
396 |
+
client_secret=os.getenv("REDDIT_CLIENT_SECRET"),
|
397 |
+
user_agent=os.getenv("REDDIT_USER_AGENT", "sentiment-classifier-ui")
|
398 |
+
)
|
399 |
+
|
400 |
+
def fetch_reddit_text(reddit_url):
|
401 |
+
try:
|
402 |
+
submission = reddit.submission(url=reddit_url)
|
403 |
+
return f"{submission.title}\n\n{submission.selftext}"
|
404 |
+
except Exception as e:
|
405 |
+
return f"Error fetching Reddit post: {str(e)}"
|
406 |
+
|
407 |
+
def fetch_subreddit_texts(subreddit_name, limit=15):
|
408 |
+
texts = []
|
409 |
+
try:
|
410 |
+
subreddit = reddit.subreddit(subreddit_name)
|
411 |
+
for submission in subreddit.hot(limit=limit):
|
412 |
+
combined = f"{submission.title} {submission.selftext}".strip()
|
413 |
+
if combined:
|
414 |
+
texts.append(combined)
|
415 |
+
return texts
|
416 |
+
except Exception as e:
|
417 |
+
return [f"Error fetching subreddit: {str(e)}"]
|
418 |
+
|
419 |
+
def fallback_classifier(text):
|
420 |
+
encoded_input = fallback_tokenizer(text, return_tensors='pt', truncation=True, padding=True)
|
421 |
+
with torch.no_grad():
|
422 |
+
output = fallback_model(**encoded_input)
|
423 |
+
scores = softmax(output.logits.numpy()[0])
|
424 |
+
labels = ['Negative', 'Neutral', 'Positive']
|
425 |
+
return f"Prediction: {labels[scores.argmax()]}"
|
426 |
+
|
427 |
+
def classify_multiple_sentiments(texts):
|
428 |
+
counts = {"Positive": 0, "Neutral": 0, "Negative": 0}
|
429 |
+
for text in texts:
|
430 |
+
try:
|
431 |
+
inputs = tokenizer(text, return_tensors="tf", truncation=True, padding=True)
|
432 |
+
outputs = model(inputs)
|
433 |
+
probs = tf.nn.softmax(outputs.logits, axis=1)
|
434 |
+
confidence = float(tf.reduce_max(probs).numpy())
|
435 |
+
pred_label = tf.argmax(probs, axis=1).numpy()[0]
|
436 |
+
|
437 |
+
if confidence < 0.5:
|
438 |
+
label = fallback_classifier(text).split(":")[-1].strip()
|
439 |
+
else:
|
440 |
+
label = LABELS[pred_label]
|
441 |
+
|
442 |
+
counts[label] += 1
|
443 |
+
except:
|
444 |
+
continue
|
445 |
+
return counts
|
446 |
+
|
447 |
+
def sentiment_pie_chart(counts):
|
448 |
+
labels = list(counts.keys())
|
449 |
+
values = list(counts.values())
|
450 |
+
fig = go.Figure(data=[go.Pie(labels=labels, values=values, hole=.3)])
|
451 |
+
fig.update_layout(title_text="Sentiment Distribution in Subreddit")
|
452 |
+
return fig
|
453 |
+
|
454 |
+
def classify_sentiment(text_input, reddit_url, image, subreddit_name):
|
455 |
+
# Subreddit Dashboard has priority
|
456 |
+
if subreddit_name.strip():
|
457 |
+
texts = fetch_subreddit_texts(subreddit_name)
|
458 |
+
if "Error" in texts[0]:
|
459 |
+
return texts[0]
|
460 |
+
counts = classify_multiple_sentiments(texts)
|
461 |
+
return sentiment_pie_chart(counts)
|
462 |
+
|
463 |
+
if reddit_url.strip():
|
464 |
+
text = fetch_reddit_text(reddit_url)
|
465 |
+
elif image is not None:
|
466 |
+
try:
|
467 |
+
img_array = np.array(image)
|
468 |
+
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
|
469 |
+
thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1]
|
470 |
+
text = pytesseract.image_to_string(thresh)
|
471 |
+
except Exception as e:
|
472 |
+
return f"[!] OCR failed: {str(e)}"
|
473 |
+
elif text_input.strip():
|
474 |
+
text = text_input
|
475 |
+
else:
|
476 |
+
return "[!] Please enter some text, upload an image, or provide a Reddit URL."
|
477 |
+
|
478 |
+
if text.lower().startswith("error") or "Unable to extract" in text:
|
479 |
+
return f"[!] {text}"
|
480 |
+
|
481 |
+
try:
|
482 |
+
inputs = tokenizer(text, return_tensors="tf", truncation=True, padding=True)
|
483 |
+
outputs = model(inputs)
|
484 |
+
probs = tf.nn.softmax(outputs.logits, axis=1)
|
485 |
+
confidence = float(tf.reduce_max(probs).numpy())
|
486 |
+
pred_label = tf.argmax(probs, axis=1).numpy()[0]
|
487 |
+
|
488 |
+
if confidence < 0.5:
|
489 |
+
return fallback_classifier(text)
|
490 |
+
|
491 |
+
return f"Prediction: {LABELS[pred_label]}"
|
492 |
+
except Exception as e:
|
493 |
+
return f"[!] Prediction error: {str(e)}"
|
494 |
+
|
495 |
+
# Gradio interface
|
496 |
+
demo = gr.Interface(
|
497 |
+
fn=classify_sentiment,
|
498 |
+
inputs=[
|
499 |
+
gr.Textbox(label="Text Input (can be tweet or any content)", placeholder="Paste tweet or type any content here...", lines=4),
|
500 |
+
gr.Textbox(label="Reddit Post URL", placeholder="Paste a Reddit post URL (optional)", lines=1),
|
501 |
+
gr.Image(label="Upload Image (optional)", type="pil"),
|
502 |
+
gr.Textbox(label="Subreddit Name", placeholder="e.g. AskReddit (optional)", lines=1),
|
503 |
+
],
|
504 |
+
outputs=gr.outputs.Component(label="Result"),
|
505 |
+
title="Sentiment Analyzer with Dashboard",
|
506 |
+
description="\ud83d\udd0d Paste any text, Reddit post URL, upload an image, or enter a subreddit name to analyze sentiment."
|
507 |
+
)
|
508 |
+
|
509 |
+
demo.launch()
|
510 |
|
511 |
|
512 |
|