winamnd commited on
Commit
104c39e
·
verified ·
1 Parent(s): c623da2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -60
app.py CHANGED
@@ -1,87 +1,116 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import DistilBertForSequenceClassification, DistilBertTokenizer, DistilBertConfig
 
 
4
  import cv2
5
  import numpy as np
 
6
  import easyocr
7
  import keras_ocr
8
  from paddleocr import PaddleOCR
9
- import os
10
 
11
- # Ensure model config exists
12
  MODEL_PATH = "./distilbert_spam_model"
 
 
13
 
14
- if not os.path.exists(os.path.join(MODEL_PATH, "config.json")):
15
- print("config.json not found. Generating default configuration...")
16
- config = DistilBertConfig.from_pretrained("distilbert-base-uncased", num_labels=2)
17
- config.save_pretrained(MODEL_PATH)
18
-
19
- # Load tokenizer and model
20
- tokenizer = DistilBertTokenizer.from_pretrained(MODEL_PATH)
21
- model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH)
 
 
 
22
 
23
- # Define Spam Classification Function
24
- def classify_text(text):
25
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
26
- with torch.no_grad():
27
- outputs = model(**inputs)
28
- logits = outputs.logits
29
- prediction = torch.argmax(logits, dim=-1).item()
30
- return "Spam" if prediction == 1 else "Not Spam"
31
-
32
- # OCR Methods
33
  def ocr_with_paddle(img):
34
  ocr = PaddleOCR(lang='en', use_angle_cls=True)
35
  result = ocr.ocr(img)
36
- extracted_text = ' '.join([entry[1][0] for entry in result[0]])
37
- return extracted_text
38
 
39
  def ocr_with_keras(img):
40
  pipeline = keras_ocr.pipeline.Pipeline()
41
  images = [keras_ocr.tools.read(img)]
42
  predictions = pipeline.recognize(images)
43
- extracted_text = ' '.join([text for text, _ in predictions[0]])
44
- return extracted_text
45
 
46
  def ocr_with_easy(img):
 
47
  reader = easyocr.Reader(['en'])
48
- results = reader.readtext(img, detail=0)
49
  return ' '.join(results)
50
 
51
- # OCR + Spam Detection
52
- def process_image(ocr_method, image):
53
- if image is None:
54
- return "Error: No image uploaded."
55
-
56
- if ocr_method == "PaddleOCR":
57
- extracted_text = ocr_with_paddle(image)
58
- elif ocr_method == "KerasOCR":
59
- extracted_text = ocr_with_keras(image)
60
- elif ocr_method == "EasyOCR":
61
- extracted_text = ocr_with_easy(image)
62
- else:
63
- return "Invalid OCR method."
64
-
65
- if not extracted_text.strip():
66
- return "No text detected in the image."
67
-
68
- classification = classify_text(extracted_text)
69
- return f"Extracted Text: {extracted_text}\n\nClassification: {classification}"
70
-
71
- # Gradio UI
72
- image_input = gr.Image(type="numpy")
73
- ocr_method_input = gr.Radio(["PaddleOCR", "EasyOCR", "KerasOCR"], value="PaddleOCR", label="OCR Method")
74
- output_text = gr.Textbox(label="OCR & Classification Result")
75
-
76
- interface = gr.Interface(
77
- fn=process_image,
78
- inputs=[ocr_method_input, image_input],
79
- outputs=output_text,
80
- title="OCR + Spam Detection",
81
- description="Upload an image with text, extract the text using OCR, and classify it as Spam or Not Spam using DistilBERT.",
82
- theme="compact"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  )
84
 
85
- # Launch app
86
  if __name__ == "__main__":
87
- interface.launch()
 
1
  import gradio as gr
2
  import torch
3
+ import json
4
+ import csv
5
+ import os
6
  import cv2
7
  import numpy as np
8
+ import pandas as pd
9
  import easyocr
10
  import keras_ocr
11
  from paddleocr import PaddleOCR
12
+ from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
13
 
14
+ # Paths
15
  MODEL_PATH = "./distilbert_spam_model"
16
+ RESULTS_JSON = "ocr_results.json"
17
+ RESULTS_CSV = "ocr_results.csv"
18
 
