winamnd commited on
Commit
b2a51c8
·
verified ·
1 Parent(s): a7de18e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -23
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import gradio as gr
2
  import torch
3
  import json
4
- import csv
5
  import os
6
  import cv2
7
  import numpy as np
@@ -10,12 +9,17 @@ import keras_ocr
10
  from paddleocr import PaddleOCR
11
  from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
12
  import torch.nn.functional as F
13
- from save_results import save_results_to_repo # Import the save function
 
 
 
 
 
14
 
15
  # Paths
16
  MODEL_PATH = "./distilbert_spam_model"
17
 
18
- # Ensure model exists
19
  if not os.path.exists(os.path.join(MODEL_PATH, "pytorch_model.bin")):
20
  print(f"⚠️ Model not found in {MODEL_PATH}. Downloading from Hugging Face Hub...")
21
  model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
@@ -27,49 +31,76 @@ else:
27
  model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH)
28
  tokenizer = DistilBertTokenizer.from_pretrained(MODEL_PATH)
29
 
30
- # 🔹 Ensure model is in evaluation mode
31
  model.eval()
32
 
33
- # OCR Functions (No changes here)
 
 
 
 
 
34
  def ocr_with_paddle(img):
35
  ocr = PaddleOCR(lang='en', use_angle_cls=True)
36
  result = ocr.ocr(img)
37
- return ' '.join([item[1][0] for item in result[0]])
 
 
 
 
 
38
 
39
  def ocr_with_keras(img):
40
  pipeline = keras_ocr.pipeline.Pipeline()
41
  images = [keras_ocr.tools.read(img)]
42
  predictions = pipeline.recognize(images)
43
- return ' '.join([text for text, _ in predictions[0]])
 
 
44
 
45
  def ocr_with_easy(img):
46
  gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
47
  reader = easyocr.Reader(['en'])
48
- results = reader.readtext(gray_image, detail=0)
49
- return ' '.join(results)
 
 
 
 
 
 
 
 
 
50
 
51
  # OCR & Classification Function
52
- def generate_ocr(method, img):
53
- if img is None:
54
  raise gr.Error("Please upload an image!")
55
 
56
  # Convert PIL Image to OpenCV format
57
- img = np.array(img)
58
 
59
  # Select OCR method
60
  if method == "PaddleOCR":
61
- text_output = ocr_with_paddle(img)
62
  elif method == "EasyOCR":
63
- text_output = ocr_with_easy(img)
64
- else: # KerasOCR
65
- text_output = ocr_with_keras(img)
66
-
67
- # Preprocess text properly
68
- text_output = text_output.strip()
 
 
 
 
 
 
69
  if len(text_output) == 0:
70
  return "No text detected!", "Cannot classify"
71
 
72
- # Tokenize text
73
  inputs = tokenizer(text_output, return_tensors="pt", truncation=True, padding=True, max_length=512)
74
 
75
  # Perform inference
@@ -78,7 +109,7 @@ def generate_ocr(method, img):
78
  probs = F.softmax(outputs.logits, dim=1) # Convert logits to probabilities
79
  spam_prob = probs[0][1].item() # Probability of Spam
80
 
81
- # Adjust classification based on threshold (better than argmax)
82
  label = "Spam" if spam_prob > 0.5 else "Not Spam"
83
 
84
  # Save results using external function
@@ -88,7 +119,7 @@ def generate_ocr(method, img):
88
 
89
  # Gradio Interface
90
  image_input = gr.Image()
91
- method_input = gr.Radio(["PaddleOCR", "EasyOCR", "KerasOCR"], value="PaddleOCR")
92
  output_text = gr.Textbox(label="Extracted Text")
93
  output_label = gr.Textbox(label="Spam Classification")
94
 
@@ -97,7 +128,7 @@ demo = gr.Interface(
97
  inputs=[method_input, image_input],
98
  outputs=[output_text, output_label],
99
  title="OCR Spam Classifier",
100
- description="Upload an image, extract text, and classify it as Spam or Not Spam.",
101
  theme="compact",
102
  )
103
 
 
1
  import gradio as gr
2
  import torch
3
  import json
 
4
  import os
5
  import cv2
6
  import numpy as np
 
9
  from paddleocr import PaddleOCR
10
  from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
11
  import torch.nn.functional as F
12
+ from PIL import Image
13
+ import pytesseract
14
+ import io
15
+
16
+ # Import save function
17
+ from save_results import save_results_to_repo
18
 
19
  # Paths
20
  MODEL_PATH = "./distilbert_spam_model"
