winamnd commited on
Commit
e08cb5e
·
verified ·
1 Parent(s): 4ee3a20

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -45
app.py CHANGED
@@ -27,10 +27,10 @@ else:
27
  model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH)
28
  tokenizer = DistilBertTokenizer.from_pretrained(MODEL_PATH)
29
 
30
- # Set the model to evaluation mode to disable dropout layers
31
  model.eval()
32
 
33
- # Load OCR Methods
34
  def ocr_with_paddle(img):
35
  ocr = PaddleOCR(lang='en', use_angle_cls=True)
36
  result = ocr.ocr(img)
@@ -48,8 +48,8 @@ def ocr_with_easy(img):
48
  results = reader.readtext(gray_image, detail=0)
49
  return ' '.join(results)
50
 
51
- # OCR Function
52
- def generate_ocr(method, img):
53
  if img is None:
54
  raise gr.Error("Please upload an image!")
55
 
@@ -64,63 +64,58 @@ def generate_ocr(method, img):
64
  else: # KerasOCR
65
  text_output = ocr_with_keras(img)
66
 
67
- # Clean and truncate the extracted text
68
  text_output = text_output.strip()
 
69
  if len(text_output) == 0:
70
- return "No text detected!", "Cannot classify"
71
 
72
- # Classify Text as Spam or Not Spam
 
 
 
 
 
 
 
73
  inputs = tokenizer(text_output, return_tensors="pt", truncation=True, padding=True, max_length=512)
74
 
 
75
  with torch.no_grad():
76
  outputs = model(**inputs)
77
- probs = F.softmax(outputs.logits, dim=1) # Convert logits to probabilities
78
  prediction = torch.argmax(probs, dim=1).item()
79
 
80
  label_map = {0: "Not Spam", 1: "Spam"}
81
  label = label_map[prediction]
82
 
83
- # Save results using the external save function
84
  save_results_to_repo(text_output, label)
85
 
86
  return text_output, label
87
 
88
- # Save results to JSON file
89
- RESULTS_JSON = "ocr_results.json"
90
-
91
- def save_to_json(text, label):
92
- data = {"Extracted Text": text, "Spam Classification": label}
93
-
94
- # Save to JSON file
95
- with open(RESULTS_JSON, "w") as json_file:
96
- json.dump(data, json_file, indent=4)
97
-
98
- return f"Results saved to {RESULTS_JSON}"
99
-
100
- # Create Gradio Interface
101
  image_input = gr.Image()
102
- method_input = gr.Radio(["PaddleOCR", "EasyOCR", "KerasOCR"], value="PaddleOCR")
103
- output_text = gr.Textbox(label="Extracted Text")
104
- output_label = gr.Textbox(label="Spam Classification")
105
- save_button = gr.Button("Save to JSON")
106
- save_output = gr.Textbox(label="Save Status")
107
-
108
- # Main OCR Interface
109
- demo = gr.Interface(
110
- fn=generate_ocr,
111
- inputs=[method_input, image_input],
112
- outputs=[output_text, output_label],
113
- title="OCR Spam Classifier",
114
- description="Upload an image, extract text, and classify it as Spam or Not Spam.",
115
- theme="compact",
116
- )
117
-
118
- # *Attach Save Button to Function*
119
- save_button.click(
120
- fn=save_to_json,
121
- inputs=[output_text, output_label],
122
- outputs=[save_output]
123
- )
124
 
125
  # Launch App
126
- demo.launch()
 
 
27
  model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH)
28
  tokenizer = DistilBertTokenizer.from_pretrained(MODEL_PATH)
29
 
30
+ # Set model to evaluation mode
31
  model.eval()
32
 
33
+ # OCR Methods
34
  def ocr_with_paddle(img):
35
  ocr = PaddleOCR(lang='en', use_angle_cls=True)
36
  result = ocr.ocr(img)
 
48
  results = reader.readtext(gray_image, detail=0)
49
  return ' '.join(results)
50
 
51
+ # OCR Extraction Function
52
+ def extract_text(method, img):
53
  if img is None:
54
  raise gr.Error("Please upload an image!")
55
 
 
64
  else: # KerasOCR
65
  text_output = ocr_with_keras(img)
66
 
67
+ # Clean extracted text
68
  text_output = text_output.strip()
69
+
70
  if len(text_output) == 0:
71
+ return "No text detected!"
72
 
73
+ return text_output
74
+
75
+ # Classification Function
76
+ def classify_text(text_output):
77
+ if text_output.strip() == "No text detected!":
78
+ return text_output, "Cannot classify"
79
+
80
+ # Tokenize text
81
  inputs = tokenizer(text_output, return_tensors="pt", truncation=True, padding=True, max_length=512)
82
 
83
+ # Model inference
84
  with torch.no_grad():
85
  outputs = model(**inputs)
86
+ probs = F.softmax(outputs.logits, dim=1)
87
  prediction = torch.argmax(probs, dim=1).item()
88
 
89
  label_map = {0: "Not Spam", 1: "Spam"}
90
  label = label_map[prediction]
91
 
92
+ # Save results automatically
93
  save_results_to_repo(text_output, label)
94
 
95
  return text_output, label
96
 
97
+ # Gradio Interface
 
 
 
 
 
 
 
 
 
 
 
 
98
  image_input = gr.Image()
99
+ method_input = gr.Radio(["PaddleOCR", "EasyOCR", "KerasOCR"], value="PaddleOCR", label="Choose OCR Method")
100
+ output_text = gr.Textbox(label="Extracted Text", interactive=True)
101
+ output_label = gr.Textbox(label="Spam Classification", interactive=False)
102
+
103
+ # Define UI layout
104
+ with gr.Blocks() as demo:
105
+ gr.Markdown("## OCR Spam Classifier")
106
+
107
+ with gr.Row():
108
+ method_input.render()
109
+
110
+ with gr.Row():
111
+ image_input.render()
112
+
113
+ extract_button = gr.Button("Submit")
114
+ classify_button = gr.Button("Classify")
115
+
116
+ extract_button.click(fn=extract_text, inputs=[method_input, image_input], outputs=[output_text])
117
+ classify_button.click(fn=classify_text, inputs=[output_text], outputs=[output_text, output_label])
 
 
 
118
 
119
  # Launch App
120
+ if __name__ == "__main__":
121
+ demo.launch()