winamnd commited on
Commit
deb409e
·
verified ·
1 Parent(s): 1269497

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -4
app.py CHANGED
@@ -9,8 +9,8 @@ import easyocr
9
  import keras_ocr
10
  from paddleocr import PaddleOCR
11
  from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
12
- import torch.nn.functional as F # Added for softmax
13
- from save_results import save_results_to_repo # Import the new save function
14
 
15
  # Paths
16
  MODEL_PATH = "./distilbert_spam_model"
@@ -29,6 +29,9 @@ else:
29
  model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH)
30
  tokenizer = DistilBertTokenizer.from_pretrained(MODEL_PATH)
31
 
 
 
 
32
  # Load OCR Methods
33
  def ocr_with_paddle(img):
34
  ocr = PaddleOCR(lang='en', use_angle_cls=True)
@@ -63,6 +66,11 @@ def generate_ocr(method, img):
63
  else: # KerasOCR
64
  text_output = ocr_with_keras(img)
65
 
 
 
 
 
 
66
  # Classify Text as Spam or Not Spam
67
  inputs = tokenizer(text_output, return_tensors="pt", truncation=True, padding=True, max_length=512)
68
 
@@ -71,7 +79,7 @@ def generate_ocr(method, img):
71
  probs = F.softmax(outputs.logits, dim=1) # Convert logits to probabilities
72
  prediction = torch.argmax(probs, dim=1).item()
73
 
74
- label_map = {0: "Spam", 1: "Not Spam"}
75
  label = label_map[prediction]
76
 
77
  # Save results using the external save function
@@ -96,4 +104,4 @@ demo = gr.Interface(
96
 
97
  # Launch App
98
  if __name__ == "__main__":
99
- demo.launch()
 
9
  import keras_ocr
10
  from paddleocr import PaddleOCR
11
  from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
12
+ import torch.nn.functional as F
13
+ from save_results import save_results_to_repo
14
 
15
  # Paths
16
  MODEL_PATH = "./distilbert_spam_model"
 
29
  model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH)
30
  tokenizer = DistilBertTokenizer.from_pretrained(MODEL_PATH)
31
 
32
+ # Set the model to evaluation mode to disable dropout layers
33
+ model.eval()
34
+
35
  # Load OCR Methods
36
  def ocr_with_paddle(img):
37
  ocr = PaddleOCR(lang='en', use_angle_cls=True)
 
66
  else: # KerasOCR
67
  text_output = ocr_with_keras(img)
68
 
69
+ # Clean and truncate the extracted text
70
+ text_output = text_output.strip()
71
+ if len(text_output) == 0:
72
+ return "No text detected!", "Cannot classify"
73
+
74
  # Classify Text as Spam or Not Spam
75
  inputs = tokenizer(text_output, return_tensors="pt", truncation=True, padding=True, max_length=512)
76
 
 
79
  probs = F.softmax(outputs.logits, dim=1) # Convert logits to probabilities
80
  prediction = torch.argmax(probs, dim=1).item()
81
 
82
+ label_map = {0: "Not Spam", 1: "Spam"}
83
  label = label_map[prediction]
84
 
85
  # Save results using the external save function
 
104
 
105
  # Launch App
106
  if __name__ == "__main__":
107
+ demo.launch()