hackerbyhobby commited on
Commit
05a7e7d
·
unverified ·
1 Parent(s): ca159d8

updated app to have user choose text or OCR and hide elements

Browse files
Files changed (2) hide show
  1. app.py +69 -53
  2. app.py.working_ocr_selection +194 -0
app.py CHANGED
@@ -21,16 +21,15 @@ model_name = "joeddav/xlm-roberta-large-xnli"
21
  classifier = pipeline("zero-shot-classification", model=model_name)
22
  CANDIDATE_LABELS = ["SMiShing", "Other Scam", "Legitimate"]
23
 
24
-
25
  def get_keywords_by_language(text: str):
26
  """
27
  Detect language using `langdetect` and translate keywords if needed.
28
  """
29
- snippet = text[:200] # Use a snippet for detection
30
  try:
31
  detected_lang = detect(snippet)
32
  except Exception:
33
- detected_lang = "en" # Default to English if detection fails
34
 
35
  if detected_lang == "es":
36
  smishing_in_spanish = [
@@ -43,7 +42,6 @@ def get_keywords_by_language(text: str):
43
  else:
44
  return SMISHING_KEYWORDS, OTHER_SCAM_KEYWORDS, "en"
45
 
46
-
47
  def boost_probabilities(probabilities: dict, text: str):
48
  """
49
  Boost probabilities based on keyword matches and presence of URLs.
@@ -54,13 +52,11 @@ def boost_probabilities(probabilities: dict, text: str):
54
  smishing_count = sum(1 for kw in smishing_keywords if kw in lower_text)
55
  other_scam_count = sum(1 for kw in other_scam_keywords if kw in lower_text)
56
 
57
- # Example: 30% per found keyword
58
  smishing_boost = 0.30 * smishing_count
59
  other_scam_boost = 0.30 * other_scam_count
60
 
61
  found_urls = re.findall(r"(https?://[^\s]+)", lower_text)
62
  if found_urls:
63
- # 35% boost for Smishing if there's a URL
64
  smishing_boost += 0.35
65
 
66
  p_smishing = probabilities.get("SMiShing", 0.0)
@@ -71,7 +67,7 @@ def boost_probabilities(probabilities: dict, text: str):
71
  p_other_scam += other_scam_boost
72
  p_legit -= (smishing_boost + other_scam_boost)
73
 
74
- # Clamp to 0
75
  p_smishing = max(p_smishing, 0.0)
76
  p_other_scam = max(p_other_scam, 0.0)
77
  p_legit = max(p_legit, 0.0)
@@ -92,23 +88,18 @@ def boost_probabilities(probabilities: dict, text: str):
92
  "detected_lang": detected_lang
93
  }
94
 
95
-
96
  def smishing_detector(input_type, text, image):
97
  """
98
- Main detection function:
99
- - If input_type == "Text": use `text` as the message
100
- - If input_type == "Screenshot": use OCR on `image` to get text
101
  """
102
  if input_type == "Text":
103
- # Use the pasted text
104
  combined_text = text.strip() if text else ""
105
  else:
106
  # input_type == "Screenshot"
 
107
  if image is not None:
108
- ocr_text = pytesseract.image_to_string(image, lang="spa+eng")
109
- combined_text = ocr_text.strip()
110
- else:
111
- combined_text = ""
112
 
113
  if not combined_text:
114
  return {
@@ -129,16 +120,12 @@ def smishing_detector(input_type, text, image):
129
 
130
  # Boost logic
131
  boosted = boost_probabilities(original_probs, combined_text)
132
-
133
- # Convert to float
134
  boosted = {k: float(v) for k, v in boosted.items() if isinstance(v, (int, float))}
135
- detected_lang = boosted.pop("detected_lang", "en")
136
 
137
- # Final classification
138
  final_label = max(boosted, key=boosted.get)
139
  final_confidence = round(boosted[final_label], 3)
140
 
141
- # For display
142
  lower_text = combined_text.lower()
143
  smishing_keys, scam_keys, _ = get_keywords_by_language(combined_text)
144
 
@@ -149,8 +136,12 @@ def smishing_detector(input_type, text, image):
149
  return {
150
  "detected_language": detected_lang,
151
  "text_used_for_classification": combined_text,
152
- "original_probabilities": {k: round(v, 3) for k, v in original_probs.items()},
153
- "boosted_probabilities": {k: round(v, 3) for k, v in boosted.items()},
 
 
 
 
154
  "label": final_label,
155
  "confidence": final_confidence,
156
  "smishing_keywords_found": found_smishing,
@@ -158,37 +149,62 @@ def smishing_detector(input_type, text, image):
158
  "urls_found": found_urls,
159
  }
160
 
161
-
162
- # Create a Radio for user choice + text input + image input
163
- demo = gr.Interface(
164
- fn=smishing_detector,
165
- inputs=[
166
- gr.Radio(
167
- choices=["Text", "Screenshot"],
168
- label="Choose input type",
169
- value="Text", # default
170
- info="Select 'Text' to paste a message, or 'Screenshot' to upload an image."
171
- ),
172
- gr.Textbox(
173
- lines=3,
174
- label="Paste Suspicious SMS Text",
175
- placeholder="Type or paste the message here..."
176
- ),
177
- gr.Image(
178
- type="pil",
179
- label="Upload a Screenshot",
 
 
 
 
180
  )
181
- ],
182
- outputs="json",
183
- title="SMiShing & Scam Detector",
184
- description="""
185
- Select "Text" or "Screenshot" above.
186
- - If "Text", only use the textbox.
187
- - If "Screenshot", only upload an image.
188
- The app will classify the message as SMiShing, Other Scam, or Legitimate.
189
- """,
190
- allow_flagging="never"
191
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
  if __name__ == "__main__":
194
  demo.launch()
 
21
  classifier = pipeline("zero-shot-classification", model=model_name)
22
  CANDIDATE_LABELS = ["SMiShing", "Other Scam", "Legitimate"]
23
 
 
24
  def get_keywords_by_language(text: str):
25
  """
26
  Detect language using `langdetect` and translate keywords if needed.
27
  """
28
+ snippet = text[:200]
29
  try:
30
  detected_lang = detect(snippet)
31
  except Exception:
32
+ detected_lang = "en"
33
 
34
  if detected_lang == "es":
35
  smishing_in_spanish = [
 
42
  else:
43
  return SMISHING_KEYWORDS, OTHER_SCAM_KEYWORDS, "en"
44
 
 
45
  def boost_probabilities(probabilities: dict, text: str):
46
  """
47
  Boost probabilities based on keyword matches and presence of URLs.
 
52
  smishing_count = sum(1 for kw in smishing_keywords if kw in lower_text)
53
  other_scam_count = sum(1 for kw in other_scam_keywords if kw in lower_text)
54
 
 
55
  smishing_boost = 0.30 * smishing_count
56
  other_scam_boost = 0.30 * other_scam_count
57
 
58
  found_urls = re.findall(r"(https?://[^\s]+)", lower_text)
59
  if found_urls:
 
60
  smishing_boost += 0.35
61
 
62
  p_smishing = probabilities.get("SMiShing", 0.0)
 
67
  p_other_scam += other_scam_boost
68
  p_legit -= (smishing_boost + other_scam_boost)
69
 
70
+ # Clamp
71
  p_smishing = max(p_smishing, 0.0)
72
  p_other_scam = max(p_other_scam, 0.0)
73
  p_legit = max(p_legit, 0.0)
 
88
  "detected_lang": detected_lang
89
  }
