shrish191 commited on
Commit
bd6cebd
·
verified ·
1 Parent(s): 4f19b27

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -2
app.py CHANGED
@@ -237,6 +237,7 @@ demo = gr.Interface(
237
 
238
  demo.launch()
239
  '''
 
240
  import gradio as gr
241
  from transformers import TFBertForSequenceClassification, BertTokenizer
242
  import tensorflow as tf
@@ -246,6 +247,7 @@ import pytesseract
246
  from PIL import Image
247
  import cv2
248
  import numpy as np
 
249
 
250
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
251
  import torch
@@ -291,14 +293,22 @@ def fallback_classifier(text):
291
  labels = ['Negative', 'Neutral', 'Positive']
292
  return f"Prediction: {labels[scores.argmax()]}"
293
 
 
 
 
 
 
 
294
  def classify_sentiment(text_input, reddit_url, image):
295
- # Priority: Reddit > Image > Textbox
296
  if reddit_url.strip():
297
  text = fetch_reddit_text(reddit_url)
298
  elif image is not None:
299
  try:
300
  img_array = np.array(image)
301
- text = pytesseract.image_to_string(img_array)
 
 
 
302
  except Exception as e:
303
  return f"[!] OCR failed: {str(e)}"
304
  elif text_input.strip():
@@ -353,3 +363,4 @@ demo.launch()
353
 
354
 
355
 
 
 
237
 
238
  demo.launch()
239
  '''
240
+
241
  import gradio as gr
242
  from transformers import TFBertForSequenceClassification, BertTokenizer
243
  import tensorflow as tf
 
247
  from PIL import Image
248
  import cv2
249
  import numpy as np
250
+ import re
251
 
252
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
253
  import torch
 
293
  labels = ['Negative', 'Neutral', 'Positive']
294
  return f"Prediction: {labels[scores.argmax()]}"
295
 
296
+ def clean_ocr_text(text):
297
+ text = text.strip()
298
+ text = re.sub(r'\s+', ' ', text) # Replace multiple spaces and newlines
299
+ text = re.sub(r'[^\x00-\x7F]+', '', text) # Remove non-ASCII characters
300
+ return text
301
+
302
  def classify_sentiment(text_input, reddit_url, image):
 
303
  if reddit_url.strip():
304
  text = fetch_reddit_text(reddit_url)
305
  elif image is not None:
306
  try:
307
  img_array = np.array(image)
308
+ gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
309
+ _, thresh = cv2.threshold(gray, 150, 255, cv2.THRESH_BINARY)
310
+ text = pytesseract.image_to_string(thresh)
311
+ text = clean_ocr_text(text)
312
  except Exception as e:
313
  return f"[!] OCR failed: {str(e)}"
314
  elif text_input.strip():
 
363
 
364
 
365
 
366
+