winamnd commited on
Commit
7dedea0
·
verified ·
1 Parent(s): 2a250f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -41
app.py CHANGED
@@ -73,6 +73,24 @@ def ocr_with_tesseract(img):
73
  confidences = [1.0] * len(extracted_text) # Tesseract doesn't return confidence scores
74
  return extracted_text, confidences
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  # OCR & Classification Function
77
  def generate_ocr(method, img):
78
  if img is None:
@@ -83,57 +101,33 @@ def generate_ocr(method, img):
83
 
84
  # Select OCR method
85
  if method == "PaddleOCR":
86
- extracted_text = ocr_with_paddle(img)
87
  elif method == "EasyOCR":
88
- extracted_text = ocr_with_easy(img)
89
- elif method == "KerasOCR":
90
- extracted_text = ocr_with_keras(img)
91
- elif method == "TesseractOCR":
92
- extracted_text, _ = ocr_with_tesseract(img) # Ignore confidence values
93
- else:
94
- return "Invalid OCR method", "N/A"
95
-
96
- # Clean text
97
- extracted_text = extracted_text.strip()
98
- if not extracted_text:
99
- return "No text detected!", "Cannot classify"
100
 
101
- # Debugging: Print extracted text
102
- print(f"Extracted Text: {extracted_text}")
103
-
104
- # Tokenize input
105
- inputs = tokenizer(
106
- extracted_text,
107
- return_tensors="pt",
108
- truncation=True,
109
- padding="max_length",
110
- max_length=512
111
- )
112
 
113
- # Move tensors to the same device as the model
114
- inputs = {key: val.to(model.device) for key, val in inputs.items()}
115
 
116
  # Perform inference
117
  with torch.no_grad():
118
  outputs = model(**inputs)
119
- logits = outputs.logits
120
-
121
- # Debugging: Print logits
122
- print(f"Logits: {logits}")
123
-
124
- # Use argmax to classify
125
- predicted_class = torch.argmax(logits, dim=1).item()
126
- label_map = {0: "Not Spam", 1: "Spam"}
127
- label = label_map.get(predicted_class, "Unknown")
128
-
129
- # Debugging: Print final classification
130
- print(f"Predicted Class: {predicted_class}, Label: {label}")
131
 
132
- # Save results
133
- save_results_to_repo(extracted_text, label)
134
 
135
- return extracted_text, label
 
136
 
 
137
 
138
  # Gradio Interface
139
  image_input = gr.Image()
 
73
  confidences = [1.0] * len(extracted_text) # Tesseract doesn't return confidence scores
74
  return extracted_text, confidences
75
 
76
+ # OCR & Classification Function
77
+ def ocr_with_paddle(img):
78
+ ocr = PaddleOCR(lang='en', use_angle_cls=True)
79
+ result = ocr.ocr(img)
80
+ return ' '.join([item[1][0] for item in result[0]])
81
+
82
+ def ocr_with_keras(img):
83
+ pipeline = keras_ocr.pipeline.Pipeline()
84
+ images = [keras_ocr.tools.read(img)]
85
+ predictions = pipeline.recognize(images)
86
+ return ' '.join([text for text, _ in predictions[0]])
87
+
88
+ def ocr_with_easy(img):
89
+ gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
90
+ reader = easyocr.Reader(['en'])
91
+ results = reader.readtext(gray_image, detail=0)
92
+ return ' '.join(results)
93
+
94
  # OCR & Classification Function
95
  def generate_ocr(method, img):
96
  if img is None:
 
101
 
102
  # Select OCR method
103
  if method == "PaddleOCR":
104
+ text_output = ocr_with_paddle(img)
105
  elif method == "EasyOCR":
106
+ text_output = ocr_with_easy(img)
107
+ else: # KerasOCR
108
+ text_output = ocr_with_keras(img)
 
 
 
 
 
 
 
 
 
109
 
110
+ # Preprocess text properly
111
+ text_output = text_output.strip()
112
+ if len(text_output) == 0:
113
+ return "No text detected!", "Cannot classify"
 
 
 
 
 
 
 
114
 
115
+ # Tokenize text
116
+ inputs = tokenizer(text_output, return_tensors="pt", truncation=True, padding=True, max_length=512)
117
 
118
  # Perform inference
119
  with torch.no_grad():
120
  outputs = model(**inputs)
121
+ probs = F.softmax(outputs.logits, dim=1) # Convert logits to probabilities
122
+ spam_prob = probs[0][1].item() # Probability of Spam
 
 
 
 
 
 
 
 
 
 
123
 
124
+ # Adjust classification based on threshold (better than argmax)
125
+ label = "Spam" if spam_prob > 0.5 else "Not Spam"
126
 
127
+ # Save results using external function
128
+ save_results_to_repo(text_output, label)
129
 
130
+ return text_output, label
131
 
132
  # Gradio Interface
133
  image_input = gr.Image()