Eemansleepdeprived commited on
Commit
98a24df
·
verified ·
1 Parent(s): 5f25888

Create new.py

Browse files
Files changed (1) hide show
  1. new.py +328 -0
new.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.hub
4
+ import re
5
+ import os
6
+ import time
7
+
8
+ # --- Set Page Config First ---
9
+ st.set_page_config(
10
+ page_title="AI Text Detector",
11
+ layout="centered",
12
+ initial_sidebar_state="collapsed"
13
+ )
14
+
15
+ # --- Improved CSS for a cleaner UI ---
16
+ st.markdown("""
17
+ <style>
18
+ /* Modern clean font for the entire app */
19
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap');
20
+
21
+ html, body, [class*="css"] {
22
+ font-family: 'Inter', sans-serif;
23
+ }
24
+
25
+ /* Header styling */
26
+ h1 {
27
+ font-weight: 700;
28
+ color: #1E3A8A;
29
+ padding-bottom: 1rem;
30
+ border-bottom: 2px solid #E5E7EB;
31
+ margin-bottom: 2rem;
32
+ }
33
+
34
+ /* Text area styling */
35
+ .stTextArea textarea {
36
+ border: 1px solid #D1D5DB;
37
+ border-radius: 8px;
38
+ font-size: 16px;
39
+ padding: 12px;
40
+ background-color: #F9FAFB;
41
+ box-shadow: 0 1px 2px rgba(0, 0, 0, 0.05);
42
+ transition: border-color 0.15s ease-in-out, box-shadow 0.15s ease-in-out;
43
+ }
44
+
45
+ .stTextArea textarea:focus {
46
+ border-color: #3B82F6;
47
+ box-shadow: 0 0 0 3px rgba(59, 130, 246, 0.3);
48
+ outline: none;
49
+ }
50
+
51
+ /* Button styling */
52
+ .stButton button {
53
+ border-radius: 8px;
54
+ font-weight: 600;
55
+ padding: 10px 16px;
56
+ background-color: #2563EB;
57
+ color: white;
58
+ border: none;
59
+ width: 100%;
60
+ transition: background-color 0.2s ease;
61
+ }
62
+
63
+ .stButton button:hover {
64
+ background-color: #1D4ED8;
65
+ }
66
+
67
+ /* Result box styling */
68
+ .result-box {
69
+ border-radius: 8px;
70
+ padding: 20px;
71
+ margin-top: 24px;
72
+ text-align: center;
73
+ background-color: white;
74
+ box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1), 0 1px 2px rgba(0, 0, 0, 0.06);
75
+ border: 1px solid #E5E7EB;
76
+ }
77
+
78
+ /* Result highlights */
79
+ .highlight-human {
80
+ color: #059669;
81
+ font-weight: 600;
82
+ background: rgba(5, 150, 105, 0.1);
83
+ padding: 4px 10px;
84
+ border-radius: 8px;
85
+ display: inline-block;
86
+ }
87
+
88
+ .highlight-ai {
89
+ color: #DC2626;
90
+ font-weight: 600;
91
+ background: rgba(220, 38, 38, 0.1);
92
+ padding: 4px 10px;
93
+ border-radius: 8px;
94
+ display: inline-block;
95
+ }
96
+
97
+ /* Footer styling */
98
+ .footer {
99
+ text-align: center;
100
+ margin-top: 40px;
101
+ padding-top: 20px;
102
+ border-top: 1px solid #E5E7EB;
103
+ color: #6B7280;
104
+ font-size: 14px;
105
+ }
106
+
107
+ /* Progress bar styling */
108
+ .stProgress > div > div {
109
+ background-color: #2563EB;
110
+ }
111
+
112
+ /* General spacing */
113
+ .block-container {
114
+ padding-top: 2rem;
115
+ padding-bottom: 2rem;
116
+ }
117
+ </style>
118
+ """, unsafe_allow_html=True)
119
+
120
+ # --- Configuration ---
121
+ MODEL1_PATH = "modernbert.bin"
122
+ MODEL2_URL = "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed12"
123
+ MODEL3_URL = "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed22"
124
+ BASE_MODEL = "answerdotai/ModernBERT-base"
125
+ NUM_LABELS = 41
126
+ HUMAN_LABEL_INDEX = 24
127
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
128
+
129
+ # --- Model Loading Functions ---
130
+ @st.cache_resource(show_spinner=False)
131
+ def load_tokenizer(model_name):
132
+ from transformers import AutoTokenizer
133
+ return AutoTokenizer.from_pretrained(model_name)
134
+
135
+ @st.cache_resource(show_spinner=False)
136
+ def load_model(model_path_or_url, base_model, num_labels, is_url=False, _device=DEVICE):
137
+ from transformers import AutoModelForSequenceClassification
138
+
139
+ # Load base model architecture
140
+ model = AutoModelForSequenceClassification.from_pretrained(base_model, num_labels=num_labels)
141
+
142
+ try:
143
+ # Load weights
144
+ if is_url:
145
+ state_dict = torch.hub.load_state_dict_from_url(model_path_or_url, map_location=_device, progress=False)
146
+ else:
147
+ if not os.path.exists(model_path_or_url):
148
+ return None
149
+ state_dict = torch.load(model_path_or_url, map_location=_device, weights_only=False)
150
+
151
+ model.load_state_dict(state_dict)
152
+ model.to(_device).eval()
153
+ return model
154
+ except Exception:
155
+ return None
156
+
157
+ # --- Text Processing Functions ---
158
+ def clean_text(text):
159
+ if not isinstance(text, str):
160
+ return ""
161
+ text = text.replace("\r\n", "\n").replace("\r", "\n")
162
+ text = re.sub(r"\n\s*\n+", "\n\n", text)
163
+ text = re.sub(r"[ \t]+", " ", text)
164
+ text = re.sub(r"(\w+)-\s*\n\s*(\w+)", r"\1\2", text)
165
+ text = re.sub(r"(?<!\n)\n(?!\n)", " ", text)
166
+ return text.strip()
167
+
168
+ def classify_text(text, tokenizer, model_1, model_2, model_3, device, label_mapping, human_label_index):
169
+ if not all([model_1, model_2, model_3, tokenizer]):
170
+ return {"error": True, "message": "Models failed to load properly."}
171
+
172
+ cleaned_text = clean_text(text)
173
+ if not cleaned_text:
174
+ return None
175
+
176
+ try:
177
+ inputs = tokenizer(
178
+ cleaned_text,
179
+ return_tensors="pt",
180
+ truncation=True,
181
+ padding=True,
182
+ max_length=tokenizer.model_max_length
183
+ ).to(device)
184
+
185
+ with torch.no_grad():
186
+ logits_1 = model_1(**inputs).logits
187
+ logits_2 = model_2(**inputs).logits
188
+ logits_3 = model_3(**inputs).logits
189
+
190
+ softmax_1 = torch.softmax(logits_1, dim=1)
191
+ softmax_2 = torch.softmax(logits_2, dim=1)
192
+ softmax_3 = torch.softmax(logits_3, dim=1)
193
+
194
+ averaged_probabilities = (softmax_1 + softmax_2 + softmax_3) / 3
195
+ probabilities = averaged_probabilities[0].cpu()
196
+
197
+ if not (0 <= human_label_index < len(probabilities)):
198
+ return {"error": True, "message": "Configuration error."}
199
+
200
+ human_prob = probabilities[human_label_index].item() * 100
201
+
202
+ mask = torch.ones_like(probabilities, dtype=torch.bool)
203
+ mask[human_label_index] = False
204
+ ai_total_prob = probabilities[mask].sum().item() * 100
205
+
206
+ ai_probs_only = probabilities.clone()
207
+ ai_probs_only[human_label_index] = -float('inf')
208
+ ai_argmax_index = torch.argmax(ai_probs_only).item()
209
+ ai_argmax_model = label_mapping.get(ai_argmax_index, f"Unknown AI (Index {ai_argmax_index})")
210
+
211
+ if human_prob >= ai_total_prob:
212
+ return {"is_human": True, "probability": human_prob, "model": "Human"}
213
+ else:
214
+ return {"is_human": False, "probability": ai_total_prob, "model": ai_argmax_model}
215
+
216
+ except Exception as e:
217
+ return {"error": True, "message": f"Analysis failed: {str(e)}"}
218
+
219
+ # --- Label Mapping ---
220
+ LABEL_MAPPING = {
221
+ 0: '13B', 1: '30B', 2: '65B', 3: '7B', 4: 'GLM130B', 5: 'bloom_7b',
222
+ 6: 'bloomz', 7: 'cohere', 8: 'davinci', 9: 'dolly', 10: 'dolly-v2-12b',
223
+ 11: 'flan_t5_base', 12: 'flan_t5_large', 13: 'flan_t5_small',
224
+ 14: 'flan_t5_xl', 15: 'flan_t5_xxl', 16: 'gemma-7b-it', 17: 'gemma2-9b-it',
225
+ 18: 'gpt-3.5-turbo', 19: 'gpt-35', 20: 'gpt4', 21: 'gpt4o',
226
+ 22: 'gpt_j', 23: 'gpt_neox', 24: 'human', 25: 'llama3-70b', 26: 'llama3-8b',
227
+ 27: 'mixtral-8x7b', 28: 'opt_1.3b', 29: 'opt_125m', 30: 'opt_13b',
228
+ 31: 'opt_2.7b', 32: 'opt_30b', 33: 'opt_350m', 34: 'opt_6.7b',
229
+ 35: 'opt_iml_30b', 36: 'opt_iml_max_1.3b', 37: 't0_11b', 38: 't0_3b',
230
+ 39: 'text-davinci-002', 40: 'text-davinci-003'
231
+ }
232
+
233
+ # --- Main UI ---
234
+ st.title("🕵️ AI Text Detector")
235
+
236
+ # Initialization with a progress bar
237
+ with st.spinner(""):
238
+ # Create a progress bar
239
+ progress_bar = st.progress(0)
240
+ st.info("Initializing AI detection models...")
241
+
242
+ # Step 1: Load tokenizer
243
+ progress_bar.progress(20)
244
+ time.sleep(0.5) # Small delay for visual feedback
245
+ TOKENIZER = load_tokenizer(BASE_MODEL)
246
+
247
+ # Step 2: Load first model
248
+ progress_bar.progress(40)
249
+ time.sleep(0.5) # Small delay for visual feedback
250
+ MODEL_1 = load_model(MODEL1_PATH, BASE_MODEL, NUM_LABELS, is_url=False, _device=DEVICE)
251
+
252
+ # Step 3: Load second model
253
+ progress_bar.progress(60)
254
+ time.sleep(0.5) # Small delay for visual feedback
255
+ MODEL_2 = load_model(MODEL2_URL, BASE_MODEL, NUM_LABELS, is_url=True, _device=DEVICE)
256
+
257
+ # Step 4: Load third model
258
+ progress_bar.progress(80)
259
+ time.sleep(0.5) # Small delay for visual feedback
260
+ MODEL_3 = load_model(MODEL3_URL, BASE_MODEL, NUM_LABELS, is_url=True, _device=DEVICE)
261
+
262
+ # Complete initialization
263
+ progress_bar.progress(100)
264
+ time.sleep(0.5) # Small delay for visual feedback
265
+
266
+ # Clear the initialization messages
267
+ st.empty()
268
+
269
+ # Check if models loaded successfully
270
+ if not all([TOKENIZER, MODEL_1, MODEL_2, MODEL_3]):
271
+ st.error("Failed to initialize one or more AI detection models. Please try refreshing the page.")
272
+ st.stop()
273
+
274
+ # Input area
275
+ input_text = st.text_area(
276
+ label="Enter text to analyze:",
277
+ placeholder="Type or paste your content here for AI detection analysis...",
278
+ height=200,
279
+ key="text_input"
280
+ )
281
+
282
+ # Analyze button and output
283
+ analyze_button = st.button("Analyze Text", key="analyze_button")
284
+ result_placeholder = st.empty()
285
+
286
+ if analyze_button:
287
+ if input_text and input_text.strip():
288
+ with st.spinner('Analyzing text...'):
289
+ classification_result = classify_text(
290
+ input_text,
291
+ TOKENIZER,
292
+ MODEL_1,
293
+ MODEL_2,
294
+ MODEL_3,
295
+ DEVICE,
296
+ LABEL_MAPPING,
297
+ HUMAN_LABEL_INDEX
298
+ )
299
+
300
+ # Display result
301
+ if classification_result is None:
302
+ result_placeholder.warning("Please enter some text to analyze.")
303
+ elif classification_result.get("error"):
304
+ error_message = classification_result.get("message", "An unknown error occurred during analysis.")
305
+ result_placeholder.error(f"Analysis Error: {error_message}")
306
+ elif classification_result["is_human"]:
307
+ prob = classification_result['probability']
308
+ result_html = (
309
+ f"<div class='result-box'>"
310
+ f"<b>The text is</b> <span class='highlight-human'><b>{prob:.2f}%</b> likely <b>Human written</b>.</span>"
311
+ f"</div>"
312
+ )
313
+ result_placeholder.markdown(result_html, unsafe_allow_html=True)
314
+ else: # AI generated
315
+ prob = classification_result['probability']
316
+ model_name = classification_result['model']
317
+ result_html = (
318
+ f"<div class='result-box'>"
319
+ f"<b>The text is</b> <span class='highlight-ai'><b>{prob:.2f}%</b> likely <b>AI generated</b>.</span><br><br>"
320
+ f"<b>Most Likely AI Model: {model_name}</b>"
321
+ f"</div>"
322
+ )
323
+ result_placeholder.markdown(result_html, unsafe_allow_html=True)
324
+ else:
325
+ result_placeholder.warning("Please enter some text to analyze.")
326
+
327
+ # Footer
328
+ st.markdown("<div class='footer'>Developed by Eeman Majumder</div>", unsafe_allow_html=True)