Spaces:
Sleeping
Sleeping
import torch | |
import numpy as np | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import torch.nn.functional as F | |
import spacy | |
from typing import List, Dict | |
import logging | |
import os | |
import gradio as gr | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Constants | |
MAX_LENGTH = 512 | |
MODEL_NAME = "microsoft/deberta-v3-small" | |
WINDOW_SIZE = 17 | |
WINDOW_OVERLAP = 2 | |
CONFIDENCE_THRESHOLD = 0.65 | |
class TextWindowProcessor: | |
def __init__(self): | |
try: | |
self.nlp = spacy.load("en_core_web_sm") | |
except OSError: | |
logger.info("Downloading spacy model...") | |
spacy.cli.download("en_core_web_sm") | |
self.nlp = spacy.load("en_core_web_sm") | |
if 'sentencizer' not in self.nlp.pipe_names: | |
self.nlp.add_pipe('sentencizer') | |
disabled_pipes = [pipe for pipe in self.nlp.pipe_names if pipe != 'sentencizer'] | |
self.nlp.disable_pipes(*disabled_pipes) | |
def split_into_sentences(self, text: str) -> List[str]: | |
doc = self.nlp(text) | |
return [str(sent).strip() for sent in doc.sents] | |
def create_centered_windows(self, sentences: List[str], window_size: int) -> tuple[List[str], List[List[int]]]: | |
"""Create windows centered around each sentence for detailed analysis.""" | |
windows = [] | |
window_sentence_indices = [] | |
for i in range(len(sentences)): | |
half_window = window_size // 2 | |
start_idx = max(0, i - half_window) | |
end_idx = min(len(sentences), i + half_window + 1) | |
if start_idx == 0: | |
end_idx = min(len(sentences), window_size) | |
elif end_idx == len(sentences): | |
start_idx = max(0, len(sentences) - window_size) | |
window = sentences[start_idx:end_idx] | |
windows.append(" ".join(window)) | |
window_sentence_indices.append(list(range(start_idx, end_idx))) | |
return windows, window_sentence_indices | |
class TextClassifier: | |
def __init__(self): | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.model_name = MODEL_NAME | |
self.tokenizer = None | |
self.model = None | |
self.processor = TextWindowProcessor() | |
self.initialize_model() | |
def initialize_model(self): | |
"""Initialize the model and tokenizer.""" | |
logger.info("Initializing model and tokenizer...") | |
# Download and save tokenizer files locally | |
local_tokenizer_path = "tokenizer" | |
if not os.path.exists(local_tokenizer_path): | |
AutoTokenizer.from_pretrained(self.model_name).save_pretrained(local_tokenizer_path) | |
# Load from local files | |
self.tokenizer = AutoTokenizer.from_pretrained(local_tokenizer_path) | |
# First initialize the base model | |
self.model = AutoModelForSequenceClassification.from_pretrained( | |
self.model_name, | |
num_labels=2 | |
).to(self.device) | |
# Look for model file in the same directory as the code | |
model_path = "model.pt" # Your model file should be uploaded as model.pt | |
if os.path.exists(model_path): | |
logger.info(f"Loading custom model from {model_path}") | |
checkpoint = torch.load(model_path, map_location=self.device) | |
self.model.load_state_dict(checkpoint['model_state_dict']) | |
else: | |
logger.warning("Custom model file not found. Using base model.") | |
self.model.eval() | |
def predict_with_sentence_scores(self, text: str) -> Dict: | |
"""Predict with sentence-level granularity using overlapping windows.""" | |
if not text.strip(): | |
return { | |
'sentence_predictions': [], | |
'highlighted_text': '', | |
'full_text': '', | |
'overall_prediction': { | |
'prediction': 'unknown', | |
'confidence': 0.0, | |
'num_sentences': 0 | |
} | |
} | |
sentences = self.processor.split_into_sentences(text) | |
if not sentences: | |
return {} | |
# Create centered windows for each sentence | |
windows, window_sentence_indices = self.processor.create_centered_windows(sentences, WINDOW_SIZE) | |
# Track scores for each sentence | |
sentence_appearances = {i: 0 for i in range(len(sentences))} | |
sentence_scores = {i: {'human_prob': 0.0, 'ai_prob': 0.0} for i in range(len(sentences))} | |
# Process windows in batches to save memory | |
batch_size = 16 | |
for i in range(0, len(windows), batch_size): | |
batch_windows = windows[i:i + batch_size] | |
batch_indices = window_sentence_indices[i:i + batch_size] | |
inputs = self.tokenizer( | |
batch_windows, | |
truncation=True, | |
padding=True, | |
max_length=MAX_LENGTH, | |
return_tensors="pt" | |
).to(self.device) | |
with torch.no_grad(): | |
outputs = self.model(**inputs) | |
probs = F.softmax(outputs.logits, dim=-1) | |
for window_idx, indices in enumerate(batch_indices): | |
for sent_idx in indices: | |
sentence_appearances[sent_idx] += 1 | |
sentence_scores[sent_idx]['human_prob'] += probs[window_idx][1].item() | |
sentence_scores[sent_idx]['ai_prob'] += probs[window_idx][0].item() | |
# Average the scores and create final sentence-level predictions | |
sentence_predictions = [] | |
for i in range(len(sentences)): | |
if sentence_appearances[i] > 0: | |
human_prob = sentence_scores[i]['human_prob'] / sentence_appearances[i] | |
ai_prob = sentence_scores[i]['ai_prob'] / sentence_appearances[i] | |
sentence_predictions.append({ | |
'sentence': sentences[i], | |
'human_prob': human_prob, | |
'ai_prob': ai_prob, | |
'prediction': 'human' if human_prob > ai_prob else 'ai', | |
'confidence': max(human_prob, ai_prob) | |
}) | |
# Generate analysis outputs | |
return { | |
'sentence_predictions': sentence_predictions, | |
'highlighted_text': self.format_predictions_html(sentence_predictions), | |
'full_text': text, | |
'overall_prediction': self.aggregate_predictions(sentence_predictions) | |
} | |
def format_predictions_html(self, sentence_predictions: List[Dict]) -> str: | |
"""Format predictions as HTML with color-coding.""" | |
html_parts = [] | |
for pred in sentence_predictions: | |
sentence = pred['sentence'] | |
confidence = pred['confidence'] | |
if confidence >= CONFIDENCE_THRESHOLD: | |
if pred['prediction'] == 'human': | |
color = "#90EE90" # Light green | |
else: | |
color = "#FFB6C6" # Light red | |
else: | |
if pred['prediction'] == 'human': | |
color = "#E8F5E9" # Very light green | |
else: | |
color = "#FFEBEE" # Very light red | |
html_parts.append(f'<span style="background-color: {color};">{sentence}</span>') | |
return " ".join(html_parts) | |
def aggregate_predictions(self, predictions: List[Dict]) -> Dict: | |
"""Aggregate predictions from multiple sentences into a single prediction.""" | |
if not predictions: | |
return { | |
'prediction': 'unknown', | |
'confidence': 0.0, | |
'num_sentences': 0 | |
} | |
total_human_prob = sum(p['human_prob'] for p in predictions) | |
total_ai_prob = sum(p['ai_prob'] for p in predictions) | |
num_sentences = len(predictions) | |
avg_human_prob = total_human_prob / num_sentences | |
avg_ai_prob = total_ai_prob / num_sentences | |
return { | |
'prediction': 'human' if avg_human_prob > avg_ai_prob else 'ai', | |
'confidence': max(avg_human_prob, avg_ai_prob), | |
'num_sentences': num_sentences | |
} | |
def analyze_text(text: str, classifier: TextClassifier) -> tuple: | |
"""Analyze text and return formatted results for Gradio interface.""" | |
# Get predictions | |
analysis = classifier.predict_with_sentence_scores(text) | |
# Format sentence-by-sentence analysis | |
detailed_analysis = [] | |
for pred in analysis['sentence_predictions']: | |
confidence = pred['confidence'] * 100 | |
detailed_analysis.append(f"Sentence: {pred['sentence']}") | |
detailed_analysis.append(f"Prediction: {pred['prediction'].upper()}") | |
detailed_analysis.append(f"Confidence: {confidence:.1f}%") | |
detailed_analysis.append("-" * 50) | |
# Format overall prediction | |
final_pred = analysis['overall_prediction'] | |
overall_result = f""" | |
FINAL PREDICTION: {final_pred['prediction'].upper()} | |
Overall confidence: {final_pred['confidence']*100:.1f}% | |
Number of sentences analyzed: {final_pred['num_sentences']} | |
""" | |
return ( | |
analysis['highlighted_text'], | |
"\n".join(detailed_analysis), | |
overall_result | |
) | |
# Initialize the classifier globally | |
classifier = TextClassifier() | |
# Create Gradio interface | |
demo = gr.Interface( | |
fn=lambda text: analyze_text(text, classifier), | |
inputs=gr.Textbox( | |
lines=8, | |
placeholder="Enter text to analyze...", | |
label="Input Text" | |
), | |
outputs=[ | |
gr.HTML(label="Highlighted Analysis"), | |
gr.Textbox(label="Sentence-by-Sentence Analysis", lines=10), | |
gr.Textbox(label="Overall Result", lines=4) | |
], | |
title="AI Text Detector", | |
description="Analyze text to detect if it was written by a human or AI. Text is analyzed sentence by sentence, with color coding indicating the prediction confidence.", | |
examples=[ | |
["This is a sample text written by a human. It contains multiple sentences with different ideas. The analysis will show how each sentence is classified. This demonstrates the AI detection capabilities."], | |
], | |
allow_flagging="never" | |
) | |
# Launch the interface | |
if __name__ == "__main__": | |
demo.launch() |