Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -100,7 +100,7 @@ def generate_ocr(method, image):
|
|
100 |
if len(text_output) == 0:
|
101 |
return "No text detected!", "Cannot classify"
|
102 |
|
103 |
-
# Tokenize text for
|
104 |
inputs = tokenizer(text_output, return_tensors="pt", truncation=True, padding=True, max_length=512)
|
105 |
|
106 |
# Perform inference
|
@@ -108,27 +108,30 @@ def generate_ocr(method, image):
|
|
108 |
outputs = model(**inputs)
|
109 |
logits = outputs.logits # Get raw logits
|
110 |
|
111 |
-
# Print raw logits
|
112 |
print(f"Raw logits: {logits}")
|
113 |
|
114 |
-
# Convert logits to probabilities
|
115 |
probs = F.softmax(logits, dim=1)
|
116 |
-
|
117 |
# Extract probability values
|
118 |
not_spam_prob = probs[0, 0].item()
|
119 |
spam_prob = probs[0, 1].item()
|
120 |
|
121 |
-
# Print
|
122 |
print(f"Not Spam Probability: {not_spam_prob}, Spam Probability: {spam_prob}")
|
123 |
|
124 |
-
#
|
125 |
-
|
126 |
-
|
|
|
|
|
|
|
127 |
label = "Spam"
|
128 |
else:
|
129 |
label = "Not Spam"
|
130 |
|
131 |
-
# Save results
|
132 |
save_results_to_repo(text_output, label)
|
133 |
|
134 |
return text_output, label
|
|
|
100 |
if len(text_output) == 0:
|
101 |
return "No text detected!", "Cannot classify"
|
102 |
|
103 |
+
# Tokenize text for classification
|
104 |
inputs = tokenizer(text_output, return_tensors="pt", truncation=True, padding=True, max_length=512)
|
105 |
|
106 |
# Perform inference
|
|
|
108 |
outputs = model(**inputs)
|
109 |
logits = outputs.logits # Get raw logits
|
110 |
|
111 |
+
# Print raw logits to debug
|
112 |
print(f"Raw logits: {logits}")
|
113 |
|
114 |
+
# Convert logits to probabilities using softmax
|
115 |
probs = F.softmax(logits, dim=1)
|
116 |
+
|
117 |
# Extract probability values
|
118 |
not_spam_prob = probs[0, 0].item()
|
119 |
spam_prob = probs[0, 1].item()
|
120 |
|
121 |
+
# Print probability values for debugging
|
122 |
print(f"Not Spam Probability: {not_spam_prob}, Spam Probability: {spam_prob}")
|
123 |
|
124 |
+
# Ensure correct label mapping
|
125 |
+
predicted_class = torch.argmax(probs, dim=1).item() # Get predicted class index
|
126 |
+
print(f"Predicted Class Index: {predicted_class}") # Debugging output
|
127 |
+
|
128 |
+
# Check if the labels are flipped
|
129 |
+
if predicted_class == 1:
|
130 |
label = "Spam"
|
131 |
else:
|
132 |
label = "Not Spam"
|
133 |
|
134 |
+
# Save results
|
135 |
save_results_to_repo(text_output, label)
|
136 |
|
137 |
return text_output, label
|