shrish191 commited on
Commit
5a741e8
Β·
verified Β·
1 Parent(s): 65d8742

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -18
app.py CHANGED
@@ -138,31 +138,42 @@ demo = gr.Interface(
138
 
139
  demo.launch()
140
  '''
141
- '''import gradio as gr
 
 
 
142
  from transformers import TFBertForSequenceClassification, BertTokenizer
143
  import tensorflow as tf
144
  import praw
145
  import os
146
 
147
- # Load model and tokenizer from Hugging Face
 
 
 
 
 
148
  model = TFBertForSequenceClassification.from_pretrained("shrish191/sentiment-bert")
149
  tokenizer = BertTokenizer.from_pretrained("shrish191/sentiment-bert")
150
 
151
- # Label mapping
152
  LABELS = {
153
  0: "Neutral",
154
  1: "Positive",
155
  2: "Negative"
156
  }
157
 
158
- # Reddit API setup (credentials loaded securely from secrets)
 
 
 
 
 
159
  reddit = praw.Reddit(
160
  client_id=os.getenv("REDDIT_CLIENT_ID"),
161
  client_secret=os.getenv("REDDIT_CLIENT_SECRET"),
162
- user_agent=os.getenv("REDDIT_USER_AGENT", "sentiment-classifier-script")
163
  )
164
 
165
- # Reddit post fetcher
166
  def fetch_reddit_text(reddit_url):
167
  try:
168
  submission = reddit.submission(url=reddit_url)
@@ -170,7 +181,15 @@ def fetch_reddit_text(reddit_url):
170
  except Exception as e:
171
  return f"Error fetching Reddit post: {str(e)}"
172
 
173
- # Main sentiment function
 
 
 
 
 
 
 
 
174
  def classify_sentiment(text_input, reddit_url):
175
  if reddit_url.strip():
176
  text = fetch_reddit_text(reddit_url)
@@ -186,13 +205,17 @@ def classify_sentiment(text_input, reddit_url):
186
  inputs = tokenizer(text, return_tensors="tf", truncation=True, padding=True)
187
  outputs = model(inputs)
188
  probs = tf.nn.softmax(outputs.logits, axis=1)
189
- pred_label = tf.argmax(probs, axis=1).numpy()[0]
190
  confidence = float(tf.reduce_max(probs).numpy())
191
- return f"Prediction: {LABELS[pred_label]} (Confidence: {confidence:.2f})"
 
 
 
 
 
192
  except Exception as e:
193
  return f"[!] Prediction error: {str(e)}"
194
 
195
- # Gradio UI
196
  demo = gr.Interface(
197
  fn=classify_sentiment,
198
  inputs=[
@@ -211,7 +234,7 @@ demo = gr.Interface(
211
  title="Sentiment Analyzer",
212
  description="πŸ” Paste any text (including tweet content) OR a Reddit post URL to analyze sentiment.\n\nπŸ’‘ Tweet URLs are not supported directly due to platform restrictions. Please paste tweet content manually."
213
  )
214
-
215
  demo.launch()
216
  '''
217
  import gradio as gr
@@ -219,13 +242,19 @@ from transformers import TFBertForSequenceClassification, BertTokenizer
219
  import tensorflow as tf
220
  import praw
221
  import os
222
-
 
 
 
223
 
224
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
225
  import torch
226
  from scipy.special import softmax
227
 
 
 
228
 
 
229
  model = TFBertForSequenceClassification.from_pretrained("shrish191/sentiment-bert")
230
  tokenizer = BertTokenizer.from_pretrained("shrish191/sentiment-bert")
231
 
@@ -235,12 +264,12 @@ LABELS = {
235
  2: "Negative"
236
  }
237
 
238
-
239
  fallback_model_name = "cardiffnlp/twitter-roberta-base-sentiment"
240
  fallback_tokenizer = AutoTokenizer.from_pretrained(fallback_model_name)
241
  fallback_model = AutoModelForSequenceClassification.from_pretrained(fallback_model_name)
242
 
243
- # Reddit API
244
  reddit = praw.Reddit(
245
  client_id=os.getenv("REDDIT_CLIENT_ID"),
246
  client_secret=os.getenv("REDDIT_CLIENT_SECRET"),
@@ -254,7 +283,6 @@ def fetch_reddit_text(reddit_url):
254
  except Exception as e:
255
  return f"Error fetching Reddit post: {str(e)}"
256
 
257
-
258
  def fallback_classifier(text):
259
  encoded_input = fallback_tokenizer(text, return_tensors='pt', truncation=True, padding=True)
260
  with torch.no_grad():
@@ -263,13 +291,20 @@ def fallback_classifier(text):
263
  labels = ['Negative', 'Neutral', 'Positive']
264
  return f"Prediction: {labels[scores.argmax()]}"
265
 
266
- def classify_sentiment(text_input, reddit_url):
 
267
  if reddit_url.strip():
268
  text = fetch_reddit_text(reddit_url)
 
 
 
 
 
 
269
  elif text_input.strip():
270
  text = text_input
271
  else:
272
- return "[!] Please enter some text or a Reddit post URL."
273
 
274
  if text.lower().startswith("error") or "Unable to extract" in text:
275
  return f"[!] {text}"
@@ -302,10 +337,14 @@ demo = gr.Interface(
302
  placeholder="Paste a Reddit post URL (optional)",
303
  lines=1
304
  ),
 
 
 
 
305
  ],
306
  outputs="text",
307
  title="Sentiment Analyzer",
308
- description="πŸ” Paste any text (including tweet content) OR a Reddit post URL to analyze sentiment.\n\nπŸ’‘ Tweet URLs are not supported directly due to platform restrictions. Please paste tweet content manually."
309
  )
310
 
311
  demo.launch()
 
138
 
139
  demo.launch()
140
  '''
141
+
142
+
143
+ '''
144
+ import gradio as gr
145
  from transformers import TFBertForSequenceClassification, BertTokenizer
146
  import tensorflow as tf
147
  import praw
148
  import os
149
 
150
+
151
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
152
+ import torch
153
+ from scipy.special import softmax
154
+
155
+
156
  model = TFBertForSequenceClassification.from_pretrained("shrish191/sentiment-bert")
157
  tokenizer = BertTokenizer.from_pretrained("shrish191/sentiment-bert")
158
 
 
159
  LABELS = {
160
  0: "Neutral",
161
  1: "Positive",
162
  2: "Negative"
163
  }
164
 
165
+
166
+ fallback_model_name = "cardiffnlp/twitter-roberta-base-sentiment"
167
+ fallback_tokenizer = AutoTokenizer.from_pretrained(fallback_model_name)
168
+ fallback_model = AutoModelForSequenceClassification.from_pretrained(fallback_model_name)
169
+
170
+ # Reddit API
171
  reddit = praw.Reddit(
172
  client_id=os.getenv("REDDIT_CLIENT_ID"),
173
  client_secret=os.getenv("REDDIT_CLIENT_SECRET"),
174
+ user_agent=os.getenv("REDDIT_USER_AGENT", "sentiment-classifier-ui")
175
  )
176
 
 
177
  def fetch_reddit_text(reddit_url):
178
  try:
179
  submission = reddit.submission(url=reddit_url)
 
181
  except Exception as e:
182
  return f"Error fetching Reddit post: {str(e)}"
183
 
184
+
185
+ def fallback_classifier(text):
186
+ encoded_input = fallback_tokenizer(text, return_tensors='pt', truncation=True, padding=True)
187
+ with torch.no_grad():
188
+ output = fallback_model(**encoded_input)
189
+ scores = softmax(output.logits.numpy()[0])
190
+ labels = ['Negative', 'Neutral', 'Positive']
191
+ return f"Prediction: {labels[scores.argmax()]}"
192
+
193
  def classify_sentiment(text_input, reddit_url):
194
  if reddit_url.strip():
195
  text = fetch_reddit_text(reddit_url)
 
205
  inputs = tokenizer(text, return_tensors="tf", truncation=True, padding=True)
206
  outputs = model(inputs)
207
  probs = tf.nn.softmax(outputs.logits, axis=1)
 
208
  confidence = float(tf.reduce_max(probs).numpy())
209
+ pred_label = tf.argmax(probs, axis=1).numpy()[0]
210
+
211
+ if confidence < 0.5:
212
+ return fallback_classifier(text)
213
+
214
+ return f"Prediction: {LABELS[pred_label]}"
215
  except Exception as e:
216
  return f"[!] Prediction error: {str(e)}"
217
 
218
+ # Gradio interface
219
  demo = gr.Interface(
220
  fn=classify_sentiment,
221
  inputs=[
 
234
  title="Sentiment Analyzer",
235
  description="πŸ” Paste any text (including tweet content) OR a Reddit post URL to analyze sentiment.\n\nπŸ’‘ Tweet URLs are not supported directly due to platform restrictions. Please paste tweet content manually."
236
  )
237
+
238
  demo.launch()
239
  '''
240
  import gradio as gr
 
242
  import tensorflow as tf
243
  import praw
244
  import os
245
+ import pytesseract
246
+ from PIL import Image
247
+ import cv2
248
+ import numpy as np
249
 
250
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
251
  import torch
252
  from scipy.special import softmax
253
 
254
+ # Install tesseract OCR (only runs once in Hugging Face Spaces)
255
+ os.system("apt-get update && apt-get install -y tesseract-ocr")
256
 
257
+ # Load main model
258
  model = TFBertForSequenceClassification.from_pretrained("shrish191/sentiment-bert")
259
  tokenizer = BertTokenizer.from_pretrained("shrish191/sentiment-bert")
260
 
 
264
  2: "Negative"
265
  }
266
 
267
+ # Load fallback model
268
  fallback_model_name = "cardiffnlp/twitter-roberta-base-sentiment"
269
  fallback_tokenizer = AutoTokenizer.from_pretrained(fallback_model_name)
270
  fallback_model = AutoModelForSequenceClassification.from_pretrained(fallback_model_name)
271
 
272
+ # Reddit API setup
273
  reddit = praw.Reddit(
274
  client_id=os.getenv("REDDIT_CLIENT_ID"),
275
  client_secret=os.getenv("REDDIT_CLIENT_SECRET"),
 
283
  except Exception as e:
284
  return f"Error fetching Reddit post: {str(e)}"
285
 
 
286
  def fallback_classifier(text):
287
  encoded_input = fallback_tokenizer(text, return_tensors='pt', truncation=True, padding=True)
288
  with torch.no_grad():
 
291
  labels = ['Negative', 'Neutral', 'Positive']
292
  return f"Prediction: {labels[scores.argmax()]}"
293
 
294
+ def classify_sentiment(text_input, reddit_url, image):
295
+ # Priority: Reddit > Image > Textbox
296
  if reddit_url.strip():
297
  text = fetch_reddit_text(reddit_url)
298
+ elif image is not None:
299
+ try:
300
+ img_array = np.array(image)
301
+ text = pytesseract.image_to_string(img_array)
302
+ except Exception as e:
303
+ return f"[!] OCR failed: {str(e)}"
304
  elif text_input.strip():
305
  text = text_input
306
  else:
307
+ return "[!] Please enter some text, upload an image, or provide a Reddit URL."
308
 
309
  if text.lower().startswith("error") or "Unable to extract" in text:
310
  return f"[!] {text}"
 
337
  placeholder="Paste a Reddit post URL (optional)",
338
  lines=1
339
  ),
340
+ gr.Image(
341
+ label="Upload Image (optional)",
342
+ type="pil"
343
+ )
344
  ],
345
  outputs="text",
346
  title="Sentiment Analyzer",
347
+ description="πŸ” Paste any text, Reddit post URL, or upload an image containing text to analyze sentiment.\n\nπŸ’‘ Tweet URLs are not supported. Please paste tweet content or screenshot instead."
348
  )
349
 
350
  demo.launch()