ApsidalSolid4's picture
Update app.py
7b9d8b2 verified
raw
history blame
13.7 kB
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import spacy
from typing import List, Dict
import logging
import os
from colorama import init, Fore, Back, Style
# Initialize colorama for colored terminal output
init()
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Constants - matching original implementations
MAX_LENGTH = 512
MODEL_NAME = "microsoft/deberta-v3-small"
WINDOW_SIZE = 17
WINDOW_OVERLAP = 2
CONFIDENCE_THRESHOLD = 0.65
BATCH_SIZE = 16 # Matching original batch size
class TextProcessor:
def __init__(self):
try:
self.nlp = spacy.load("en_core_web_sm")
except OSError:
logger.info("Downloading spacy model...")
os.system("python -m spacy download en_core_web_sm")
self.nlp = spacy.load("en_core_web_sm")
if 'sentencizer' not in self.nlp.pipe_names:
self.nlp.add_pipe('sentencizer')
disabled_pipes = [pipe for pipe in self.nlp.pipe_names if pipe != 'sentencizer']
self.nlp.disable_pipes(*disabled_pipes)
def split_into_sentences(self, text: str) -> List[str]:
doc = self.nlp(text)
return [str(sent).strip() for sent in doc.sents]
def create_windows(self, sentences: List[str], window_size: int, overlap: int) -> List[str]:
if len(sentences) < window_size:
return [" ".join(sentences)]
windows = []
stride = window_size - overlap
for i in range(0, len(sentences) - window_size + 1, stride):
window = sentences[i:i + window_size]
windows.append(" ".join(window))
return windows
def create_centered_windows(self, sentences: List[str], window_size: int) -> tuple[List[str], List[List[int]]]:
"""Create windows centered around each sentence for detailed analysis."""
windows = []
window_sentence_indices = []
for i in range(len(sentences)):
# Calculate window boundaries centered on current sentence
half_window = window_size // 2
start_idx = max(0, i - half_window)
end_idx = min(len(sentences), i + half_window + 1)
# Adjust window if we're near the edges
if start_idx == 0:
end_idx = min(len(sentences), window_size)
elif end_idx == len(sentences):
start_idx = max(0, len(sentences) - window_size)
window = sentences[start_idx:end_idx]
windows.append(" ".join(window))
window_sentence_indices.append(list(range(start_idx, end_idx)))
return windows, window_sentence_indices
class AITextDetector:
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.processor = TextProcessor()
self.tokenizer = None
self.model = None
self._initialize_model()
def _initialize_model(self):
"""Initialize model and tokenizer."""
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
self.model = AutoModelForSequenceClassification.from_pretrained(
MODEL_NAME,
num_labels=2
).to(self.device)
try:
model_path = "model_20250209_184929_acc1.0000.pt"
checkpoint = torch.load(model_path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
logger.info(f"Loaded model from {model_path}")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
def quick_scan(self, text: str) -> Dict:
"""
Quick scan implementation matching the second original program's predict method.
"""
if self.model is None or self.tokenizer is None:
self._initialize_model()
self.model.eval()
sentences = self.processor.split_into_sentences(text)
windows = self.processor.create_windows(sentences, WINDOW_SIZE, WINDOW_OVERLAP)
predictions = []
# Process windows in batches to save memory
for i in range(0, len(windows), BATCH_SIZE):
batch_windows = windows[i:i + BATCH_SIZE]
inputs = self.tokenizer(
batch_windows,
truncation=True,
padding=True,
max_length=MAX_LENGTH,
return_tensors="pt"
).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
probs = F.softmax(outputs.logits, dim=-1)
for idx, window in enumerate(batch_windows):
prediction = {
'window': window,
'human_prob': probs[idx][1].item(),
'ai_prob': probs[idx][0].item(),
'prediction': 'human' if probs[idx][1] > probs[idx][0] else 'ai'
}
predictions.append(prediction)
# Clear memory
del inputs, outputs, probs
if torch.cuda.is_available():
torch.cuda.empty_cache()
return self._aggregate_quick_predictions(predictions)
def _aggregate_quick_predictions(self, predictions: List[Dict]) -> Dict:
"""
Aggregate predictions matching the second original program.
"""
if not predictions:
return {
'human_prob': 0.0,
'ai_prob': 0.0,
'prediction': 'unknown',
'confidence': 0.0,
'num_windows': 0
}
avg_human_prob = sum(p['human_prob'] for p in predictions) / len(predictions)
avg_ai_prob = sum(p['ai_prob'] for p in predictions) / len(predictions)
return {
'human_prob': avg_human_prob,
'ai_prob': avg_ai_prob,
'prediction': 'human' if avg_human_prob > avg_ai_prob else 'ai',
'confidence': max(avg_human_prob, avg_ai_prob),
'num_windows': len(predictions)
}
def detailed_scan(self, text: str) -> Dict:
"""
Detailed scan implementation matching the first original program's
predict_with_sentence_scores method.
"""
if self.model is None or self.tokenizer is None:
self._initialize_model()
self.model.eval()
sentences = self.processor.split_into_sentences(text)
if not sentences:
return {}
# Create centered windows for each sentence
windows, window_sentence_indices = self.processor.create_centered_windows(sentences, WINDOW_SIZE)
# Track scores for each sentence
sentence_appearances = {i: 0 for i in range(len(sentences))}
sentence_scores = {i: {'human_prob': 0.0, 'ai_prob': 0.0} for i in range(len(sentences))}
# Process windows in batches to save memory
for i in range(0, len(windows), BATCH_SIZE):
batch_windows = windows[i:i + BATCH_SIZE]
batch_indices = window_sentence_indices[i:i + BATCH_SIZE]
inputs = self.tokenizer(
batch_windows,
truncation=True,
padding=True,
max_length=MAX_LENGTH,
return_tensors="pt"
).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
probs = F.softmax(outputs.logits, dim=-1)
# Attribute window predictions back to individual sentences
for window_idx, indices in enumerate(batch_indices):
for sent_idx in indices:
sentence_appearances[sent_idx] += 1
sentence_scores[sent_idx]['human_prob'] += probs[window_idx][1].item()
sentence_scores[sent_idx]['ai_prob'] += probs[window_idx][0].item()
# Clear memory
del inputs, outputs, probs
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Average the scores and create final sentence-level predictions
sentence_predictions = []
for i in range(len(sentences)):
if sentence_appearances[i] > 0:
human_prob = sentence_scores[i]['human_prob'] / sentence_appearances[i]
ai_prob = sentence_scores[i]['ai_prob'] / sentence_appearances[i]
sentence_predictions.append({
'sentence': sentences[i],
'human_prob': human_prob,
'ai_prob': ai_prob,
'prediction': 'human' if human_prob > ai_prob else 'ai',
'confidence': max(human_prob, ai_prob)
})
# Generate highlighted text output
highlighted_text = self._generate_highlighted_text(sentence_predictions)
return {
'sentence_predictions': sentence_predictions,
'highlighted_text': highlighted_text,
'full_text': text,
'overall_prediction': self._aggregate_detailed_predictions(sentence_predictions)
}
def _generate_highlighted_text(self, sentence_predictions: List[Dict]) -> str:
"""Generate colored text output with highlighting based on predictions."""
highlighted_parts = []
for pred in sentence_predictions:
sentence = pred['sentence']
confidence = pred['confidence']
if confidence >= CONFIDENCE_THRESHOLD:
if pred['prediction'] == 'human':
highlighted_parts.append(f"{Back.GREEN}{sentence}{Style.RESET_ALL}")
else:
highlighted_parts.append(f"{Back.RED}{sentence}{Style.RESET_ALL}")
else:
# Low confidence predictions get a lighter highlight
if pred['prediction'] == 'human':
highlighted_parts.append(f"{Back.LIGHTGREEN_EX}{sentence}{Style.RESET_ALL}")
else:
highlighted_parts.append(f"{Back.LIGHTRED_EX}{sentence}{Style.RESET_ALL}")
return " ".join(highlighted_parts)
def _aggregate_detailed_predictions(self, predictions: List[Dict]) -> Dict:
"""
Aggregate predictions matching the first original program.
"""
if not predictions:
return {
'human_prob': 0.0,
'ai_prob': 0.0,
'prediction': 'unknown',
'confidence': 0.0,
'num_sentences': 0
}
total_human_prob = sum(p['human_prob'] for p in predictions)
total_ai_prob = sum(p['ai_prob'] for p in predictions)
num_sentences = len(predictions)
avg_human_prob = total_human_prob / num_sentences
avg_ai_prob = total_ai_prob / num_sentences
return {
'human_prob': avg_human_prob,
'ai_prob': avg_ai_prob,
'prediction': 'human' if avg_human_prob > avg_ai_prob else 'ai',
'confidence': max(avg_human_prob, avg_ai_prob),
'num_sentences': num_sentences
}
def main():
try:
detector = AITextDetector()
while True:
print("\nAI Text Detector")
print("===============")
print("1. Quick Scan")
print("2. Detailed Scan")
print("3. Exit")
choice = input("\nSelect an option (1-3): ").strip()
if choice == "3":
break
if choice not in ["1", "2"]:
print("Invalid choice. Please select 1, 2, or 3.")
continue
text = input("\nEnter text to analyze: ").strip()
if choice == "1":
# Quick scan
result = detector.quick_scan(text)
print("\nQuick Scan Results:")
print("==================")
print(f"Prediction: {result['prediction'].upper()}")
print(f"Confidence: {result['confidence']*100:.1f}%")
print(f"Human Probability: {result['human_prob']*100:.1f}%")
print(f"AI Probability: {result['ai_prob']*100:.1f}%")
print(f"Number of windows analyzed: {result['num_windows']}")
else:
# Detailed scan
result = detector.detailed_scan(text)
print("\nDetailed Analysis:")
print("=================")
# Print sentence-level predictions
for pred in result['sentence_predictions']:
confidence = pred['confidence'] * 100
print(f"\nSentence: {pred['sentence']}")
print(f"Prediction: {pred['prediction'].upper()}")
print(f"Confidence: {confidence:.1f}%")
# Print highlighted text
print("\nHighlighted Text Analysis:")
print("=========================")
print(result['highlighted_text'])
# Print final prediction
final_pred = result['overall_prediction']
print(f"\nFINAL PREDICTION: {final_pred['prediction'].upper()}")
print(f"Overall confidence: {final_pred['confidence']*100:.1f}%")
print(f"Number of sentences analyzed: {final_pred['num_sentences']}")
except Exception as e:
logger.error(f"An error occurred: {e}")
raise
if __name__ == "__main__":
main()