import streamlit as st import torch import torch.hub import re import os import time # --- Set Page Config First --- st.set_page_config( page_title="AI Text Detector", layout="centered", initial_sidebar_state="collapsed" ) # --- Improved CSS for a cleaner UI --- st.markdown(""" """, unsafe_allow_html=True) # --- Configuration --- MODEL1_PATH = "modernbert.bin" MODEL2_URL = "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed12" MODEL3_URL = "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed22" BASE_MODEL = "answerdotai/ModernBERT-base" NUM_LABELS = 41 HUMAN_LABEL_INDEX = 24 DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # --- Model Loading Functions --- @st.cache_resource(show_spinner=False) def load_tokenizer(model_name): from transformers import AutoTokenizer return AutoTokenizer.from_pretrained(model_name) @st.cache_resource(show_spinner=False) def load_model(model_path_or_url, base_model, num_labels, is_url=False, _device=DEVICE): from transformers import AutoModelForSequenceClassification # Load base model architecture model = AutoModelForSequenceClassification.from_pretrained(base_model, num_labels=num_labels) try: # Load weights if is_url: state_dict = torch.hub.load_state_dict_from_url(model_path_or_url, map_location=_device, progress=False) else: if not os.path.exists(model_path_or_url): return None state_dict = torch.load(model_path_or_url, map_location=_device, weights_only=False) model.load_state_dict(state_dict) model.to(_device).eval() return model except Exception: return None # --- Text Processing Functions --- def clean_text(text): if not isinstance(text, str): return "" text = text.replace("\r\n", "\n").replace("\r", "\n") text = re.sub(r"\n\s*\n+", "\n\n", text) text = re.sub(r"[ \t]+", " ", text) text = re.sub(r"(\w+)-\s*\n\s*(\w+)", r"\1\2", text) text = re.sub(r"(?= ai_total_prob: return {"is_human": True, "probability": human_prob, "model": "Human"} else: return {"is_human": False, "probability": ai_total_prob, "model": ai_argmax_model} except Exception as e: return {"error": True, "message": f"Analysis failed: {str(e)}"} # --- Label Mapping --- LABEL_MAPPING = { 0: '13B', 1: '30B', 2: '65B', 3: '7B', 4: 'GLM130B', 5: 'bloom_7b', 6: 'bloomz', 7: 'cohere', 8: 'davinci', 9: 'dolly', 10: 'dolly-v2-12b', 11: 'flan_t5_base', 12: 'flan_t5_large', 13: 'flan_t5_small', 14: 'flan_t5_xl', 15: 'flan_t5_xxl', 16: 'gemma-7b-it', 17: 'gemma2-9b-it', 18: 'gpt-3.5-turbo', 19: 'gpt-35', 20: 'gpt4', 21: 'gpt4o', 22: 'gpt_j', 23: 'gpt_neox', 24: 'human', 25: 'llama3-70b', 26: 'llama3-8b', 27: 'mixtral-8x7b', 28: 'opt_1.3b', 29: 'opt_125m', 30: 'opt_13b', 31: 'opt_2.7b', 32: 'opt_30b', 33: 'opt_350m', 34: 'opt_6.7b', 35: 'opt_iml_30b', 36: 'opt_iml_max_1.3b', 37: 't0_11b', 38: 't0_3b', 39: 'text-davinci-002', 40: 'text-davinci-003' } # --- Main UI --- st.title("AI Text Detector") # Initialization with a progress bar with st.spinner(""): # Create a progress bar progress_bar = st.progress(0) st.info("Initializing AI detection models...") # Step 1: Load tokenizer progress_bar.progress(20) time.sleep(0.5) # Small delay for visual feedback TOKENIZER = load_tokenizer(BASE_MODEL) # Step 2: Load first model progress_bar.progress(40) time.sleep(0.5) # Small delay for visual feedback MODEL_1 = load_model(MODEL1_PATH, BASE_MODEL, NUM_LABELS, is_url=False, _device=DEVICE) # Step 3: Load second model progress_bar.progress(60) time.sleep(0.5) # Small delay for visual feedback MODEL_2 = load_model(MODEL2_URL, BASE_MODEL, NUM_LABELS, is_url=True, _device=DEVICE) # Step 4: Load third model progress_bar.progress(80) time.sleep(0.5) # Small delay for visual feedback MODEL_3 = load_model(MODEL3_URL, BASE_MODEL, NUM_LABELS, is_url=True, _device=DEVICE) # Complete initialization progress_bar.progress(100) time.sleep(0.5) # Small delay for visual feedback # Clear the initialization messages st.empty() # Check if models loaded successfully if not all([TOKENIZER, MODEL_1, MODEL_2, MODEL_3]): st.error("Failed to initialize one or more AI detection models. Please try refreshing the page.") st.stop() # Input area input_text = st.text_area( label="Enter text to analyze:", placeholder="Type or paste your content here for AI detection analysis...", height=200, key="text_input" ) # Analyze button and output analyze_button = st.button("Analyze Text", key="analyze_button") result_placeholder = st.empty() if analyze_button: if input_text and input_text.strip(): with st.spinner('Analyzing text...'): classification_result = classify_text( input_text, TOKENIZER, MODEL_1, MODEL_2, MODEL_3, DEVICE, LABEL_MAPPING, HUMAN_LABEL_INDEX ) # Display result if classification_result is None: result_placeholder.warning("Please enter some text to analyze.") elif classification_result.get("error"): error_message = classification_result.get("message", "An unknown error occurred during analysis.") result_placeholder.error(f"Analysis Error: {error_message}") elif classification_result["is_human"]: prob = classification_result['probability'] result_html = ( f"
" f"The text is {prob:.2f}% likely Human written." f"
" ) result_placeholder.markdown(result_html, unsafe_allow_html=True) else: # AI generated prob = classification_result['probability'] model_name = classification_result['model'] result_html = ( f"
" f"The text is {prob:.2f}% likely AI generated.

" f"Most Likely AI Model: {model_name}" f"
" ) result_placeholder.markdown(result_html, unsafe_allow_html=True) else: result_placeholder.warning("Please enter some text to analyze.") # Footer st.markdown("", unsafe_allow_html=True)