21
 
22
+ # Ensure LLM Model exists
23
  if not os.path.exists(os.path.join(MODEL_PATH, "pytorch_model.bin")):
24
  print(f"⚠️ Model not found in {MODEL_PATH}. Downloading from Hugging Face Hub...")
25
  model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
 
31
  model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH)
32
  tokenizer = DistilBertTokenizer.from_pretrained(MODEL_PATH)
33
 
34
+ # Ensure model is in evaluation mode
35
  model.eval()
36
 
37
+ # Function to process image for OCR
38
+ def preprocess_image(image):
39
+ """Convert PIL image to OpenCV format (NumPy array)"""
40
+ return np.array(image)
41
+
42
+ # OCR Functions (same as ocr-api)
43
  def ocr_with_paddle(img):
44
  ocr = PaddleOCR(lang='en', use_angle_cls=True)
45
  result = ocr.ocr(img)
46
+ extracted_text, confidences = [], []
47
+ for line in result[0]:
48
+ text, confidence = line[1]
49
+ extracted_text.append(text)
50
+ confidences.append(confidence)
51
+ return extracted_text, confidences
52
 
53
  def ocr_with_keras(img):
54
  pipeline = keras_ocr.pipeline.Pipeline()
55
  images = [keras_ocr.tools.read(img)]
56
  predictions = pipeline.recognize(images)
57
+ extracted_text = [text for text, confidence in predictions[0]]
58
+ confidences = [confidence for text, confidence in predictions[0]]
59
+ return extracted_text, confidences
60
 
61
  def ocr_with_easy(img):
62
  gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
63
  reader = easyocr.Reader(['en'])
64
+ results = reader.readtext(gray_image)
65
+ extracted_text = [text for _, text, confidence in results]
66
+ confidences = [confidence for _, text, confidence in results]
67
+ return extracted_text, confidences
68
+
69
+ def ocr_with_tesseract(img):
70
+ gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
71
+ extracted_text = pytesseract.image_to_string(gray_image).split("\n")
72
+ extracted_text = [line.strip() for line in extracted_text if line.strip()]
73
+ confidences = [1.0] * len(extracted_text) # Tesseract doesn't return confidence scores
74
+ return extracted_text, confidences
75
 
76
  # OCR & Classification Function
77
+ def generate_ocr(method, image):
78
+ if image is None:
79
  raise gr.Error("Please upload an image!")
80
 
81
  # Convert PIL Image to OpenCV format
82
+ img_cv = preprocess_image(image)
83
 
84
  # Select OCR method
85
  if method == "PaddleOCR":
86
+ extracted_text, confidences = ocr_with_paddle(img_cv)
87
  elif method == "EasyOCR":
88
+ extracted_text, confidences = ocr_with_easy(img_cv)
89
+ elif method == "KerasOCR":
90
+ extracted_text, confidences = ocr_with_keras(img_cv)
91
+ elif method == "TesseractOCR":
92
+ extracted_text, confidences = ocr_with_tesseract(img_cv)
93
+ else:
94
+ return "Invalid OCR method", "N/A"
95
+
96
+ # Join extracted text into a single string
97
+ text_output = " ".join(extracted_text).strip()
98
+
99
+ # If no text detected, return early
100
  if len(text_output) == 0:
101
  return "No text detected!", "Cannot classify"
102
 
103
+ # Tokenize text for LLM classification
104
  inputs = tokenizer(text_output, return_tensors="pt", truncation=True, padding=True, max_length=512)
105
 
106
  # Perform inference
 
109
  probs = F.softmax(outputs.logits, dim=1) # Convert logits to probabilities
110
  spam_prob = probs[0][1].item() # Probability of Spam
111
 
112
+ # Adjust classification based on threshold
113
  label = "Spam" if spam_prob > 0.5 else "Not Spam"
114
 
115
  # Save results using external function
 
119
 
120
  # Gradio Interface
121
  image_input = gr.Image()
122
+ method_input = gr.Radio(["PaddleOCR", "EasyOCR", "KerasOCR", "TesseractOCR"], value="PaddleOCR")
123
  output_text = gr.Textbox(label="Extracted Text")
124
  output_label = gr.Textbox(label="Spam Classification")
125
 
 
128
  inputs=[method_input, image_input],
129
  outputs=[output_text, output_label],
130
  title="OCR Spam Classifier",
131
+ description="Upload an image, extract text using OCR, and classify it as Spam or Not Spam.",
132
  theme="compact",
133
  )
134