shrish191 commited on
Commit
27e835e
·
verified ·
1 Parent(s): 4071ce6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -1
app.py CHANGED
@@ -138,7 +138,7 @@ demo = gr.Interface(
138
 
139
  demo.launch()
140
  '''
141
- import gradio as gr
142
  from transformers import TFBertForSequenceClassification, BertTokenizer
143
  import tensorflow as tf
144
  import praw
@@ -213,5 +213,91 @@ demo = gr.Interface(
213
  )
214
 
215
  demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
 
 
138
 
139
  demo.launch()
140
  '''
141
+ '''import gradio as gr
142
  from transformers import TFBertForSequenceClassification, BertTokenizer
143
  import tensorflow as tf
144
  import praw
 
213
  )
214
 
215
  demo.launch()
216
+ '''
217
+ import gradio as gr
218
+ from transformers import TFBertForSequenceClassification, BertTokenizer, pipeline
219
+ import tensorflow as tf
220
+ import praw
221
+ import os
222
+
223
+ # Load main BERT model and tokenizer
224
+ model = TFBertForSequenceClassification.from_pretrained("shrish191/sentiment-bert")
225
+ tokenizer = BertTokenizer.from_pretrained("shrish191/sentiment-bert")
226
+
227
+ # Load fallback sentiment pipeline model
228
+ fallback_classifier = pipeline("text-classification", model="VinMir/GordonAI-sentiment_analysis")
229
+
230
+ # Label mapping for main model
231
+ LABELS = {
232
+ 0: "Neutral",
233
+ 1: "Positive",
234
+ 2: "Negative"
235
+ }
236
+
237
+ # Reddit API setup (secure credentials from Hugging Face secrets)
238
+ reddit = praw.Reddit(
239
+ client_id=os.getenv("REDDIT_CLIENT_ID"),
240
+ client_secret=os.getenv("REDDIT_CLIENT_SECRET"),
241
+ user_agent=os.getenv("REDDIT_USER_AGENT", "sentiment-classifier-script")
242
+ )
243
+
244
+ # Fetch content from Reddit URL
245
+ def fetch_reddit_text(reddit_url):
246
+ try:
247
+ submission = reddit.submission(url=reddit_url)
248
+ return f"{submission.title}\n\n{submission.selftext}"
249
+ except Exception as e:
250
+ return f"Error fetching Reddit post: {str(e)}"
251
+
252
+ # Sentiment classification function
253
+ def classify_sentiment(text_input, reddit_url):
254
+ if reddit_url.strip():
255
+ text = fetch_reddit_text(reddit_url)
256
+ elif text_input.strip():
257
+ text = text_input
258
+ else:
259
+ return "[!] Please enter some text or a Reddit post URL."
260
+
261
+ if text.lower().startswith("error") or "Unable to extract" in text:
262
+ return f"[!] {text}"
263
+
264
+ try:
265
+ # Main BERT model prediction
266
+ inputs = tokenizer(text, return_tensors="tf", truncation=True, padding=True)
267
+ outputs = model(inputs)
268
+ probs = tf.nn.softmax(outputs.logits, axis=1)
269
+ confidence = float(tf.reduce_max(probs).numpy())
270
+ pred_label = tf.argmax(probs, axis=1).numpy()[0]
271
+
272
+ if confidence < 0.5:
273
+ # Use fallback model silently
274
+ fallback = fallback_classifier(text)[0]['label']
275
+ return f"Prediction: {fallback}"
276
+
277
+ return f"Prediction: {LABELS[pred_label]}"
278
+ except Exception as e:
279
+ return f"[!] Prediction error: {str(e)}"
280
+
281
+ # Gradio interface
282
+ demo = gr.Interface(
283
+ fn=classify_sentiment,
284
+ inputs=[
285
+ gr.Textbox(
286
+ label="Text Input (can be tweet or any content)",
287
+ placeholder="Paste tweet or type any content here...",
288
+ lines=4
289
+ ),
290
+ gr.Textbox(
291
+ label="Reddit Post URL",
292
+ placeholder="Paste a Reddit post URL (optional)",
293
+ lines=1
294
+ ),
295
+ ],
296
+ outputs="text",
297
+ title="Sentiment Analyzer",
298
+ description="🔍 Paste any text (including tweet content) OR a Reddit post URL to analyze sentiment.\n\n💡 Tweet URLs are not supported directly due to platform restrictions. Please paste tweet content manually."
299
+ )
300
+
301
+ demo.launch()
302
 
303