90
 
 
91
  def smishing_detector(input_type, text, image):
92
  """
93
+ Only use the textbox if input_type == "Text",
94
+ otherwise perform OCR on the image if input_type == "Screenshot".
 
95
  """
96
  if input_type == "Text":
 
97
  combined_text = text.strip() if text else ""
98
  else:
99
  # input_type == "Screenshot"
100
+ combined_text = ""
101
  if image is not None:
102
+ combined_text = pytesseract.image_to_string(image, lang="spa+eng").strip()
 
 
 
103
 
104
  if not combined_text:
105
  return {
 
120
 
121
  # Boost logic
122
  boosted = boost_probabilities(original_probs, combined_text)
 
 
123
  boosted = {k: float(v) for k, v in boosted.items() if isinstance(v, (int, float))}
 
124
 
125
+ detected_lang = boosted.pop("detected_lang", "en")
126
  final_label = max(boosted, key=boosted.get)
127
  final_confidence = round(boosted[final_label], 3)
128
 
 
129
  lower_text = combined_text.lower()
130
  smishing_keys, scam_keys, _ = get_keywords_by_language(combined_text)
131
 
 
136
  return {
137
  "detected_language": detected_lang,
138
  "text_used_for_classification": combined_text,
139
+ "original_probabilities": {
140
+ k: round(v, 3) for k, v in original_probs.items()
141
+ },
142
+ "boosted_probabilities": {
143
+ k: round(v, 3) for k, v in boosted.items()
144
+ },
145
  "label": final_label,
146
  "confidence": final_confidence,
147
  "smishing_keywords_found": found_smishing,
 
149
  "urls_found": found_urls,
150
  }
151
 
152
+ #
153
+ # Gradio interface with dynamic visibility
154
+ #
155
+ def toggle_inputs(choice):
156
+ """
157
+ Return updates for (text_input, image_input) based on the radio selection.
158
+ """
159
+ if choice == "Text":
160
+ # Show text input, hide image
161
+ return gr.update(visible=True), gr.update(visible=False)
162
+ else:
163
+ # choice == "Screenshot"
164
+ # Hide text input, show image
165
+ return gr.update(visible=False), gr.update(visible=True)
166
+
167
+ with gr.Blocks() as demo:
168
+ gr.Markdown("## SMiShing & Scam Detector (Choose Text or Screenshot)")
169
+
170
+ with gr.Row():
171
+ input_type = gr.Radio(
172
+ choices=["Text", "Screenshot"],
173
+ value="Text",
174
+ label="Choose Input Type"
175
  )
176
+
177
+ text_input = gr.Textbox(
178
+ lines=3,
179
+ label="Paste Suspicious SMS Text",
180
+ placeholder="Type or paste the message here...",
181
+ visible=True # default
182
+ )
183
+
184
+ image_input = gr.Image(
185
+ type="pil",
186
+ label="Upload Screenshot",
187
+ visible=False # hidden by default
188
+ )
189
+
190
+ # Whenever input_type changes, toggle which input is visible
191
+ input_type.change(
192
+ fn=toggle_inputs,
193
+ inputs=input_type,
194
+ outputs=[text_input, image_input],
195
+ queue=False
196
+ )
197
+
198
+ # Button to run classification
199
+ analyze_btn = gr.Button("Classify")
200
+ output_json = gr.JSON(label="Result")
201
+
202
+ # On button click, call the smishing_detector
203
+ analyze_btn.click(
204
+ fn=smishing_detector,
205
+ inputs=[input_type, text_input, image_input],
206
+ outputs=output_json
207
+ )
208
 
209
  if __name__ == "__main__":
210
  demo.launch()
app.py.working_ocr_selection ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pytesseract
3
+ from PIL import Image
4
+ from transformers import pipeline
5
+ import re
6
+ from langdetect import detect
7
+ from deep_translator import GoogleTranslator
8
+
9
+ # Translator instance
10
+ translator = GoogleTranslator(source="auto", target="es")
11
+
12
+ # 1. Load separate keywords for SMiShing and Other Scam (assumed in English)
13
+ with open("smishing_keywords.txt", "r", encoding="utf-8") as f:
14
+ SMISHING_KEYWORDS = [line.strip().lower() for line in f if line.strip()]
15
+
16
+ with open("other_scam_keywords.txt", "r", encoding="utf-8") as f:
17
+ OTHER_SCAM_KEYWORDS = [line.strip().lower() for line in f if line.strip()]
18
+
19
+ # 2. Zero-Shot Classification Pipeline
20
+ model_name = "joeddav/xlm-roberta-large-xnli"
21
+ classifier = pipeline("zero-shot-classification", model=model_name)
22
+ CANDIDATE_LABELS = ["SMiShing", "Other Scam", "Legitimate"]
23
+
24
+
25
+ def get_keywords_by_language(text: str):
26
+ """
27
+ Detect language using `langdetect` and translate keywords if needed.
28
+ """
29
+ snippet = text[:200] # Use a snippet for detection
30
+ try:
31
+ detected_lang = detect(snippet)
32
+ except Exception:
33
+ detected_lang = "en" # Default to English if detection fails
34
+
35
+ if detected_lang == "es":
36
+ smishing_in_spanish = [
37
+ translator.translate(kw).lower() for kw in SMISHING_KEYWORDS
38
+ ]
39
+ other_scam_in_spanish = [
40
+ translator.translate(kw).lower() for kw in OTHER_SCAM_KEYWORDS
41
+ ]
42
+ return smishing_in_spanish, other_scam_in_spanish, "es"
43
+ else:
44
+ return SMISHING_KEYWORDS, OTHER_SCAM_KEYWORDS, "en"
45
+
46
+
47
+ def boost_probabilities(probabilities: dict, text: str):
48
+ """
49
+ Boost probabilities based on keyword matches and presence of URLs.
50
+ """
51
+ lower_text = text.lower()
52
+ smishing_keywords, other_scam_keywords, detected_lang = get_keywords_by_language(text)
53
+
54
+ smishing_count = sum(1 for kw in smishing_keywords if kw in lower_text)
55
+ other_scam_count = sum(1 for kw in other_scam_keywords if kw in lower_text)
56
+
57
+ # Example: 30% per found keyword
58
+ smishing_boost = 0.30 * smishing_count
59
+ other_scam_boost = 0.30 * other_scam_count
60
+
61
+ found_urls = re.findall(r"(https?://[^\s]+)", lower_text)
62
+ if found_urls:
63
+ # 35% boost for Smishing if there's a URL
64
+ smishing_boost += 0.35
65
+
66
+ p_smishing = probabilities.get("SMiShing", 0.0)
67
+ p_other_scam = probabilities.get("Other Scam", 0.0)
68
+ p_legit = probabilities.get("Legitimate", 1.0)
69
+
70
+ p_smishing += smishing_boost
71
+ p_other_scam += other_scam_boost
72
+ p_legit -= (smishing_boost + other_scam_boost)
73
+
74
+ # Clamp to 0
75
+ p_smishing = max(p_smishing, 0.0)
76
+ p_other_scam = max(p_other_scam, 0.0)
77
+ p_legit = max(p_legit, 0.0)
78
+
79
+ # Re-normalize
80
+ total = p_smishing + p_other_scam + p_legit
81
+ if total > 0:
82
+ p_smishing /= total
83
+ p_other_scam /= total
84
+ p_legit /= total
85
+ else:
86
+ p_smishing, p_other_scam, p_legit = 0.0, 0.0, 1.0
87
+
88
+ return {
89
+ "SMiShing": p_smishing,
90
+ "Other Scam": p_other_scam,
91
+ "Legitimate": p_legit,
92
+ "detected_lang": detected_lang
93
+ }
94
+
95
+
96
+ def smishing_detector(input_type, text, image):
97
+ """
98
+ Main detection function:
99
+ - If input_type == "Text": use `text` as the message
100
+ - If input_type == "Screenshot": use OCR on `image` to get text
101
+ """
102
+ if input_type == "Text":
103
+ # Use the pasted text
104
+ combined_text = text.strip() if text else ""
105
+ else:
106
+ # input_type == "Screenshot"
107
+ if image is not None:
108
+ ocr_text = pytesseract.image_to_string(image, lang="spa+eng")
109
+ combined_text = ocr_text.strip()
110
+ else:
111
+ combined_text = ""
112
+
113
+ if not combined_text:
114
+ return {
115
+ "text_used_for_classification": "(none)",
116
+ "label": "No text provided",
117
+ "confidence": 0.0,
118
+ "keywords_found": [],
119
+ "urls_found": []
120
+ }
121
+
122
+ # Zero-shot classification
123
+ result = classifier(
124
+ sequences=combined_text,
125
+ candidate_labels=CANDIDATE_LABELS,
126
+ hypothesis_template="This message is {}."
127
+ )
128
+ original_probs = {k: float(v) for k, v in zip(result["labels"], result["scores"])}
129
+
130
+ # Boost logic
131
+ boosted = boost_probabilities(original_probs, combined_text)
132
+
133
+ # Convert to float
134
+ boosted = {k: float(v) for k, v in boosted.items() if isinstance(v, (int, float))}
135
+ detected_lang = boosted.pop("detected_lang", "en")
136
+
137
+ # Final classification
138
+ final_label = max(boosted, key=boosted.get)
139
+ final_confidence = round(boosted[final_label], 3)
140
+
141
+ # For display
142
+ lower_text = combined_text.lower()
143
+ smishing_keys, scam_keys, _ = get_keywords_by_language(combined_text)
144
+
145
+ found_urls = re.findall(r"(https?://[^\s]+)", lower_text)
146
+ found_smishing = [kw for kw in smishing_keys if kw in lower_text]
147
+ found_other_scam = [kw for kw in scam_keys if kw in lower_text]
148
+
149
+ return {
150
+ "detected_language": detected_lang,
151
+ "text_used_for_classification": combined_text,
152
+ "original_probabilities": {k: round(v, 3) for k, v in original_probs.items()},
153
+ "boosted_probabilities": {k: round(v, 3) for k, v in boosted.items()},
154
+ "label": final_label,
155
+ "confidence": final_confidence,
156
+ "smishing_keywords_found": found_smishing,
157
+ "other_scam_keywords_found": found_other_scam,
158
+ "urls_found": found_urls,
159
+ }
160
+
161
+
162
+ # Create a Radio for user choice + text input + image input
163
+ demo = gr.Interface(
164
+ fn=smishing_detector,
165
+ inputs=[
166
+ gr.Radio(
167
+ choices=["Text", "Screenshot"],
168
+ label="Choose input type",
169
+ value="Text", # default
170
+ info="Select 'Text' to paste a message, or 'Screenshot' to upload an image."
171
+ ),
172
+ gr.Textbox(
173
+ lines=3,
174
+ label="Paste Suspicious SMS Text",
175
+ placeholder="Type or paste the message here..."
176
+ ),
177
+ gr.Image(
178
+ type="pil",
179
+ label="Upload a Screenshot",
180
+ )
181
+ ],
182
+ outputs="json",
183
+ title="SMiShing & Scam Detector",
184
+ description="""
185
+ Select "Text" or "Screenshot" above.
186
+ - If "Text", only use the textbox.
187
+ - If "Screenshot", only upload an image.
188
+ The app will classify the message as SMiShing, Other Scam, or Legitimate.
189
+ """,
190
+ allow_flagging="never"
191
+ )
192
+
193
+ if __name__ == "__main__":
194
+ demo.launch()