ApsidalSolid4's picture
Update app.py
f69e7ad verified
raw
history blame
20.9 kB
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch.nn.functional as F
import spacy
from typing import List, Dict, Tuple
import logging
import os
import gradio as gr
from fastapi.middleware.cors import CORSMiddleware
from concurrent.futures import ThreadPoolExecutor
from functools import partial
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Constants
MAX_LENGTH = 512
MODEL_NAME = "microsoft/deberta-v3-small"
WINDOW_SIZE = 6
WINDOW_OVERLAP = 2
CONFIDENCE_THRESHOLD = 0.65
BATCH_SIZE = 8 # Reduced batch size for CPU
MAX_WORKERS = 4 # Number of worker threads for processing
class TextWindowProcessor:
def __init__(self):
try:
self.nlp = spacy.load("en_core_web_sm")
except OSError:
logger.info("Downloading spacy model...")
spacy.cli.download("en_core_web_sm")
self.nlp = spacy.load("en_core_web_sm")
if 'sentencizer' not in self.nlp.pipe_names:
self.nlp.add_pipe('sentencizer')
disabled_pipes = [pipe for pipe in self.nlp.pipe_names if pipe != 'sentencizer']
self.nlp.disable_pipes(*disabled_pipes)
# Initialize thread pool for parallel processing
self.executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
def split_into_sentences(self, text: str) -> List[str]:
doc = self.nlp(text)
return [str(sent).strip() for sent in doc.sents]
def create_windows(self, sentences: List[str], window_size: int, overlap: int) -> List[str]:
if len(sentences) < window_size:
return [" ".join(sentences)]
windows = []
stride = window_size - overlap
for i in range(0, len(sentences) - window_size + 1, stride):
window = sentences[i:i + window_size]
windows.append(" ".join(window))
return windows
def create_centered_windows(self, sentences: List[str], window_size: int) -> Tuple[List[str], List[List[int]]]:
"""Create windows with better boundary handling"""
windows = []
window_sentence_indices = []
for i in range(len(sentences)):
# Calculate window boundaries centered on current sentence
half_window = window_size // 2
start_idx = max(0, i - half_window)
end_idx = min(len(sentences), i + half_window + 1)
# Create the window
window = sentences[start_idx:end_idx]
windows.append(" ".join(window))
window_sentence_indices.append(list(range(start_idx, end_idx)))
return windows, window_sentence_indices
class TextClassifier:
def __init__(self):
# Set thread configuration before any model loading or parallel work
if not torch.cuda.is_available():
torch.set_num_threads(MAX_WORKERS)
torch.set_num_interop_threads(MAX_WORKERS)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model_name = MODEL_NAME
self.tokenizer = None
self.model = None
self.processor = TextWindowProcessor()
self.initialize_model()
def initialize_model(self):
"""Initialize the model and tokenizer."""
logger.info("Initializing model and tokenizer...")
from transformers import DebertaV2TokenizerFast
self.tokenizer = DebertaV2TokenizerFast.from_pretrained(
self.model_name,
model_max_length=MAX_LENGTH,
use_fast=True
)
self.model = AutoModelForSequenceClassification.from_pretrained(
self.model_name,
num_labels=2
).to(self.device)
model_path = "model_20250209_184929_acc1.0000.pt"
if os.path.exists(model_path):
logger.info(f"Loading custom model from {model_path}")
checkpoint = torch.load(model_path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
else:
logger.warning("Custom model file not found. Using base model.")
self.model.eval()
def quick_scan(self, text: str) -> Dict:
"""Perform a quick scan using simple window analysis."""
if not text.strip():
return {
'prediction': 'unknown',
'confidence': 0.0,
'num_windows': 0
}
sentences = self.processor.split_into_sentences(text)
windows = self.processor.create_windows(sentences, WINDOW_SIZE, WINDOW_OVERLAP)
predictions = []
# Process windows in smaller batches for CPU efficiency
for i in range(0, len(windows), BATCH_SIZE):
batch_windows = windows[i:i + BATCH_SIZE]
inputs = self.tokenizer(
batch_windows,
truncation=True,
padding=True,
max_length=MAX_LENGTH,
return_tensors="pt"
).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
probs = F.softmax(outputs.logits, dim=-1)
for idx, window in enumerate(batch_windows):
prediction = {
'window': window,
'human_prob': probs[idx][1].item(),
'ai_prob': probs[idx][0].item(),
'prediction': 'human' if probs[idx][1] > probs[idx][0] else 'ai'
}
predictions.append(prediction)
# Clean up GPU memory if available
del inputs, outputs, probs
if torch.cuda.is_available():
torch.cuda.empty_cache()
if not predictions:
return {
'prediction': 'unknown',
'confidence': 0.0,
'num_windows': 0
}
avg_human_prob = sum(p['human_prob'] for p in predictions) / len(predictions)
avg_ai_prob = sum(p['ai_prob'] for p in predictions) / len(predictions)
return {
'prediction': 'human' if avg_human_prob > avg_ai_prob else 'ai',
'confidence': max(avg_human_prob, avg_ai_prob),
'num_windows': len(predictions)
}
# def detailed_scan(self, text: str) -> Dict:
# """Original prediction method with modified window handling"""
# if self.model is None or self.tokenizer is None:
# self.load_model()
# self.model.eval()
# sentences = self.processor.split_into_sentences(text)
# if not sentences:
# return {}
# # Create centered windows for each sentence
# windows, window_sentence_indices = self.processor.create_centered_windows(sentences, WINDOW_SIZE)
# # Track scores for each sentence
# sentence_appearances = {i: 0 for i in range(len(sentences))}
# sentence_scores = {i: {'human_prob': 0.0, 'ai_prob': 0.0} for i in range(len(sentences))}
# # Process windows in batches
# batch_size = 16
# for i in range(0, len(windows), batch_size):
# batch_windows = windows[i:i + batch_size]
# batch_indices = window_sentence_indices[i:i + batch_size]
# inputs = self.tokenizer(
# batch_windows,
# truncation=True,
# padding=True,
# max_length=MAX_LENGTH,
# return_tensors="pt"
# ).to(self.device)
# with torch.no_grad():
# outputs = self.model(**inputs)
# probs = F.softmax(outputs.logits, dim=-1)
# # Attribute predictions more carefully
# for window_idx, indices in enumerate(batch_indices):
# center_idx = len(indices) // 2
# center_weight = 0.7 # Higher weight for center sentence
# edge_weight = 0.3 / (len(indices) - 1) # Distribute remaining weight
# for pos, sent_idx in enumerate(indices):
# # Apply higher weight to center sentence
# weight = center_weight if pos == center_idx else edge_weight
# sentence_appearances[sent_idx] += weight
# sentence_scores[sent_idx]['human_prob'] += weight * probs[window_idx][1].item()
# sentence_scores[sent_idx]['ai_prob'] += weight * probs[window_idx][0].item()
# del inputs, outputs, probs
# if torch.cuda.is_available():
# torch.cuda.empty_cache()
# # Calculate final predictions
# sentence_predictions = []
# for i in range(len(sentences)):
# if sentence_appearances[i] > 0:
# human_prob = sentence_scores[i]['human_prob'] / sentence_appearances[i]
# ai_prob = sentence_scores[i]['ai_prob'] / sentence_appearances[i]
# # Only apply minimal smoothing at prediction boundaries
# if i > 0 and i < len(sentences) - 1:
# prev_human = sentence_scores[i-1]['human_prob'] / sentence_appearances[i-1]
# prev_ai = sentence_scores[i-1]['ai_prob'] / sentence_appearances[i-1]
# next_human = sentence_scores[i+1]['human_prob'] / sentence_appearances[i+1]
# next_ai = sentence_scores[i+1]['ai_prob'] / sentence_appearances[i+1]
# # Check if we're at a prediction boundary
# current_pred = 'human' if human_prob > ai_prob else 'ai'
# prev_pred = 'human' if prev_human > prev_ai else 'ai'
# next_pred = 'human' if next_human > next_ai else 'ai'
# if current_pred != prev_pred or current_pred != next_pred:
# # Small adjustment at boundaries
# smooth_factor = 0.1
# human_prob = (human_prob * (1 - smooth_factor) +
# (prev_human + next_human) * smooth_factor / 2)
# ai_prob = (ai_prob * (1 - smooth_factor) +
# (prev_ai + next_ai) * smooth_factor / 2)
# sentence_predictions.append({
# 'sentence': sentences[i],
# 'human_prob': human_prob,
# 'ai_prob': ai_prob,
# 'prediction': 'human' if human_prob > ai_prob else 'ai',
# 'confidence': max(human_prob, ai_prob)
# })
# return {
# 'sentence_predictions': sentence_predictions,
# 'highlighted_text': self.format_predictions_html(sentence_predictions),
# 'full_text': text,
# 'overall_prediction': self.aggregate_predictions(sentence_predictions)
# }
def detailed_scan(self, text: str) -> Dict:
"""Perform a detailed scan with improved sentence-level analysis."""
# Clean up trailing whitespace
text = text.rstrip()
if not text.strip():
return {
'sentence_predictions': [],
'highlighted_text': '',
'full_text': '',
'overall_prediction': {
'prediction': 'unknown',
'confidence': 0.0,
'num_sentences': 0
}
}
sentences = self.processor.split_into_sentences(text)
if not sentences:
return {}
# Create centered windows for each sentence
windows, window_sentence_indices = self.processor.create_centered_windows(sentences, WINDOW_SIZE)
# Track scores for each sentence
sentence_appearances = {i: 0 for i in range(len(sentences))}
sentence_scores = {i: {'human_prob': 0.0, 'ai_prob': 0.0} for i in range(len(sentences))}
# Process windows in batches
for i in range(0, len(windows), BATCH_SIZE):
batch_windows = windows[i:i + BATCH_SIZE]
batch_indices = window_sentence_indices[i:i + BATCH_SIZE]
inputs = self.tokenizer(
batch_windows,
truncation=True,
padding=True,
max_length=MAX_LENGTH,
return_tensors="pt"
).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
probs = F.softmax(outputs.logits, dim=-1)
# Attribute predictions with weighted scoring
for window_idx, indices in enumerate(batch_indices):
center_idx = len(indices) // 2
center_weight = 0.7 # Higher weight for center sentence
edge_weight = 0.3 / (len(indices) - 1) # Distribute remaining weight
for pos, sent_idx in enumerate(indices):
# Apply higher weight to center sentence
weight = center_weight if pos == center_idx else edge_weight
sentence_appearances[sent_idx] += weight
sentence_scores[sent_idx]['human_prob'] += weight * probs[window_idx][1].item()
sentence_scores[sent_idx]['ai_prob'] += weight * probs[window_idx][0].item()
# Clean up memory
del inputs, outputs, probs
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Calculate final predictions with boundary smoothing
sentence_predictions = []
for i in range(len(sentences)):
if sentence_appearances[i] > 0:
human_prob = sentence_scores[i]['human_prob'] / sentence_appearances[i]
ai_prob = sentence_scores[i]['ai_prob'] / sentence_appearances[i]
# Apply minimal smoothing at prediction boundaries
if i > 0 and i < len(sentences) - 1:
prev_human = sentence_scores[i-1]['human_prob'] / sentence_appearances[i-1]
prev_ai = sentence_scores[i-1]['ai_prob'] / sentence_appearances[i-1]
next_human = sentence_scores[i+1]['human_prob'] / sentence_appearances[i+1]
next_ai = sentence_scores[i+1]['ai_prob'] / sentence_appearances[i+1]
# Check if we're at a prediction boundary
current_pred = 'human' if human_prob > ai_prob else 'ai'
prev_pred = 'human' if prev_human > prev_ai else 'ai'
next_pred = 'human' if next_human > next_ai else 'ai'
if current_pred != prev_pred or current_pred != next_pred:
# Small adjustment at boundaries
smooth_factor = 0.1
human_prob = (human_prob * (1 - smooth_factor) +
(prev_human + next_human) * smooth_factor / 2)
ai_prob = (ai_prob * (1 - smooth_factor) +
(prev_ai + next_ai) * smooth_factor / 2)
sentence_predictions.append({
'sentence': sentences[i],
'human_prob': human_prob,
'ai_prob': ai_prob,
'prediction': 'human' if human_prob > ai_prob else 'ai',
'confidence': max(human_prob, ai_prob)
})
return {
'sentence_predictions': sentence_predictions,
'highlighted_text': self.format_predictions_html(sentence_predictions),
'full_text': text,
'overall_prediction': self.aggregate_predictions(sentence_predictions)
}
def format_predictions_html(self, sentence_predictions: List[Dict]) -> str:
"""Format predictions as HTML with color-coding."""
html_parts = []
for pred in sentence_predictions:
sentence = pred['sentence']
confidence = pred['confidence']
if confidence >= CONFIDENCE_THRESHOLD:
if pred['prediction'] == 'human':
color = "#90EE90" # Light green
else:
color = "#FFB6C6" # Light red
else:
if pred['prediction'] == 'human':
color = "#E8F5E9" # Very light green
else:
color = "#FFEBEE" # Very light red
html_parts.append(f'<span style="background-color: {color};">{sentence}</span>')
return " ".join(html_parts)
def aggregate_predictions(self, predictions: List[Dict]) -> Dict:
"""Aggregate predictions from multiple sentences into a single prediction."""
if not predictions:
return {
'prediction': 'unknown',
'confidence': 0.0,
'num_sentences': 0
}
total_human_prob = sum(p['human_prob'] for p in predictions)
total_ai_prob = sum(p['ai_prob'] for p in predictions)
num_sentences = len(predictions)
avg_human_prob = total_human_prob / num_sentences
avg_ai_prob = total_ai_prob / num_sentences
return {
'prediction': 'human' if avg_human_prob > avg_ai_prob else 'ai',
'confidence': max(avg_human_prob, avg_ai_prob),
'num_sentences': num_sentences
}
def analyze_text(text: str, mode: str, classifier: TextClassifier) -> tuple:
"""Analyze text using specified mode and return formatted results."""
if mode == "quick":
result = classifier.quick_scan(text)
quick_analysis = f"""
PREDICTION: {result['prediction'].upper()}
Confidence: {result['confidence']*100:.1f}%
Windows analyzed: {result['num_windows']}
"""
return (
text, # No highlighting in quick mode
"Quick scan mode - no sentence-level analysis available",
quick_analysis
)
else:
analysis = classifier.detailed_scan(text)
detailed_analysis = []
for pred in analysis['sentence_predictions']:
confidence = pred['confidence'] * 100
detailed_analysis.append(f"Sentence: {pred['sentence']}")
detailed_analysis.append(f"Prediction: {pred['prediction'].upper()}")
detailed_analysis.append(f"Confidence: {confidence:.1f}%")
detailed_analysis.append("-" * 50)
final_pred = analysis['overall_prediction']
overall_result = f"""
FINAL PREDICTION: {final_pred['prediction'].upper()}
Overall confidence: {final_pred['confidence']*100:.1f}%
Number of sentences analyzed: {final_pred['num_sentences']}
"""
return (
analysis['highlighted_text'],
"\n".join(detailed_analysis),
overall_result
)
# Initialize the classifier globally
classifier = TextClassifier()
# Create Gradio interface
demo = gr.Interface(
fn=lambda text, mode: analyze_text(text, mode, classifier),
inputs=[
gr.Textbox(
lines=8,
placeholder="Enter text to analyze...",
label="Input Text"
),
gr.Radio(
choices=["quick", "detailed"],
value="quick",
label="Analysis Mode",
info="Quick mode for faster analysis, Detailed mode for sentence-level analysis"
)
],
outputs=[
gr.HTML(label="Highlighted Analysis"),
gr.Textbox(label="Sentence-by-Sentence Analysis", lines=10),
gr.Textbox(label="Overall Result", lines=4)
],
title="AI Text Detector",
description="Analyze text to detect if it was written by a human or AI. Choose between quick scan and detailed sentence-level analysis.",
examples=[
["This is a sample text written by a human. It contains multiple sentences with different ideas. The analysis will show how each sentence is classified.", "quick"],
["This is a sample text written by a human. It contains multiple sentences with different ideas. The analysis will show how each sentence is classified.", "detailed"],
],
api_name="predict",
flagging_mode="never"
)
app = demo.app
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # For development
allow_credentials=True,
allow_methods=["GET", "POST", "OPTIONS"], # Explicitly list methods
allow_headers=["*"],
)
# Ensure CORS is applied before launching
if __name__ == "__main__":
demo.queue()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True
)