winamnd commited on
Commit
2a250f6
·
verified ·
1 Parent(s): bf26c19

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -17
app.py CHANGED
@@ -83,43 +83,57 @@ def generate_ocr(method, img):
83
 
84
  # Select OCR method
85
  if method == "PaddleOCR":
86
- text_output = ocr_with_paddle(img)
87
  elif method == "EasyOCR":
88
- text_output = ocr_with_easy(img)
89
  elif method == "KerasOCR":
90
- text_output = ocr_with_keras(img)
91
  elif method == "TesseractOCR":
92
- text_output, _ = ocr_with_tesseract(img) # Ignore confidence values
93
  else:
94
  return "Invalid OCR method", "N/A"
95
 
96
- # Clean and truncate the extracted text
97
- text_output = text_output.strip()
98
- if len(text_output) == 0:
99
  return "No text detected!", "Cannot classify"
100
 
101
- # Tokenize text for classification
102
- inputs = tokenizer(text_output, return_tensors="pt", truncation=True, padding=True, max_length=512)
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  # Perform inference
105
  with torch.no_grad():
106
  outputs = model(**inputs)
107
- logits = outputs.logits # Get raw logits
108
 
109
- # Debugging: Print raw logits
110
- print(f"Raw logits: {logits}")
111
 
112
- # Use raw logits directly instead of softmax
113
  predicted_class = torch.argmax(logits, dim=1).item()
114
-
115
- # Map class index to labels
116
  label_map = {0: "Not Spam", 1: "Spam"}
117
  label = label_map.get(predicted_class, "Unknown")
118
 
 
 
 
119
  # Save results
120
- save_results_to_repo(text_output, label)
 
 
121
 
122
- return text_output, label
123
 
124
  # Gradio Interface
125
  image_input = gr.Image()
 
83
 
84
  # Select OCR method
85
  if method == "PaddleOCR":
86
+ extracted_text = ocr_with_paddle(img)
87
  elif method == "EasyOCR":
88
+ extracted_text = ocr_with_easy(img)
89
  elif method == "KerasOCR":
90
+ extracted_text = ocr_with_keras(img)
91
  elif method == "TesseractOCR":
92
+ extracted_text, _ = ocr_with_tesseract(img) # Ignore confidence values
93
  else:
94
  return "Invalid OCR method", "N/A"
95
 
96
+ # Clean text
97
+ extracted_text = extracted_text.strip()
98
+ if not extracted_text:
99
  return "No text detected!", "Cannot classify"
100
 
101
+ # Debugging: Print extracted text
102
+ print(f"Extracted Text: {extracted_text}")
103
+
104
+ # Tokenize input
105
+ inputs = tokenizer(
106
+ extracted_text,
107
+ return_tensors="pt",
108
+ truncation=True,
109
+ padding="max_length",
110
+ max_length=512
111
+ )
112
+
113
+ # Move tensors to the same device as the model
114
+ inputs = {key: val.to(model.device) for key, val in inputs.items()}
115
 
116
  # Perform inference
117
  with torch.no_grad():
118
  outputs = model(**inputs)
119
+ logits = outputs.logits
120
 
121
+ # Debugging: Print logits
122
+ print(f"Logits: {logits}")
123
 
124
+ # Use argmax to classify
125
  predicted_class = torch.argmax(logits, dim=1).item()
 
 
126
  label_map = {0: "Not Spam", 1: "Spam"}
127
  label = label_map.get(predicted_class, "Unknown")
128
 
129
+ # Debugging: Print final classification
130
+ print(f"Predicted Class: {predicted_class}, Label: {label}")
131
+
132
  # Save results
133
+ save_results_to_repo(extracted_text, label)
134
+
135
+ return extracted_text, label
136
 
 
137
 
138
  # Gradio Interface
139
  image_input = gr.Image()