winamnd commited on
Commit
ffe536a
·
verified ·
1 Parent(s): db8a1e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -9
app.py CHANGED
@@ -100,7 +100,7 @@ def generate_ocr(method, image):
100
  if len(text_output) == 0:
101
  return "No text detected!", "Cannot classify"
102
 
103
- # Tokenize text for LLM classification
104
  inputs = tokenizer(text_output, return_tensors="pt", truncation=True, padding=True, max_length=512)
105
 
106
  # Perform inference
@@ -108,27 +108,30 @@ def generate_ocr(method, image):
108
  outputs = model(**inputs)
109
  logits = outputs.logits # Get raw logits
110
 
111
- # Print raw logits for debugging
112
  print(f"Raw logits: {logits}")
113
 
114
- # Convert logits to probabilities
115
  probs = F.softmax(logits, dim=1)
116
-
117
  # Extract probability values
118
  not_spam_prob = probs[0, 0].item()
119
  spam_prob = probs[0, 1].item()
120
 
121
- # Print probabilities for debugging
122
  print(f"Not Spam Probability: {not_spam_prob}, Spam Probability: {spam_prob}")
123
 
124
- # Use a classification threshold to avoid bias
125
- threshold = 0.55 # Adjust based on observations
126
- if spam_prob >= threshold:
 
 
 
127
  label = "Spam"
128
  else:
129
  label = "Not Spam"
130
 
131
- # Save results using external function
132
  save_results_to_repo(text_output, label)
133
 
134
  return text_output, label
 
100
  if len(text_output) == 0:
101
  return "No text detected!", "Cannot classify"
102
 
103
+ # Tokenize text for classification
104
  inputs = tokenizer(text_output, return_tensors="pt", truncation=True, padding=True, max_length=512)
105
 
106
  # Perform inference
 
108
  outputs = model(**inputs)
109
  logits = outputs.logits # Get raw logits
110
 
111
+ # Print raw logits to debug
112
  print(f"Raw logits: {logits}")
113
 
114
+ # Convert logits to probabilities using softmax
115
  probs = F.softmax(logits, dim=1)
116
+
117
  # Extract probability values
118
  not_spam_prob = probs[0, 0].item()
119
  spam_prob = probs[0, 1].item()
120
 
121
+ # Print probability values for debugging
122
  print(f"Not Spam Probability: {not_spam_prob}, Spam Probability: {spam_prob}")
123
 
124
+ # Ensure correct label mapping
125
+ predicted_class = torch.argmax(probs, dim=1).item() # Get predicted class index
126
+ print(f"Predicted Class Index: {predicted_class}") # Debugging output
127
+
128
+ # Check if the labels are flipped
129
+ if predicted_class == 1:
130
  label = "Spam"
131
  else:
132
  label = "Not Spam"
133
 
134
+ # Save results
135
  save_results_to_repo(text_output, label)
136
 
137
  return text_output, label