shrish191 commited on
Commit
804adf6
·
verified ·
1 Parent(s): 27e835e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -13
app.py CHANGED
@@ -215,33 +215,38 @@ demo = gr.Interface(
215
  demo.launch()
216
  '''
217
  import gradio as gr
218
- from transformers import TFBertForSequenceClassification, BertTokenizer, pipeline
219
  import tensorflow as tf
220
  import praw
221
  import os
222
 
223
- # Load main BERT model and tokenizer
 
 
 
 
 
224
  model = TFBertForSequenceClassification.from_pretrained("shrish191/sentiment-bert")
225
  tokenizer = BertTokenizer.from_pretrained("shrish191/sentiment-bert")
226
 
227
- # Load fallback sentiment pipeline model
228
- fallback_classifier = pipeline("text-classification", model="VinMir/GordonAI-sentiment_analysis")
229
-
230
- # Label mapping for main model
231
  LABELS = {
232
  0: "Neutral",
233
  1: "Positive",
234
  2: "Negative"
235
  }
236
 
237
- # Reddit API setup (secure credentials from Hugging Face secrets)
 
 
 
 
 
238
  reddit = praw.Reddit(
239
  client_id=os.getenv("REDDIT_CLIENT_ID"),
240
  client_secret=os.getenv("REDDIT_CLIENT_SECRET"),
241
  user_agent=os.getenv("REDDIT_USER_AGENT", "sentiment-classifier-script")
242
  )
243
 
244
- # Fetch content from Reddit URL
245
  def fetch_reddit_text(reddit_url):
246
  try:
247
  submission = reddit.submission(url=reddit_url)
@@ -249,7 +254,15 @@ def fetch_reddit_text(reddit_url):
249
  except Exception as e:
250
  return f"Error fetching Reddit post: {str(e)}"
251
 
252
- # Sentiment classification function
 
 
 
 
 
 
 
 
253
  def classify_sentiment(text_input, reddit_url):
254
  if reddit_url.strip():
255
  text = fetch_reddit_text(reddit_url)
@@ -262,7 +275,6 @@ def classify_sentiment(text_input, reddit_url):
262
  return f"[!] {text}"
263
 
264
  try:
265
- # Main BERT model prediction
266
  inputs = tokenizer(text, return_tensors="tf", truncation=True, padding=True)
267
  outputs = model(inputs)
268
  probs = tf.nn.softmax(outputs.logits, axis=1)
@@ -270,9 +282,7 @@ def classify_sentiment(text_input, reddit_url):
270
  pred_label = tf.argmax(probs, axis=1).numpy()[0]
271
 
272
  if confidence < 0.5:
273
- # Use fallback model silently
274
- fallback = fallback_classifier(text)[0]['label']
275
- return f"Prediction: {fallback}"
276
 
277
  return f"Prediction: {LABELS[pred_label]}"
278
  except Exception as e:
@@ -301,3 +311,6 @@ demo = gr.Interface(
301
  demo.launch()
302
 
303
 
 
 
 
 
215
  demo.launch()
216
  '''
217
  import gradio as gr
218
+ from transformers import TFBertForSequenceClassification, BertTokenizer
219
  import tensorflow as tf
220
  import praw
221
  import os
222
 
223
+ # Fallback imports
224
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
225
+ import torch
226
+ from scipy.special import softmax
227
+
228
+ # Load main model and tokenizer
229
  model = TFBertForSequenceClassification.from_pretrained("shrish191/sentiment-bert")
230
  tokenizer = BertTokenizer.from_pretrained("shrish191/sentiment-bert")
231
 
 
 
 
 
232
  LABELS = {
233
  0: "Neutral",
234
  1: "Positive",
235
  2: "Negative"
236
  }
237
 
238
+ # Load fallback model and tokenizer
239
+ fallback_model_name = "cardiffnlp/twitter-roberta-base-sentiment"
240
+ fallback_tokenizer = AutoTokenizer.from_pretrained(fallback_model_name)
241
+ fallback_model = AutoModelForSequenceClassification.from_pretrained(fallback_model_name)
242
+
243
+ # Reddit API
244
  reddit = praw.Reddit(
245
  client_id=os.getenv("REDDIT_CLIENT_ID"),
246
  client_secret=os.getenv("REDDIT_CLIENT_SECRET"),
247
  user_agent=os.getenv("REDDIT_USER_AGENT", "sentiment-classifier-script")
248
  )
249
 
 
250
  def fetch_reddit_text(reddit_url):
251
  try:
252
  submission = reddit.submission(url=reddit_url)
 
254
  except Exception as e:
255
  return f"Error fetching Reddit post: {str(e)}"
256
 
257
+ # Fallback classifier using RoBERTa
258
+ def fallback_classifier(text):
259
+ encoded_input = fallback_tokenizer(text, return_tensors='pt', truncation=True, padding=True)
260
+ with torch.no_grad():
261
+ output = fallback_model(**encoded_input)
262
+ scores = softmax(output.logits.numpy()[0])
263
+ labels = ['Negative', 'Neutral', 'Positive']
264
+ return f"Prediction: {labels[scores.argmax()]}"
265
+
266
  def classify_sentiment(text_input, reddit_url):
267
  if reddit_url.strip():
268
  text = fetch_reddit_text(reddit_url)
 
275
  return f"[!] {text}"
276
 
277
  try:
 
278
  inputs = tokenizer(text, return_tensors="tf", truncation=True, padding=True)
279
  outputs = model(inputs)
280
  probs = tf.nn.softmax(outputs.logits, axis=1)
 
282
  pred_label = tf.argmax(probs, axis=1).numpy()[0]
283
 
284
  if confidence < 0.5:
285
+ return fallback_classifier(text)
 
 
286
 
287
  return f"Prediction: {LABELS[pred_label]}"
288
  except Exception as e:
 
311
  demo.launch()
312
 
313
 
314
+
315
+
316
+