19
+ # Ensure model exists
20
+ if not os.path.exists(os.path.join(MODEL_PATH, "pytorch_model.bin")):
21
+ print(f"⚠️ Model not found in {MODEL_PATH}. Downloading from Hugging Face Hub...")
22
+ model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")
23
+ model.save_pretrained(MODEL_PATH)
24
+ tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
25
+ tokenizer.save_pretrained(MODEL_PATH)
26
+ print(f"✅ Model saved at {MODEL_PATH}.")
27
+ else:
28
+ model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH)
29
+ tokenizer = DistilBertTokenizer.from_pretrained(MODEL_PATH)
30
 
31
+ # Load OCR Methods
 
 
 
 
 
 
 
 
 
32
  def ocr_with_paddle(img):
33
  ocr = PaddleOCR(lang='en', use_angle_cls=True)
34
  result = ocr.ocr(img)
35
+ return ' '.join([item[1][0] for item in result[0]])
 
36
 
37
  def ocr_with_keras(img):
38
  pipeline = keras_ocr.pipeline.Pipeline()
39
  images = [keras_ocr.tools.read(img)]
40
  predictions = pipeline.recognize(images)
41
+ return ' '.join([text for text, _ in predictions[0]])
 
42
 
43
  def ocr_with_easy(img):
44
+ gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
45
  reader = easyocr.Reader(['en'])
46
+ results = reader.readtext(gray_image, detail=0)
47
  return ' '.join(results)
48
 
49
+ # OCR Function
50
+ def generate_ocr(method, img):
51
+ if img is None:
52
+ raise gr.Error("Please upload an image!")
53
+
54
+ # Convert PIL Image to OpenCV format
55
+ img = np.array(img)
56
+
57
+ # Select OCR method
58
+ if method == "PaddleOCR":
59
+ text_output = ocr_with_paddle(img)
60
+ elif method == "EasyOCR":
61
+ text_output = ocr_with_easy(img)
62
+ else: # KerasOCR
63
+ text_output = ocr_with_keras(img)
64
+
65
+ # Classify Text as Spam or Not Spam
66
+ inputs = tokenizer(text_output, return_tensors="pt", truncation=True, padding=True)
67
+ with torch.no_grad():
68
+ outputs = model(**inputs)
69
+ prediction = torch.argmax(outputs.logits, dim=1).item()
70
+ label = "Spam" if prediction == 1 else "Not Spam"
71
+
72
+ # Save results
73
+ save_results(text_output, label)
74
+
75
+ return text_output, label
76
+
77
+ # Save extracted text to JSON & CSV
78
+ def save_results(text, label):
79
+ data = {"text": text, "label": label}
80
+
81
+ # Save to JSON
82
+ if not os.path.exists(RESULTS_JSON):
83
+ with open(RESULTS_JSON, "w") as f:
84
+ json.dump([], f)
85
+ with open(RESULTS_JSON, "r+") as f:
86
+ content = json.load(f)
87
+ content.append(data)
88
+ f.seek(0)
89
+ json.dump(content, f, indent=4)
90
+
91
+ # Save to CSV
92
+ file_exists = os.path.exists(RESULTS_CSV)
93
+ with open(RESULTS_CSV, "a", newline="") as f:
94
+ writer = csv.DictWriter(f, fieldnames=["text", "label"])
95
+ if not file_exists:
96
+ writer.writeheader()
97
+ writer.writerow(data)
98
+
99
+ # Gradio Interface
100
+ image_input = gr.Image()
101
+ method_input = gr.Radio(["PaddleOCR", "EasyOCR", "KerasOCR"], value="PaddleOCR")
102
+ output_text = gr.Textbox(label="Extracted Text")
103
+ output_label = gr.Textbox(label="Spam Classification")
104
+
105
+ demo = gr.Interface(
106
+ generate_ocr,
107
+ inputs=[method_input, image_input],
108
+ outputs=[output_text, output_label],
109
+ title="OCR Spam Classifier",
110
+ description="Upload an image, extract text, and classify it as Spam or Not Spam.",
111
+ theme="compact",
112
  )
113
 
114
+ # Launch App
115
  if __name__ == "__main__":
116
+ demo.launch()