import streamlit as st import torch import torch.hub import re import os # --- Set Page Config First --- st.set_page_config(page_title="AI Text Detector", layout="centered") # --- Configuration --- MODEL1_PATH = "modernbert.bin" # Make sure this file is in the same directory or provide the full path MODEL2_URL = "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed12" MODEL3_URL = "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed22" BASE_MODEL = "answerdotai/ModernBERT-base" NUM_LABELS = 41 # --- Device Setup --- @st.cache_resource def get_device(): """Gets the appropriate torch device.""" return torch.device('cuda' if torch.cuda.is_available() else 'cpu') # --- Inject Custom CSS for highlighting --- st.markdown(""" """, unsafe_allow_html=True) DEVICE = get_device() # Now, we can safely continue with the rest of the code # --- Model and Tokenizer Loading (Cached) --- @st.cache_resource def load_tokenizer(model_name): """Loads the tokenizer.""" st.info(f"Loading tokenizer: {model_name}...") from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(model_name) st.info("Tokenizer loaded.") return tokenizer @st.cache_resource def load_model(model_path_or_url, base_model, num_labels, is_url=False, _device=DEVICE): """Loads a sequence classification model from local path or URL.""" from transformers import AutoModelForSequenceClassification model_name = os.path.basename(model_path_or_url) if not is_url else model_path_or_url.split('/')[-1] st.info(f"Loading model structure: {base_model}...") # Load the base model architecture with the desired number of labels. # The classification head will be randomly initialized initially. model = AutoModelForSequenceClassification.from_pretrained(base_model, num_labels=num_labels) st.info(f"Loading model weights: {model_name}...") try: if is_url: # Load state dict from URL (usually safer as HF handles download/caching) state_dict = torch.hub.load_state_dict_from_url(model_path_or_url, map_location=_device, progress=True) else: # Load state dict from local file if not os.path.exists(model_path_or_url): st.error(f"Model file not found at {model_path_or_url}. Please ensure it's in the correct location.") st.stop() # Stop execution if local model is missing # --- FIX APPLIED HERE --- # Load state dict from local path. # Set weights_only=False because the .bin file likely contains more than just weights # and PyTorch 2.6+ defaults to weights_only=True for security. # WARNING: Only use weights_only=False if you TRUST the source of the .bin file, # as it can execute arbitrary code. st.warning(f"Loading '{model_name}' with weights_only=False. Ensure this file is from a trusted source.") state_dict = torch.load(model_path_or_url, map_location=_device, weights_only=False) # --- END FIX --- # Load the state dictionary into the model structure. # This should overwrite the randomly initialized classification head # if the state_dict contains the trained classifier weights. # The warning "Some weights were not initialized..." might still appear # but is often ignorable if loading succeeds without key errors. model.load_state_dict(state_dict) model.to(_device).eval() # Set model to evaluation mode st.info(f"Model {model_name} loaded and moved to {_device}.") return model except Exception as e: st.error(f"Error loading model {model_name}: {e}") # Display the full traceback for debugging if needed # import traceback # st.error(traceback.format_exc()) st.stop() # Stop execution on model loading error # --- Label Mapping --- LABEL_MAPPING = { 0: '13B', 1: '30B', 2: '65B', 3: '7B', 4: 'GLM130B', 5: 'bloom_7b', 6: 'bloomz', 7: 'cohere', 8: 'davinci', 9: 'dolly', 10: 'dolly-v2-12b', 11: 'flan_t5_base', 12: 'flan_t5_large', 13: 'flan_t5_small', 14: 'flan_t5_xl', 15: 'flan_t5_xxl', 16: 'gemma-7b-it', 17: 'gemma2-9b-it', 18: 'gpt-3.5-turbo', 19: 'gpt-35', 20: 'gpt4', 21: 'gpt4o', 22: 'gpt_j', 23: 'gpt_neox', 24: 'human', 25: 'llama3-70b', 26: 'llama3-8b', 27: 'mixtral-8x7b', 28: 'opt_1.3b', 29: 'opt_125m', 30: 'opt_13b', 31: 'opt_2.7b', 32: 'opt_30b', 33: 'opt_350m', 34: 'opt_6.7b', 35: 'opt_iml_30b', 36: 'opt_iml_max_1.3b', 37: 't0_11b', 38: 't0_3b', 39: 'text-davinci-002', 40: 'text-davinci-003' } HUMAN_LABEL_INDEX = 24 # Assuming 'human' is always index 24 # --- Text Processing Functions --- def clean_text(text): """Cleans the input text using regex.""" if not isinstance(text, str): # Basic type check return "" text = text.replace("\r\n", "\n").replace("\r", "\n") text = re.sub(r"\n\s*\n+", "\n\n", text) text = re.sub(r"[ \t]+", " ", text) # Improved handling for hyphenated words broken by newline: handles potential space after hyphen text = re.sub(r"(\w+)-\s*\n\s*(\w+)", r"\1\2", text) text = re.sub(r"(?= ai_total_prob: return {"is_human": True, "probability": human_prob, "model": "Human"} else: # Return the total AI probability, but name the single most likely AI model return {"is_human": False, "probability": ai_total_prob, "model": ai_argmax_model} except Exception as e: st.error(f"Error during model inference: {e}") # import traceback # st.error(traceback.format_exc()) # Uncomment for detailed traceback during debugging return {"error": True, "message": f"Inference failed: {e}"} # Main UI section st.title("AI Text Detector") # Load models and tokenizer TOKENIZER = load_tokenizer(BASE_MODEL) MODEL_1 = load_model(MODEL1_PATH, BASE_MODEL, NUM_LABELS, is_url=False, _device=DEVICE) MODEL_2 = load_model(MODEL2_URL, BASE_MODEL, NUM_LABELS, is_url=True, _device=DEVICE) MODEL_3 = load_model(MODEL3_URL, BASE_MODEL, NUM_LABELS, is_url=True, _device=DEVICE) # --- Input Area --- input_text = st.text_area( label="Enter text to analyze:", placeholder="Type or paste your content here...", height=200, key="text_input" ) # --- Analyze Button and Output --- analyze_button = st.button("Analyze Text", key="analyze_button") result_placeholder = st.empty() # Create a placeholder for the result output if analyze_button: # Check if input_text is not None and not just whitespace AFTER stripping if input_text and input_text.strip(): with st.spinner('Analyzing text... This might take a moment.'): # --- Perform Classification --- classification_result = classify_text( input_text, TOKENIZER, MODEL_1, MODEL_2, MODEL_3, DEVICE, LABEL_MAPPING, HUMAN_LABEL_INDEX ) # --- Display Result --- if classification_result is None: # This case handles empty/whitespace input after cleaning result_placeholder.warning("Please enter some text to analyze.") elif classification_result.get("error"): error_message = classification_result.get("message", "An unknown error occurred during analysis.") result_placeholder.error(f"Analysis Error: {error_message}") elif classification_result["is_human"]: prob = classification_result['probability'] result_html = ( f"