Spaces:
Sleeping
Sleeping
Create new.py
Browse files
new.py
ADDED
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import torch
|
3 |
+
import torch.hub
|
4 |
+
import re
|
5 |
+
import os
|
6 |
+
import time
|
7 |
+
|
8 |
+
# --- Set Page Config First ---
|
9 |
+
st.set_page_config(
|
10 |
+
page_title="AI Text Detector",
|
11 |
+
layout="centered",
|
12 |
+
initial_sidebar_state="collapsed"
|
13 |
+
)
|
14 |
+
|
15 |
+
# --- Improved CSS for a cleaner UI ---
|
16 |
+
st.markdown("""
|
17 |
+
<style>
|
18 |
+
/* Modern clean font for the entire app */
|
19 |
+
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap');
|
20 |
+
|
21 |
+
html, body, [class*="css"] {
|
22 |
+
font-family: 'Inter', sans-serif;
|
23 |
+
}
|
24 |
+
|
25 |
+
/* Header styling */
|
26 |
+
h1 {
|
27 |
+
font-weight: 700;
|
28 |
+
color: #1E3A8A;
|
29 |
+
padding-bottom: 1rem;
|
30 |
+
border-bottom: 2px solid #E5E7EB;
|
31 |
+
margin-bottom: 2rem;
|
32 |
+
}
|
33 |
+
|
34 |
+
/* Text area styling */
|
35 |
+
.stTextArea textarea {
|
36 |
+
border: 1px solid #D1D5DB;
|
37 |
+
border-radius: 8px;
|
38 |
+
font-size: 16px;
|
39 |
+
padding: 12px;
|
40 |
+
background-color: #F9FAFB;
|
41 |
+
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.05);
|
42 |
+
transition: border-color 0.15s ease-in-out, box-shadow 0.15s ease-in-out;
|
43 |
+
}
|
44 |
+
|
45 |
+
.stTextArea textarea:focus {
|
46 |
+
border-color: #3B82F6;
|
47 |
+
box-shadow: 0 0 0 3px rgba(59, 130, 246, 0.3);
|
48 |
+
outline: none;
|
49 |
+
}
|
50 |
+
|
51 |
+
/* Button styling */
|
52 |
+
.stButton button {
|
53 |
+
border-radius: 8px;
|
54 |
+
font-weight: 600;
|
55 |
+
padding: 10px 16px;
|
56 |
+
background-color: #2563EB;
|
57 |
+
color: white;
|
58 |
+
border: none;
|
59 |
+
width: 100%;
|
60 |
+
transition: background-color 0.2s ease;
|
61 |
+
}
|
62 |
+
|
63 |
+
.stButton button:hover {
|
64 |
+
background-color: #1D4ED8;
|
65 |
+
}
|
66 |
+
|
67 |
+
/* Result box styling */
|
68 |
+
.result-box {
|
69 |
+
border-radius: 8px;
|
70 |
+
padding: 20px;
|
71 |
+
margin-top: 24px;
|
72 |
+
text-align: center;
|
73 |
+
background-color: white;
|
74 |
+
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1), 0 1px 2px rgba(0, 0, 0, 0.06);
|
75 |
+
border: 1px solid #E5E7EB;
|
76 |
+
}
|
77 |
+
|
78 |
+
/* Result highlights */
|
79 |
+
.highlight-human {
|
80 |
+
color: #059669;
|
81 |
+
font-weight: 600;
|
82 |
+
background: rgba(5, 150, 105, 0.1);
|
83 |
+
padding: 4px 10px;
|
84 |
+
border-radius: 8px;
|
85 |
+
display: inline-block;
|
86 |
+
}
|
87 |
+
|
88 |
+
.highlight-ai {
|
89 |
+
color: #DC2626;
|
90 |
+
font-weight: 600;
|
91 |
+
background: rgba(220, 38, 38, 0.1);
|
92 |
+
padding: 4px 10px;
|
93 |
+
border-radius: 8px;
|
94 |
+
display: inline-block;
|
95 |
+
}
|
96 |
+
|
97 |
+
/* Footer styling */
|
98 |
+
.footer {
|
99 |
+
text-align: center;
|
100 |
+
margin-top: 40px;
|
101 |
+
padding-top: 20px;
|
102 |
+
border-top: 1px solid #E5E7EB;
|
103 |
+
color: #6B7280;
|
104 |
+
font-size: 14px;
|
105 |
+
}
|
106 |
+
|
107 |
+
/* Progress bar styling */
|
108 |
+
.stProgress > div > div {
|
109 |
+
background-color: #2563EB;
|
110 |
+
}
|
111 |
+
|
112 |
+
/* General spacing */
|
113 |
+
.block-container {
|
114 |
+
padding-top: 2rem;
|
115 |
+
padding-bottom: 2rem;
|
116 |
+
}
|
117 |
+
</style>
|
118 |
+
""", unsafe_allow_html=True)
|
119 |
+
|
120 |
+
# --- Configuration ---
|
121 |
+
MODEL1_PATH = "modernbert.bin"
|
122 |
+
MODEL2_URL = "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed12"
|
123 |
+
MODEL3_URL = "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed22"
|
124 |
+
BASE_MODEL = "answerdotai/ModernBERT-base"
|
125 |
+
NUM_LABELS = 41
|
126 |
+
HUMAN_LABEL_INDEX = 24
|
127 |
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
128 |
+
|
129 |
+
# --- Model Loading Functions ---
|
130 |
+
@st.cache_resource(show_spinner=False)
|
131 |
+
def load_tokenizer(model_name):
|
132 |
+
from transformers import AutoTokenizer
|
133 |
+
return AutoTokenizer.from_pretrained(model_name)
|
134 |
+
|
135 |
+
@st.cache_resource(show_spinner=False)
|
136 |
+
def load_model(model_path_or_url, base_model, num_labels, is_url=False, _device=DEVICE):
|
137 |
+
from transformers import AutoModelForSequenceClassification
|
138 |
+
|
139 |
+
# Load base model architecture
|
140 |
+
model = AutoModelForSequenceClassification.from_pretrained(base_model, num_labels=num_labels)
|
141 |
+
|
142 |
+
try:
|
143 |
+
# Load weights
|
144 |
+
if is_url:
|
145 |
+
state_dict = torch.hub.load_state_dict_from_url(model_path_or_url, map_location=_device, progress=False)
|
146 |
+
else:
|
147 |
+
if not os.path.exists(model_path_or_url):
|
148 |
+
return None
|
149 |
+
state_dict = torch.load(model_path_or_url, map_location=_device, weights_only=False)
|
150 |
+
|
151 |
+
model.load_state_dict(state_dict)
|
152 |
+
model.to(_device).eval()
|
153 |
+
return model
|
154 |
+
except Exception:
|
155 |
+
return None
|
156 |
+
|
157 |
+
# --- Text Processing Functions ---
|
158 |
+
def clean_text(text):
|
159 |
+
if not isinstance(text, str):
|
160 |
+
return ""
|
161 |
+
text = text.replace("\r\n", "\n").replace("\r", "\n")
|
162 |
+
text = re.sub(r"\n\s*\n+", "\n\n", text)
|
163 |
+
text = re.sub(r"[ \t]+", " ", text)
|
164 |
+
text = re.sub(r"(\w+)-\s*\n\s*(\w+)", r"\1\2", text)
|
165 |
+
text = re.sub(r"(?<!\n)\n(?!\n)", " ", text)
|
166 |
+
return text.strip()
|
167 |
+
|
168 |
+
def classify_text(text, tokenizer, model_1, model_2, model_3, device, label_mapping, human_label_index):
|
169 |
+
if not all([model_1, model_2, model_3, tokenizer]):
|
170 |
+
return {"error": True, "message": "Models failed to load properly."}
|
171 |
+
|
172 |
+
cleaned_text = clean_text(text)
|
173 |
+
if not cleaned_text:
|
174 |
+
return None
|
175 |
+
|
176 |
+
try:
|
177 |
+
inputs = tokenizer(
|
178 |
+
cleaned_text,
|
179 |
+
return_tensors="pt",
|
180 |
+
truncation=True,
|
181 |
+
padding=True,
|
182 |
+
max_length=tokenizer.model_max_length
|
183 |
+
).to(device)
|
184 |
+
|
185 |
+
with torch.no_grad():
|
186 |
+
logits_1 = model_1(**inputs).logits
|
187 |
+
logits_2 = model_2(**inputs).logits
|
188 |
+
logits_3 = model_3(**inputs).logits
|
189 |
+
|
190 |
+
softmax_1 = torch.softmax(logits_1, dim=1)
|
191 |
+
softmax_2 = torch.softmax(logits_2, dim=1)
|
192 |
+
softmax_3 = torch.softmax(logits_3, dim=1)
|
193 |
+
|
194 |
+
averaged_probabilities = (softmax_1 + softmax_2 + softmax_3) / 3
|
195 |
+
probabilities = averaged_probabilities[0].cpu()
|
196 |
+
|
197 |
+
if not (0 <= human_label_index < len(probabilities)):
|
198 |
+
return {"error": True, "message": "Configuration error."}
|
199 |
+
|
200 |
+
human_prob = probabilities[human_label_index].item() * 100
|
201 |
+
|
202 |
+
mask = torch.ones_like(probabilities, dtype=torch.bool)
|
203 |
+
mask[human_label_index] = False
|
204 |
+
ai_total_prob = probabilities[mask].sum().item() * 100
|
205 |
+
|
206 |
+
ai_probs_only = probabilities.clone()
|
207 |
+
ai_probs_only[human_label_index] = -float('inf')
|
208 |
+
ai_argmax_index = torch.argmax(ai_probs_only).item()
|
209 |
+
ai_argmax_model = label_mapping.get(ai_argmax_index, f"Unknown AI (Index {ai_argmax_index})")
|
210 |
+
|
211 |
+
if human_prob >= ai_total_prob:
|
212 |
+
return {"is_human": True, "probability": human_prob, "model": "Human"}
|
213 |
+
else:
|
214 |
+
return {"is_human": False, "probability": ai_total_prob, "model": ai_argmax_model}
|
215 |
+
|
216 |
+
except Exception as e:
|
217 |
+
return {"error": True, "message": f"Analysis failed: {str(e)}"}
|
218 |
+
|
219 |
+
# --- Label Mapping ---
|
220 |
+
LABEL_MAPPING = {
|
221 |
+
0: '13B', 1: '30B', 2: '65B', 3: '7B', 4: 'GLM130B', 5: 'bloom_7b',
|
222 |
+
6: 'bloomz', 7: 'cohere', 8: 'davinci', 9: 'dolly', 10: 'dolly-v2-12b',
|
223 |
+
11: 'flan_t5_base', 12: 'flan_t5_large', 13: 'flan_t5_small',
|
224 |
+
14: 'flan_t5_xl', 15: 'flan_t5_xxl', 16: 'gemma-7b-it', 17: 'gemma2-9b-it',
|
225 |
+
18: 'gpt-3.5-turbo', 19: 'gpt-35', 20: 'gpt4', 21: 'gpt4o',
|
226 |
+
22: 'gpt_j', 23: 'gpt_neox', 24: 'human', 25: 'llama3-70b', 26: 'llama3-8b',
|
227 |
+
27: 'mixtral-8x7b', 28: 'opt_1.3b', 29: 'opt_125m', 30: 'opt_13b',
|
228 |
+
31: 'opt_2.7b', 32: 'opt_30b', 33: 'opt_350m', 34: 'opt_6.7b',
|
229 |
+
35: 'opt_iml_30b', 36: 'opt_iml_max_1.3b', 37: 't0_11b', 38: 't0_3b',
|
230 |
+
39: 'text-davinci-002', 40: 'text-davinci-003'
|
231 |
+
}
|
232 |
+
|
233 |
+
# --- Main UI ---
|
234 |
+
st.title("🕵️ AI Text Detector")
|
235 |
+
|
236 |
+
# Initialization with a progress bar
|
237 |
+
with st.spinner(""):
|
238 |
+
# Create a progress bar
|
239 |
+
progress_bar = st.progress(0)
|
240 |
+
st.info("Initializing AI detection models...")
|
241 |
+
|
242 |
+
# Step 1: Load tokenizer
|
243 |
+
progress_bar.progress(20)
|
244 |
+
time.sleep(0.5) # Small delay for visual feedback
|
245 |
+
TOKENIZER = load_tokenizer(BASE_MODEL)
|
246 |
+
|
247 |
+
# Step 2: Load first model
|
248 |
+
progress_bar.progress(40)
|
249 |
+
time.sleep(0.5) # Small delay for visual feedback
|
250 |
+
MODEL_1 = load_model(MODEL1_PATH, BASE_MODEL, NUM_LABELS, is_url=False, _device=DEVICE)
|
251 |
+
|
252 |
+
# Step 3: Load second model
|
253 |
+
progress_bar.progress(60)
|
254 |
+
time.sleep(0.5) # Small delay for visual feedback
|
255 |
+
MODEL_2 = load_model(MODEL2_URL, BASE_MODEL, NUM_LABELS, is_url=True, _device=DEVICE)
|
256 |
+
|
257 |
+
# Step 4: Load third model
|
258 |
+
progress_bar.progress(80)
|
259 |
+
time.sleep(0.5) # Small delay for visual feedback
|
260 |
+
MODEL_3 = load_model(MODEL3_URL, BASE_MODEL, NUM_LABELS, is_url=True, _device=DEVICE)
|
261 |
+
|
262 |
+
# Complete initialization
|
263 |
+
progress_bar.progress(100)
|
264 |
+
time.sleep(0.5) # Small delay for visual feedback
|
265 |
+
|
266 |
+
# Clear the initialization messages
|
267 |
+
st.empty()
|
268 |
+
|
269 |
+
# Check if models loaded successfully
|
270 |
+
if not all([TOKENIZER, MODEL_1, MODEL_2, MODEL_3]):
|
271 |
+
st.error("Failed to initialize one or more AI detection models. Please try refreshing the page.")
|
272 |
+
st.stop()
|
273 |
+
|
274 |
+
# Input area
|
275 |
+
input_text = st.text_area(
|
276 |
+
label="Enter text to analyze:",
|
277 |
+
placeholder="Type or paste your content here for AI detection analysis...",
|
278 |
+
height=200,
|
279 |
+
key="text_input"
|
280 |
+
)
|
281 |
+
|
282 |
+
# Analyze button and output
|
283 |
+
analyze_button = st.button("Analyze Text", key="analyze_button")
|
284 |
+
result_placeholder = st.empty()
|
285 |
+
|
286 |
+
if analyze_button:
|
287 |
+
if input_text and input_text.strip():
|
288 |
+
with st.spinner('Analyzing text...'):
|
289 |
+
classification_result = classify_text(
|
290 |
+
input_text,
|
291 |
+
TOKENIZER,
|
292 |
+
MODEL_1,
|
293 |
+
MODEL_2,
|
294 |
+
MODEL_3,
|
295 |
+
DEVICE,
|
296 |
+
LABEL_MAPPING,
|
297 |
+
HUMAN_LABEL_INDEX
|
298 |
+
)
|
299 |
+
|
300 |
+
# Display result
|
301 |
+
if classification_result is None:
|
302 |
+
result_placeholder.warning("Please enter some text to analyze.")
|
303 |
+
elif classification_result.get("error"):
|
304 |
+
error_message = classification_result.get("message", "An unknown error occurred during analysis.")
|
305 |
+
result_placeholder.error(f"Analysis Error: {error_message}")
|
306 |
+
elif classification_result["is_human"]:
|
307 |
+
prob = classification_result['probability']
|
308 |
+
result_html = (
|
309 |
+
f"<div class='result-box'>"
|
310 |
+
f"<b>The text is</b> <span class='highlight-human'><b>{prob:.2f}%</b> likely <b>Human written</b>.</span>"
|
311 |
+
f"</div>"
|
312 |
+
)
|
313 |
+
result_placeholder.markdown(result_html, unsafe_allow_html=True)
|
314 |
+
else: # AI generated
|
315 |
+
prob = classification_result['probability']
|
316 |
+
model_name = classification_result['model']
|
317 |
+
result_html = (
|
318 |
+
f"<div class='result-box'>"
|
319 |
+
f"<b>The text is</b> <span class='highlight-ai'><b>{prob:.2f}%</b> likely <b>AI generated</b>.</span><br><br>"
|
320 |
+
f"<b>Most Likely AI Model: {model_name}</b>"
|
321 |
+
f"</div>"
|
322 |
+
)
|
323 |
+
result_placeholder.markdown(result_html, unsafe_allow_html=True)
|
324 |
+
else:
|
325 |
+
result_placeholder.warning("Please enter some text to analyze.")
|
326 |
+
|
327 |
+
# Footer
|
328 |
+
st.markdown("<div class='footer'>Developed by Eeman Majumder</div>", unsafe_allow_html=True)
|