Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -9,8 +9,8 @@ import easyocr
|
|
9 |
import keras_ocr
|
10 |
from paddleocr import PaddleOCR
|
11 |
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
|
12 |
-
import torch.nn.functional as F
|
13 |
-
from save_results import save_results_to_repo
|
14 |
|
15 |
# Paths
|
16 |
MODEL_PATH = "./distilbert_spam_model"
|
@@ -29,6 +29,9 @@ else:
|
|
29 |
model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH)
|
30 |
tokenizer = DistilBertTokenizer.from_pretrained(MODEL_PATH)
|
31 |
|
|
|
|
|
|
|
32 |
# Load OCR Methods
|
33 |
def ocr_with_paddle(img):
|
34 |
ocr = PaddleOCR(lang='en', use_angle_cls=True)
|
@@ -63,6 +66,11 @@ def generate_ocr(method, img):
|
|
63 |
else: # KerasOCR
|
64 |
text_output = ocr_with_keras(img)
|
65 |
|
|
|
|
|
|
|
|
|
|
|
66 |
# Classify Text as Spam or Not Spam
|
67 |
inputs = tokenizer(text_output, return_tensors="pt", truncation=True, padding=True, max_length=512)
|
68 |
|
@@ -71,7 +79,7 @@ def generate_ocr(method, img):
|
|
71 |
probs = F.softmax(outputs.logits, dim=1) # Convert logits to probabilities
|
72 |
prediction = torch.argmax(probs, dim=1).item()
|
73 |
|
74 |
-
label_map = {0: "Spam", 1: "
|
75 |
label = label_map[prediction]
|
76 |
|
77 |
# Save results using the external save function
|
@@ -96,4 +104,4 @@ demo = gr.Interface(
|
|
96 |
|
97 |
# Launch App
|
98 |
if __name__ == "__main__":
|
99 |
-
demo.launch()
|
|
|
9 |
import keras_ocr
|
10 |
from paddleocr import PaddleOCR
|
11 |
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from save_results import save_results_to_repo
|
14 |
|
15 |
# Paths
|
16 |
MODEL_PATH = "./distilbert_spam_model"
|
|
|
29 |
model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH)
|
30 |
tokenizer = DistilBertTokenizer.from_pretrained(MODEL_PATH)
|
31 |
|
32 |
+
# Set the model to evaluation mode to disable dropout layers
|
33 |
+
model.eval()
|
34 |
+
|
35 |
# Load OCR Methods
|
36 |
def ocr_with_paddle(img):
|
37 |
ocr = PaddleOCR(lang='en', use_angle_cls=True)
|
|
|
66 |
else: # KerasOCR
|
67 |
text_output = ocr_with_keras(img)
|
68 |
|
69 |
+
# Clean and truncate the extracted text
|
70 |
+
text_output = text_output.strip()
|
71 |
+
if len(text_output) == 0:
|
72 |
+
return "No text detected!", "Cannot classify"
|
73 |
+
|
74 |
# Classify Text as Spam or Not Spam
|
75 |
inputs = tokenizer(text_output, return_tensors="pt", truncation=True, padding=True, max_length=512)
|
76 |
|
|
|
79 |
probs = F.softmax(outputs.logits, dim=1) # Convert logits to probabilities
|
80 |
prediction = torch.argmax(probs, dim=1).item()
|
81 |
|
82 |
+
label_map = {0: "Not Spam", 1: "Spam"}
|
83 |
label = label_map[prediction]
|
84 |
|
85 |
# Save results using the external save function
|
|
|
104 |
|
105 |
# Launch App
|
106 |
if __name__ == "__main__":
|
107 |
+
demo.launch()
|