Update app.py
Browse files
app.py
CHANGED
@@ -237,6 +237,7 @@ demo = gr.Interface(
|
|
237 |
|
238 |
demo.launch()
|
239 |
'''
|
|
|
240 |
import gradio as gr
|
241 |
from transformers import TFBertForSequenceClassification, BertTokenizer
|
242 |
import tensorflow as tf
|
@@ -246,6 +247,7 @@ import pytesseract
|
|
246 |
from PIL import Image
|
247 |
import cv2
|
248 |
import numpy as np
|
|
|
249 |
|
250 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
251 |
import torch
|
@@ -291,14 +293,22 @@ def fallback_classifier(text):
|
|
291 |
labels = ['Negative', 'Neutral', 'Positive']
|
292 |
return f"Prediction: {labels[scores.argmax()]}"
|
293 |
|
|
|
|
|
|
|
|
|
|
|
|
|
294 |
def classify_sentiment(text_input, reddit_url, image):
|
295 |
-
# Priority: Reddit > Image > Textbox
|
296 |
if reddit_url.strip():
|
297 |
text = fetch_reddit_text(reddit_url)
|
298 |
elif image is not None:
|
299 |
try:
|
300 |
img_array = np.array(image)
|
301 |
-
|
|
|
|
|
|
|
302 |
except Exception as e:
|
303 |
return f"[!] OCR failed: {str(e)}"
|
304 |
elif text_input.strip():
|
@@ -353,3 +363,4 @@ demo.launch()
|
|
353 |
|
354 |
|
355 |
|
|
|
|
237 |
|
238 |
demo.launch()
|
239 |
'''
|
240 |
+
|
241 |
import gradio as gr
|
242 |
from transformers import TFBertForSequenceClassification, BertTokenizer
|
243 |
import tensorflow as tf
|
|
|
247 |
from PIL import Image
|
248 |
import cv2
|
249 |
import numpy as np
|
250 |
+
import re
|
251 |
|
252 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
253 |
import torch
|
|
|
293 |
labels = ['Negative', 'Neutral', 'Positive']
|
294 |
return f"Prediction: {labels[scores.argmax()]}"
|
295 |
|
296 |
+
def clean_ocr_text(text):
|
297 |
+
text = text.strip()
|
298 |
+
text = re.sub(r'\s+', ' ', text) # Replace multiple spaces and newlines
|
299 |
+
text = re.sub(r'[^\x00-\x7F]+', '', text) # Remove non-ASCII characters
|
300 |
+
return text
|
301 |
+
|
302 |
def classify_sentiment(text_input, reddit_url, image):
|
|
|
303 |
if reddit_url.strip():
|
304 |
text = fetch_reddit_text(reddit_url)
|
305 |
elif image is not None:
|
306 |
try:
|
307 |
img_array = np.array(image)
|
308 |
+
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
|
309 |
+
_, thresh = cv2.threshold(gray, 150, 255, cv2.THRESH_BINARY)
|
310 |
+
text = pytesseract.image_to_string(thresh)
|
311 |
+
text = clean_ocr_text(text)
|
312 |
except Exception as e:
|
313 |
return f"[!] OCR failed: {str(e)}"
|
314 |
elif text_input.strip():
|
|
|
363 |
|
364 |
|
365 |
|
366 |
+
|