winamnd commited on
Commit
17f2d95
ยท
verified ยท
1 Parent(s): 4639dba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -14
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import gradio as gr
2
  import torch
3
  import json
4
- import csv
5
  import os
6
  import cv2
7
  import numpy as np
@@ -10,10 +9,10 @@ import keras_ocr
10
  from paddleocr import PaddleOCR
11
  from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
12
  import torch.nn.functional as F
13
- from save_results import save_results_to_repo # Import the save function
14
 
15
  # Paths
16
  MODEL_PATH = "./distilbert_spam_model"
 
17
 
18
  # Ensure model exists
19
  if not os.path.exists(os.path.join(MODEL_PATH, "pytorch_model.bin")):
@@ -27,10 +26,10 @@ else:
27
  model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH)
28
  tokenizer = DistilBertTokenizer.from_pretrained(MODEL_PATH)
29
 
30
- # ๐Ÿ”น Ensure model is in evaluation mode
31
  model.eval()
32
 
33
- # OCR Functions (No changes here)
34
  def ocr_with_paddle(img):
35
  ocr = PaddleOCR(lang='en', use_angle_cls=True)
36
  result = ocr.ocr(img)
@@ -48,6 +47,22 @@ def ocr_with_easy(img):
48
  results = reader.readtext(gray_image, detail=0)
49
  return ' '.join(results)
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  # OCR & Classification Function
52
  def generate_ocr(method, img):
53
  if img is None:
@@ -64,25 +79,23 @@ def generate_ocr(method, img):
64
  else: # KerasOCR
65
  text_output = ocr_with_keras(img)
66
 
67
- # ๐Ÿ”น Preprocess text properly
68
  text_output = text_output.strip()
69
  if len(text_output) == 0:
70
  return "No text detected!", "Cannot classify"
71
 
72
- # ๐Ÿ”น Tokenize text
73
  inputs = tokenizer(text_output, return_tensors="pt", truncation=True, padding=True, max_length=512)
74
 
75
- # ๐Ÿ”น Perform inference
76
  with torch.no_grad():
77
  outputs = model(**inputs)
78
- probs = F.softmax(outputs.logits, dim=1) # Convert logits to probabilities
79
- spam_prob = probs[0][1].item() # Probability of Spam
80
 
81
- # ๐Ÿ”น Adjust classification based on threshold (better than argmax)
82
  label = "Spam" if spam_prob > 0.5 else "Not Spam"
83
 
84
- # ๐Ÿ”น Save results using external function
85
- save_results_to_repo(text_output, label)
86
 
87
  return text_output, label
88
 
@@ -102,5 +115,5 @@ demo = gr.Interface(
102
  )
103
 
104
  # Launch App
105
- if __name__ == "__main__":
106
- demo.launch()
 
1
  import gradio as gr
2
  import torch
3
  import json
 
4
  import os
5
  import cv2
6
  import numpy as np
 
9
  from paddleocr import PaddleOCR
10
  from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
11
  import torch.nn.functional as F
 
12
 
13
  # Paths
14
  MODEL_PATH = "./distilbert_spam_model"
15
+ RESULTS_JSON = "results.json"
16
 
17
  # Ensure model exists
18
  if not os.path.exists(os.path.join(MODEL_PATH, "pytorch_model.bin")):
 
26
  model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH)
27
  tokenizer = DistilBertTokenizer.from_pretrained(MODEL_PATH)
28
 
29
+ # Ensure model is in evaluation mode
30
  model.eval()
31
 
32
+ # OCR Functions
33
  def ocr_with_paddle(img):
34
  ocr = PaddleOCR(lang='en', use_angle_cls=True)
35
  result = ocr.ocr(img)
 
47
  results = reader.readtext(gray_image, detail=0)
48
  return ' '.join(results)
49
 
50
+ # Save results to JSON
51
+ def save_to_json(text, label):
52
+ data = {"text": text, "classification": label}
53
+ if os.path.exists(RESULTS_JSON):
54
+ with open(RESULTS_JSON, "r") as file:
55
+ try:
56
+ results = json.load(file)
57
+ except json.JSONDecodeError:
58
+ results = []
59
+ else:
60
+ results = []
61
+
62
+ results.append(data)
63
+ with open(RESULTS_JSON, "w") as file:
64
+ json.dump(results, file, indent=4)
65
+
66
  # OCR & Classification Function
67
  def generate_ocr(method, img):
68
  if img is None:
 
79
  else: # KerasOCR
80
  text_output = ocr_with_keras(img)
81
 
 
82
  text_output = text_output.strip()
83
  if len(text_output) == 0:
84
  return "No text detected!", "Cannot classify"
85
 
86
+ # Tokenize text
87
  inputs = tokenizer(text_output, return_tensors="pt", truncation=True, padding=True, max_length=512)
88
 
89
+ # Perform inference
90
  with torch.no_grad():
91
  outputs = model(**inputs)
92
+ probs = F.softmax(outputs.logits, dim=1)
93
+ spam_prob = probs[0][1].item()
94
 
 
95
  label = "Spam" if spam_prob > 0.5 else "Not Spam"
96
 
97
+ # Save results to JSON
98
+ save_to_json(text_output, label)
99
 
100
  return text_output, label
101
 
 
115
  )
116
 
117
  # Launch App
118
+ if _name_ == "_main_":
119
+ demo.launch()