winamnd commited on
Commit
0ab07a8
·
verified ·
1 Parent(s): fffd891

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -6
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import gradio as gr
2
  import torch
3
  import json
4
- import csv
5
  import os
6
  import cv2
7
  import numpy as np
@@ -15,7 +14,6 @@ from save_results import save_results_to_repo
15
  # Paths
16
  MODEL_PATH = "./distilbert_spam_model"
17
  RESULTS_JSON = "ocr_results.json"
18
- RESULTS_CSV = "ocr_results.csv"
19
 
20
  # Ensure model exists
21
  if not os.path.exists(os.path.join(MODEL_PATH, "pytorch_model.bin")):
@@ -87,14 +85,23 @@ def generate_ocr(method, img):
87
 
88
  return text_output, label
89
 
 
 
 
 
 
 
 
90
  # Gradio Interface
91
  image_input = gr.Image()
92
  method_input = gr.Radio(["PaddleOCR", "EasyOCR", "KerasOCR"], value="PaddleOCR")
93
  output_text = gr.Textbox(label="Extracted Text")
94
  output_label = gr.Textbox(label="Spam Classification")
 
 
95
 
96
  demo = gr.Interface(
97
- generate_ocr,
98
  inputs=[method_input, image_input],
99
  outputs=[output_text, output_label],
100
  title="OCR Spam Classifier",
@@ -102,6 +109,8 @@ demo = gr.Interface(
102
  theme="compact",
103
  )
104
 
105
- # Launch App
106
- if __name__ == "__main__":
107
- demo.launch()
 
 
 
1
  import gradio as gr
2
  import torch
3
  import json
 
4
  import os
5
  import cv2
6
  import numpy as np
 
14
  # Paths
15
  MODEL_PATH = "./distilbert_spam_model"
16
  RESULTS_JSON = "ocr_results.json"
 
17
 
18
  # Ensure model exists
19
  if not os.path.exists(os.path.join(MODEL_PATH, "pytorch_model.bin")):
 
85
 
86
  return text_output, label
87
 
88
+ # Save results to JSON file
89
+ def save_to_json(text, label):
90
+ data = {"extracted_text": text, "classification": label}
91
+ with open(RESULTS_JSON, "w") as f:
92
+ json.dump(data, f, indent=4)
93
+ return "Results saved to JSON file!"
94
+
95
  # Gradio Interface
96
  image_input = gr.Image()
97
  method_input = gr.Radio(["PaddleOCR", "EasyOCR", "KerasOCR"], value="PaddleOCR")
98
  output_text = gr.Textbox(label="Extracted Text")
99
  output_label = gr.Textbox(label="Spam Classification")
100
+ save_button = gr.Button("Save to JSON")
101
+ save_output = gr.Textbox(label="Save Status")
102
 
103
  demo = gr.Interface(
104
+ fn=generate_ocr,
105
  inputs=[method_input, image_input],
106
  outputs=[output_text, output_label],
107
  title="OCR Spam Classifier",
 
109
  theme="compact",
110
  )
111
 
112
+ # Add Save Button Interaction
113
+ demo.add_component(save_button)
114
+ save_button.click(save_to_json, inputs=[output_text, output_label], outputs=[save_output])
115
+
116
+ demo.launch()