aigtd / app.py
Eemansleepdeprived's picture
Update app.py
e289c6c verified
import streamlit as st
import torch
import torch.hub
import re
import os
import time
# --- Set Page Config First ---
st.set_page_config(
page_title="AI Text Detector",
layout="centered",
initial_sidebar_state="collapsed"
)
# --- Improved CSS for a cleaner UI ---
st.markdown("""
<style>
/* Modern clean font for the entire app */
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap');
html, body, [class*="css"] {
font-family: 'Inter', sans-serif;
}
/* Header styling */
h1 {
font-weight: 700;
color: #1E3A8A;
padding-bottom: 1rem;
border-bottom: 2px solid #E5E7EB;
margin-bottom: 2rem;
}
/* Text area styling */
.stTextArea textarea {
border: 1px solid #D1D5DB;
border-radius: 8px;
font-size: 16px;
padding: 12px;
background-color: #F9FAFB;
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.05);
transition: border-color 0.15s ease-in-out, box-shadow 0.15s ease-in-out;
}
.stTextArea textarea:focus {
border-color: #3B82F6;
box-shadow: 0 0 0 3px rgba(59, 130, 246, 0.3);
outline: none;
}
/* Button styling */
.stButton button {
border-radius: 8px;
font-weight: 600;
padding: 10px 16px;
background-color: #2563EB;
color: white;
border: none;
width: 100%;
transition: background-color 0.2s ease;
}
.stButton button:hover {
background-color: #1D4ED8;
}
/* Result box styling */
.result-box {
border-radius: 8px;
padding: 20px;
margin-top: 24px;
text-align: center;
background-color: white;
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1), 0 1px 2px rgba(0, 0, 0, 0.06);
border: 1px solid #E5E7EB;
}
/* Result highlights */
.highlight-human {
color: #059669;
font-weight: 600;
background: rgba(5, 150, 105, 0.1);
padding: 4px 10px;
border-radius: 8px;
display: inline-block;
}
.highlight-ai {
color: #DC2626;
font-weight: 600;
background: rgba(220, 38, 38, 0.1);
padding: 4px 10px;
border-radius: 8px;
display: inline-block;
}
/* Footer styling */
.footer {
text-align: center;
margin-top: 40px;
padding-top: 20px;
border-top: 1px solid #E5E7EB;
color: #6B7280;
font-size: 14px;
}
/* Progress bar styling */
.stProgress > div > div {
background-color: #2563EB;
}
/* General spacing */
.block-container {
padding-top: 2rem;
padding-bottom: 2rem;
}
</style>
""", unsafe_allow_html=True)
# --- Configuration ---
MODEL1_PATH = "modernbert.bin"
MODEL2_URL = "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed12"
MODEL3_URL = "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed22"
BASE_MODEL = "answerdotai/ModernBERT-base"
NUM_LABELS = 41
HUMAN_LABEL_INDEX = 24
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# --- Model Loading Functions ---
@st.cache_resource(show_spinner=False)
def load_tokenizer(model_name):
from transformers import AutoTokenizer
return AutoTokenizer.from_pretrained(model_name)
@st.cache_resource(show_spinner=False)
def load_model(model_path_or_url, base_model, num_labels, is_url=False, _device=DEVICE):
from transformers import AutoModelForSequenceClassification
# Load base model architecture
model = AutoModelForSequenceClassification.from_pretrained(base_model, num_labels=num_labels)
try:
# Load weights
if is_url:
state_dict = torch.hub.load_state_dict_from_url(model_path_or_url, map_location=_device, progress=False)
else:
if not os.path.exists(model_path_or_url):
return None
state_dict = torch.load(model_path_or_url, map_location=_device, weights_only=False)
model.load_state_dict(state_dict)
model.to(_device).eval()
return model
except Exception:
return None
# --- Text Processing Functions ---
def clean_text(text):
if not isinstance(text, str):
return ""
text = text.replace("\r\n", "\n").replace("\r", "\n")
text = re.sub(r"\n\s*\n+", "\n\n", text)
text = re.sub(r"[ \t]+", " ", text)
text = re.sub(r"(\w+)-\s*\n\s*(\w+)", r"\1\2", text)
text = re.sub(r"(?<!\n)\n(?!\n)", " ", text)
return text.strip()
def classify_text(text, tokenizer, model_1, model_2, model_3, device, label_mapping, human_label_index):
if not all([model_1, model_2, model_3, tokenizer]):
return {"error": True, "message": "Models failed to load properly."}
cleaned_text = clean_text(text)
if not cleaned_text:
return None
try:
inputs = tokenizer(
cleaned_text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=tokenizer.model_max_length
).to(device)
with torch.no_grad():
logits_1 = model_1(**inputs).logits
logits_2 = model_2(**inputs).logits
logits_3 = model_3(**inputs).logits
softmax_1 = torch.softmax(logits_1, dim=1)
softmax_2 = torch.softmax(logits_2, dim=1)
softmax_3 = torch.softmax(logits_3, dim=1)
averaged_probabilities = (softmax_1 + softmax_2 + softmax_3) / 3
probabilities = averaged_probabilities[0].cpu()
if not (0 <= human_label_index < len(probabilities)):
return {"error": True, "message": "Configuration error."}
human_prob = probabilities[human_label_index].item() * 100
mask = torch.ones_like(probabilities, dtype=torch.bool)
mask[human_label_index] = False
ai_total_prob = probabilities[mask].sum().item() * 100
ai_probs_only = probabilities.clone()
ai_probs_only[human_label_index] = -float('inf')
ai_argmax_index = torch.argmax(ai_probs_only).item()
ai_argmax_model = label_mapping.get(ai_argmax_index, f"Unknown AI (Index {ai_argmax_index})")
if human_prob >= ai_total_prob:
return {"is_human": True, "probability": human_prob, "model": "Human"}
else:
return {"is_human": False, "probability": ai_total_prob, "model": ai_argmax_model}
except Exception as e:
return {"error": True, "message": f"Analysis failed: {str(e)}"}
# --- Label Mapping ---
LABEL_MAPPING = {
0: '13B', 1: '30B', 2: '65B', 3: '7B', 4: 'GLM130B', 5: 'bloom_7b',
6: 'bloomz', 7: 'cohere', 8: 'davinci', 9: 'dolly', 10: 'dolly-v2-12b',
11: 'flan_t5_base', 12: 'flan_t5_large', 13: 'flan_t5_small',
14: 'flan_t5_xl', 15: 'flan_t5_xxl', 16: 'gemma-7b-it', 17: 'gemma2-9b-it',
18: 'gpt-3.5-turbo', 19: 'gpt-35', 20: 'gpt4', 21: 'gpt4o',
22: 'gpt_j', 23: 'gpt_neox', 24: 'human', 25: 'llama3-70b', 26: 'llama3-8b',
27: 'mixtral-8x7b', 28: 'opt_1.3b', 29: 'opt_125m', 30: 'opt_13b',
31: 'opt_2.7b', 32: 'opt_30b', 33: 'opt_350m', 34: 'opt_6.7b',
35: 'opt_iml_30b', 36: 'opt_iml_max_1.3b', 37: 't0_11b', 38: 't0_3b',
39: 'text-davinci-002', 40: 'text-davinci-003'
}
# --- Main UI ---
st.title("AI Text Detector")
# Initialization with a progress bar
with st.spinner(""):
# Create a progress bar
progress_bar = st.progress(0)
st.info("Initializing AI detection models...")
# Step 1: Load tokenizer
progress_bar.progress(20)
time.sleep(0.5) # Small delay for visual feedback
TOKENIZER = load_tokenizer(BASE_MODEL)
# Step 2: Load first model
progress_bar.progress(40)
time.sleep(0.5) # Small delay for visual feedback
MODEL_1 = load_model(MODEL1_PATH, BASE_MODEL, NUM_LABELS, is_url=False, _device=DEVICE)
# Step 3: Load second model
progress_bar.progress(60)
time.sleep(0.5) # Small delay for visual feedback
MODEL_2 = load_model(MODEL2_URL, BASE_MODEL, NUM_LABELS, is_url=True, _device=DEVICE)
# Step 4: Load third model
progress_bar.progress(80)
time.sleep(0.5) # Small delay for visual feedback
MODEL_3 = load_model(MODEL3_URL, BASE_MODEL, NUM_LABELS, is_url=True, _device=DEVICE)
# Complete initialization
progress_bar.progress(100)
time.sleep(0.5) # Small delay for visual feedback
# Clear the initialization messages
st.empty()
# Check if models loaded successfully
if not all([TOKENIZER, MODEL_1, MODEL_2, MODEL_3]):
st.error("Failed to initialize one or more AI detection models. Please try refreshing the page.")
st.stop()
# Input area
input_text = st.text_area(
label="Enter text to analyze:",
placeholder="Type or paste your content here for AI detection analysis...",
height=200,
key="text_input"
)
# Analyze button and output
analyze_button = st.button("Analyze Text", key="analyze_button")
result_placeholder = st.empty()
if analyze_button:
if input_text and input_text.strip():
with st.spinner('Analyzing text...'):
classification_result = classify_text(
input_text,
TOKENIZER,
MODEL_1,
MODEL_2,
MODEL_3,
DEVICE,
LABEL_MAPPING,
HUMAN_LABEL_INDEX
)
# Display result
if classification_result is None:
result_placeholder.warning("Please enter some text to analyze.")
elif classification_result.get("error"):
error_message = classification_result.get("message", "An unknown error occurred during analysis.")
result_placeholder.error(f"Analysis Error: {error_message}")
elif classification_result["is_human"]:
prob = classification_result['probability']
result_html = (
f"<div class='result-box'>"
f"<b>The text is</b> <span class='highlight-human'><b>{prob:.2f}%</b> likely <b>Human written</b>.</span>"
f"</div>"
)
result_placeholder.markdown(result_html, unsafe_allow_html=True)
else: # AI generated
prob = classification_result['probability']
model_name = classification_result['model']
result_html = (
f"<div class='result-box'>"
f"<b>The text is</b> <span class='highlight-ai'><b>{prob:.2f}%</b> likely <b>AI generated</b>.</span><br><br>"
f"<b>Most Likely AI Model: {model_name}</b>"
f"</div>"
)
result_placeholder.markdown(result_html, unsafe_allow_html=True)
else:
result_placeholder.warning("Please enter some text to analyze.")
# Footer
st.markdown("<div class='footer'>Developed by Eeman Majumder</div>", unsafe_allow_html=True)