winamnd commited on
Commit
a92b56c
·
verified ·
1 Parent(s): 104c39e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -4
app.py CHANGED
@@ -10,6 +10,7 @@ import easyocr
10
  import keras_ocr
11
  from paddleocr import PaddleOCR
12
  from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
 
13
 
14
  # Paths
15
  MODEL_PATH = "./distilbert_spam_model"
@@ -19,7 +20,7 @@ RESULTS_CSV = "ocr_results.csv"
19
  # Ensure model exists
20
  if not os.path.exists(os.path.join(MODEL_PATH, "pytorch_model.bin")):
21
  print(f"⚠️ Model not found in {MODEL_PATH}. Downloading from Hugging Face Hub...")
22
- model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")
23
  model.save_pretrained(MODEL_PATH)
24
  tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
25
  tokenizer.save_pretrained(MODEL_PATH)
@@ -63,11 +64,15 @@ def generate_ocr(method, img):
63
  text_output = ocr_with_keras(img)
64
 
65
  # Classify Text as Spam or Not Spam
66
- inputs = tokenizer(text_output, return_tensors="pt", truncation=True, padding=True)
 
67
  with torch.no_grad():
68
  outputs = model(**inputs)
69
- prediction = torch.argmax(outputs.logits, dim=1).item()
70
- label = "Spam" if prediction == 1 else "Not Spam"
 
 
 
71
 
72
  # Save results
73
  save_results(text_output, label)
 
10
  import keras_ocr
11
  from paddleocr import PaddleOCR
12
  from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
13
+ import torch.nn.functional as F # Added for softmax
14
 
15
  # Paths
16
  MODEL_PATH = "./distilbert_spam_model"
 
20
  # Ensure model exists
21
  if not os.path.exists(os.path.join(MODEL_PATH, "pytorch_model.bin")):
22
  print(f"⚠️ Model not found in {MODEL_PATH}. Downloading from Hugging Face Hub...")
23
+ model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
24
  model.save_pretrained(MODEL_PATH)
25
  tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
26
  tokenizer.save_pretrained(MODEL_PATH)
 
64
  text_output = ocr_with_keras(img)
65
 
66
  # Classify Text as Spam or Not Spam
67
+ inputs = tokenizer(text_output, return_tensors="pt", truncation=True, padding=True, max_length=512)
68
+
69
  with torch.no_grad():
70
  outputs = model(**inputs)
71
+ probs = F.softmax(outputs.logits, dim=1) # Convert logits to probabilities
72
+ prediction = torch.argmax(probs, dim=1).item()
73
+
74
+ label_map = {0: "Not Spam", 1: "Spam"}
75
+ label = label_map[prediction]
76
 
77
  # Save results
78
  save_results(text_output, label)