import torch import torch.nn.functional as F from transformers import AutoTokenizer, AutoModelForSequenceClassification import spacy from typing import List, Dict import logging import os from colorama import init, Fore, Back, Style # Initialize colorama for colored terminal output init() # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Constants - matching original implementations MAX_LENGTH = 512 MODEL_NAME = "microsoft/deberta-v3-small" WINDOW_SIZE = 17 WINDOW_OVERLAP = 2 CONFIDENCE_THRESHOLD = 0.65 BATCH_SIZE = 16 # Matching original batch size class TextProcessor: def __init__(self): try: self.nlp = spacy.load("en_core_web_sm") except OSError: logger.info("Downloading spacy model...") os.system("python -m spacy download en_core_web_sm") self.nlp = spacy.load("en_core_web_sm") if 'sentencizer' not in self.nlp.pipe_names: self.nlp.add_pipe('sentencizer') disabled_pipes = [pipe for pipe in self.nlp.pipe_names if pipe != 'sentencizer'] self.nlp.disable_pipes(*disabled_pipes) def split_into_sentences(self, text: str) -> List[str]: doc = self.nlp(text) return [str(sent).strip() for sent in doc.sents] def create_windows(self, sentences: List[str], window_size: int, overlap: int) -> List[str]: if len(sentences) < window_size: return [" ".join(sentences)] windows = [] stride = window_size - overlap for i in range(0, len(sentences) - window_size + 1, stride): window = sentences[i:i + window_size] windows.append(" ".join(window)) return windows def create_centered_windows(self, sentences: List[str], window_size: int) -> tuple[List[str], List[List[int]]]: """Create windows centered around each sentence for detailed analysis.""" windows = [] window_sentence_indices = [] for i in range(len(sentences)): # Calculate window boundaries centered on current sentence half_window = window_size // 2 start_idx = max(0, i - half_window) end_idx = min(len(sentences), i + half_window + 1) # Adjust window if we're near the edges if start_idx == 0: end_idx = min(len(sentences), window_size) elif end_idx == len(sentences): start_idx = max(0, len(sentences) - window_size) window = sentences[start_idx:end_idx] windows.append(" ".join(window)) window_sentence_indices.append(list(range(start_idx, end_idx))) return windows, window_sentence_indices class AITextDetector: def __init__(self): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.processor = TextProcessor() self.tokenizer = None self.model = None self._initialize_model() def _initialize_model(self): """Initialize model and tokenizer.""" self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) self.model = AutoModelForSequenceClassification.from_pretrained( MODEL_NAME, num_labels=2 ).to(self.device) try: model_path = "model_20250209_184929_acc1.0000.pt" checkpoint = torch.load(model_path, map_location=self.device) self.model.load_state_dict(checkpoint['model_state_dict']) logger.info(f"Loaded model from {model_path}") except Exception as e: logger.error(f"Failed to load model: {e}") raise def quick_scan(self, text: str) -> Dict: """ Quick scan implementation matching the second original program's predict method. """ if self.model is None or self.tokenizer is None: self._initialize_model() self.model.eval() sentences = self.processor.split_into_sentences(text) windows = self.processor.create_windows(sentences, WINDOW_SIZE, WINDOW_OVERLAP) predictions = [] # Process windows in batches to save memory for i in range(0, len(windows), BATCH_SIZE): batch_windows = windows[i:i + BATCH_SIZE] inputs = self.tokenizer( batch_windows, truncation=True, padding=True, max_length=MAX_LENGTH, return_tensors="pt" ).to(self.device) with torch.no_grad(): outputs = self.model(**inputs) probs = F.softmax(outputs.logits, dim=-1) for idx, window in enumerate(batch_windows): prediction = { 'window': window, 'human_prob': probs[idx][1].item(), 'ai_prob': probs[idx][0].item(), 'prediction': 'human' if probs[idx][1] > probs[idx][0] else 'ai' } predictions.append(prediction) # Clear memory del inputs, outputs, probs if torch.cuda.is_available(): torch.cuda.empty_cache() return self._aggregate_quick_predictions(predictions) def _aggregate_quick_predictions(self, predictions: List[Dict]) -> Dict: """ Aggregate predictions matching the second original program. """ if not predictions: return { 'human_prob': 0.0, 'ai_prob': 0.0, 'prediction': 'unknown', 'confidence': 0.0, 'num_windows': 0 } avg_human_prob = sum(p['human_prob'] for p in predictions) / len(predictions) avg_ai_prob = sum(p['ai_prob'] for p in predictions) / len(predictions) return { 'human_prob': avg_human_prob, 'ai_prob': avg_ai_prob, 'prediction': 'human' if avg_human_prob > avg_ai_prob else 'ai', 'confidence': max(avg_human_prob, avg_ai_prob), 'num_windows': len(predictions) } def detailed_scan(self, text: str) -> Dict: """ Detailed scan implementation matching the first original program's predict_with_sentence_scores method. """ if self.model is None or self.tokenizer is None: self._initialize_model() self.model.eval() sentences = self.processor.split_into_sentences(text) if not sentences: return {} # Create centered windows for each sentence windows, window_sentence_indices = self.processor.create_centered_windows(sentences, WINDOW_SIZE) # Track scores for each sentence sentence_appearances = {i: 0 for i in range(len(sentences))} sentence_scores = {i: {'human_prob': 0.0, 'ai_prob': 0.0} for i in range(len(sentences))} # Process windows in batches to save memory for i in range(0, len(windows), BATCH_SIZE): batch_windows = windows[i:i + BATCH_SIZE] batch_indices = window_sentence_indices[i:i + BATCH_SIZE] inputs = self.tokenizer( batch_windows, truncation=True, padding=True, max_length=MAX_LENGTH, return_tensors="pt" ).to(self.device) with torch.no_grad(): outputs = self.model(**inputs) probs = F.softmax(outputs.logits, dim=-1) # Attribute window predictions back to individual sentences for window_idx, indices in enumerate(batch_indices): for sent_idx in indices: sentence_appearances[sent_idx] += 1 sentence_scores[sent_idx]['human_prob'] += probs[window_idx][1].item() sentence_scores[sent_idx]['ai_prob'] += probs[window_idx][0].item() # Clear memory del inputs, outputs, probs if torch.cuda.is_available(): torch.cuda.empty_cache() # Average the scores and create final sentence-level predictions sentence_predictions = [] for i in range(len(sentences)): if sentence_appearances[i] > 0: human_prob = sentence_scores[i]['human_prob'] / sentence_appearances[i] ai_prob = sentence_scores[i]['ai_prob'] / sentence_appearances[i] sentence_predictions.append({ 'sentence': sentences[i], 'human_prob': human_prob, 'ai_prob': ai_prob, 'prediction': 'human' if human_prob > ai_prob else 'ai', 'confidence': max(human_prob, ai_prob) }) # Generate highlighted text output highlighted_text = self._generate_highlighted_text(sentence_predictions) return { 'sentence_predictions': sentence_predictions, 'highlighted_text': highlighted_text, 'full_text': text, 'overall_prediction': self._aggregate_detailed_predictions(sentence_predictions) } def _generate_highlighted_text(self, sentence_predictions: List[Dict]) -> str: """Generate colored text output with highlighting based on predictions.""" highlighted_parts = [] for pred in sentence_predictions: sentence = pred['sentence'] confidence = pred['confidence'] if confidence >= CONFIDENCE_THRESHOLD: if pred['prediction'] == 'human': highlighted_parts.append(f"{Back.GREEN}{sentence}{Style.RESET_ALL}") else: highlighted_parts.append(f"{Back.RED}{sentence}{Style.RESET_ALL}") else: # Low confidence predictions get a lighter highlight if pred['prediction'] == 'human': highlighted_parts.append(f"{Back.LIGHTGREEN_EX}{sentence}{Style.RESET_ALL}") else: highlighted_parts.append(f"{Back.LIGHTRED_EX}{sentence}{Style.RESET_ALL}") return " ".join(highlighted_parts) def _aggregate_detailed_predictions(self, predictions: List[Dict]) -> Dict: """ Aggregate predictions matching the first original program. """ if not predictions: return { 'human_prob': 0.0, 'ai_prob': 0.0, 'prediction': 'unknown', 'confidence': 0.0, 'num_sentences': 0 } total_human_prob = sum(p['human_prob'] for p in predictions) total_ai_prob = sum(p['ai_prob'] for p in predictions) num_sentences = len(predictions) avg_human_prob = total_human_prob / num_sentences avg_ai_prob = total_ai_prob / num_sentences return { 'human_prob': avg_human_prob, 'ai_prob': avg_ai_prob, 'prediction': 'human' if avg_human_prob > avg_ai_prob else 'ai', 'confidence': max(avg_human_prob, avg_ai_prob), 'num_sentences': num_sentences } def main(): try: detector = AITextDetector() while True: print("\nAI Text Detector") print("===============") print("1. Quick Scan") print("2. Detailed Scan") print("3. Exit") choice = input("\nSelect an option (1-3): ").strip() if choice == "3": break if choice not in ["1", "2"]: print("Invalid choice. Please select 1, 2, or 3.") continue text = input("\nEnter text to analyze: ").strip() if choice == "1": # Quick scan result = detector.quick_scan(text) print("\nQuick Scan Results:") print("==================") print(f"Prediction: {result['prediction'].upper()}") print(f"Confidence: {result['confidence']*100:.1f}%") print(f"Human Probability: {result['human_prob']*100:.1f}%") print(f"AI Probability: {result['ai_prob']*100:.1f}%") print(f"Number of windows analyzed: {result['num_windows']}") else: # Detailed scan result = detector.detailed_scan(text) print("\nDetailed Analysis:") print("=================") # Print sentence-level predictions for pred in result['sentence_predictions']: confidence = pred['confidence'] * 100 print(f"\nSentence: {pred['sentence']}") print(f"Prediction: {pred['prediction'].upper()}") print(f"Confidence: {confidence:.1f}%") # Print highlighted text print("\nHighlighted Text Analysis:") print("=========================") print(result['highlighted_text']) # Print final prediction final_pred = result['overall_prediction'] print(f"\nFINAL PREDICTION: {final_pred['prediction'].upper()}") print(f"Overall confidence: {final_pred['confidence']*100:.1f}%") print(f"Number of sentences analyzed: {final_pred['num_sentences']}") except Exception as e: logger.error(f"An error occurred: {e}") raise if __name__ == "__main__": main()