Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import spacy
|
6 |
+
from typing import List, Dict
|
7 |
+
import logging
|
8 |
+
import os
|
9 |
+
import gradio as gr
|
10 |
+
|
11 |
+
# Configure logging
|
12 |
+
logging.basicConfig(level=logging.INFO)
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
# Constants
|
16 |
+
MAX_LENGTH = 512
|
17 |
+
MODEL_NAME = "microsoft/deberta-v3-small"
|
18 |
+
WINDOW_SIZE = 17
|
19 |
+
WINDOW_OVERLAP = 2
|
20 |
+
CONFIDENCE_THRESHOLD = 0.65
|
21 |
+
|
22 |
+
class TextWindowProcessor:
|
23 |
+
def __init__(self):
|
24 |
+
try:
|
25 |
+
self.nlp = spacy.load("en_core_web_sm")
|
26 |
+
except OSError:
|
27 |
+
logger.info("Downloading spacy model...")
|
28 |
+
os.system("python -m spacy download en_core_web_sm")
|
29 |
+
self.nlp = spacy.load("en_core_web_sm")
|
30 |
+
|
31 |
+
if 'sentencizer' not in self.nlp.pipe_names:
|
32 |
+
self.nlp.add_pipe('sentencizer')
|
33 |
+
|
34 |
+
disabled_pipes = [pipe for pipe in self.nlp.pipe_names if pipe != 'sentencizer']
|
35 |
+
self.nlp.disable_pipes(*disabled_pipes)
|
36 |
+
|
37 |
+
def split_into_sentences(self, text: str) -> List[str]:
|
38 |
+
doc = self.nlp(text)
|
39 |
+
return [str(sent).strip() for sent in doc.sents]
|
40 |
+
|
41 |
+
def create_centered_windows(self, sentences: List[str], window_size: int) -> tuple[List[str], List[List[int]]]:
|
42 |
+
"""Create windows centered around each sentence for detailed analysis."""
|
43 |
+
windows = []
|
44 |
+
window_sentence_indices = []
|
45 |
+
|
46 |
+
for i in range(len(sentences)):
|
47 |
+
half_window = window_size // 2
|
48 |
+
start_idx = max(0, i - half_window)
|
49 |
+
end_idx = min(len(sentences), i + half_window + 1)
|
50 |
+
|
51 |
+
if start_idx == 0:
|
52 |
+
end_idx = min(len(sentences), window_size)
|
53 |
+
elif end_idx == len(sentences):
|
54 |
+
start_idx = max(0, len(sentences) - window_size)
|
55 |
+
|
56 |
+
window = sentences[start_idx:end_idx]
|
57 |
+
windows.append(" ".join(window))
|
58 |
+
window_sentence_indices.append(list(range(start_idx, end_idx)))
|
59 |
+
|
60 |
+
return windows, window_sentence_indices
|
61 |
+
|
62 |
+
class TextClassifier:
|
63 |
+
def __init__(self):
|
64 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
65 |
+
self.model_name = MODEL_NAME
|
66 |
+
self.tokenizer = None
|
67 |
+
self.model = None
|
68 |
+
self.processor = TextWindowProcessor()
|
69 |
+
self.initialize_model()
|
70 |
+
|
71 |
+
def initialize_model(self):
|
72 |
+
"""Initialize the model and tokenizer."""
|
73 |
+
logger.info("Initializing model and tokenizer...")
|
74 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
75 |
+
|
76 |
+
# First initialize the base model
|
77 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(
|
78 |
+
self.model_name,
|
79 |
+
num_labels=2
|
80 |
+
).to(self.device)
|
81 |
+
|
82 |
+
# Load your custom trained weights
|
83 |
+
model_path = "/home/user/.cache/model_files/model.pt" # Adjust filename as needed
|
84 |
+
if os.path.exists(model_path):
|
85 |
+
logger.info(f"Loading custom model from {model_path}")
|
86 |
+
checkpoint = torch.load(model_path, map_location=self.device)
|
87 |
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
88 |
+
else:
|
89 |
+
logger.warning("Custom model file not found. Using base model.")
|
90 |
+
|
91 |
+
self.model.eval()
|
92 |
+
|
93 |
+
def predict_with_sentence_scores(self, text: str) -> Dict:
|
94 |
+
"""Predict with sentence-level granularity using overlapping windows."""
|
95 |
+
if not text.strip():
|
96 |
+
return {
|
97 |
+
'sentence_predictions': [],
|
98 |
+
'highlighted_text': '',
|
99 |
+
'full_text': '',
|
100 |
+
'overall_prediction': {
|
101 |
+
'prediction': 'unknown',
|
102 |
+
'confidence': 0.0,
|
103 |
+
'num_sentences': 0
|
104 |
+
}
|
105 |
+
}
|
106 |
+
|
107 |
+
sentences = self.processor.split_into_sentences(text)
|
108 |
+
if not sentences:
|
109 |
+
return {}
|
110 |
+
|
111 |
+
# Create centered windows for each sentence
|
112 |
+
windows, window_sentence_indices = self.processor.create_centered_windows(sentences, WINDOW_SIZE)
|
113 |
+
|
114 |
+
# Track scores for each sentence
|
115 |
+
sentence_appearances = {i: 0 for i in range(len(sentences))}
|
116 |
+
sentence_scores = {i: {'human_prob': 0.0, 'ai_prob': 0.0} for i in range(len(sentences))}
|
117 |
+
|
118 |
+
# Process windows in batches to save memory
|
119 |
+
batch_size = 16
|
120 |
+
for i in range(0, len(windows), batch_size):
|
121 |
+
batch_windows = windows[i:i + batch_size]
|
122 |
+
batch_indices = window_sentence_indices[i:i + batch_size]
|
123 |
+
|
124 |
+
inputs = self.tokenizer(
|
125 |
+
batch_windows,
|
126 |
+
truncation=True,
|
127 |
+
padding=True,
|
128 |
+
max_length=MAX_LENGTH,
|
129 |
+
return_tensors="pt"
|
130 |
+
).to(self.device)
|
131 |
+
|
132 |
+
with torch.no_grad():
|
133 |
+
outputs = self.model(**inputs)
|
134 |
+
probs = F.softmax(outputs.logits, dim=-1)
|
135 |
+
|
136 |
+
for window_idx, indices in enumerate(batch_indices):
|
137 |
+
for sent_idx in indices:
|
138 |
+
sentence_appearances[sent_idx] += 1
|
139 |
+
sentence_scores[sent_idx]['human_prob'] += probs[window_idx][1].item()
|
140 |
+
sentence_scores[sent_idx]['ai_prob'] += probs[window_idx][0].item()
|
141 |
+
|
142 |
+
# Average the scores and create final sentence-level predictions
|
143 |
+
sentence_predictions = []
|
144 |
+
for i in range(len(sentences)):
|
145 |
+
if sentence_appearances[i] > 0:
|
146 |
+
human_prob = sentence_scores[i]['human_prob'] / sentence_appearances[i]
|
147 |
+
ai_prob = sentence_scores[i]['ai_prob'] / sentence_appearances[i]
|
148 |
+
sentence_predictions.append({
|
149 |
+
'sentence': sentences[i],
|
150 |
+
'human_prob': human_prob,
|
151 |
+
'ai_prob': ai_prob,
|
152 |
+
'prediction': 'human' if human_prob > ai_prob else 'ai',
|
153 |
+
'confidence': max(human_prob, ai_prob)
|
154 |
+
})
|
155 |
+
|
156 |
+
# Generate analysis outputs
|
157 |
+
return {
|
158 |
+
'sentence_predictions': sentence_predictions,
|
159 |
+
'highlighted_text': self.format_predictions_html(sentence_predictions),
|
160 |
+
'full_text': text,
|
161 |
+
'overall_prediction': self.aggregate_predictions(sentence_predictions)
|
162 |
+
}
|
163 |
+
|
164 |
+
def format_predictions_html(self, sentence_predictions: List[Dict]) -> str:
|
165 |
+
"""Format predictions as HTML with color-coding."""
|
166 |
+
html_parts = []
|
167 |
+
|
168 |
+
for pred in sentence_predictions:
|
169 |
+
sentence = pred['sentence']
|
170 |
+
confidence = pred['confidence']
|
171 |
+
|
172 |
+
if confidence >= CONFIDENCE_THRESHOLD:
|
173 |
+
if pred['prediction'] == 'human':
|
174 |
+
color = "#90EE90" # Light green
|
175 |
+
else:
|
176 |
+
color = "#FFB6C6" # Light red
|
177 |
+
else:
|
178 |
+
if pred['prediction'] == 'human':
|
179 |
+
color = "#E8F5E9" # Very light green
|
180 |
+
else:
|
181 |
+
color = "#FFEBEE" # Very light red
|
182 |
+
|
183 |
+
html_parts.append(f'<span style="background-color: {color};">{sentence}</span>')
|
184 |
+
|
185 |
+
return " ".join(html_parts)
|
186 |
+
|
187 |
+
def aggregate_predictions(self, predictions: List[Dict]) -> Dict:
|
188 |
+
"""Aggregate predictions from multiple sentences into a single prediction."""
|
189 |
+
if not predictions:
|
190 |
+
return {
|
191 |
+
'prediction': 'unknown',
|
192 |
+
'confidence': 0.0,
|
193 |
+
'num_sentences': 0
|
194 |
+
}
|
195 |
+
|
196 |
+
total_human_prob = sum(p['human_prob'] for p in predictions)
|
197 |
+
total_ai_prob = sum(p['ai_prob'] for p in predictions)
|
198 |
+
num_sentences = len(predictions)
|
199 |
+
|
200 |
+
avg_human_prob = total_human_prob / num_sentences
|
201 |
+
avg_ai_prob = total_ai_prob / num_sentences
|
202 |
+
|
203 |
+
return {
|
204 |
+
'prediction': 'human' if avg_human_prob > avg_ai_prob else 'ai',
|
205 |
+
'confidence': max(avg_human_prob, avg_ai_prob),
|
206 |
+
'num_sentences': num_sentences
|
207 |
+
}
|
208 |
+
|
209 |
+
def analyze_text(text: str, classifier: TextClassifier) -> tuple:
|
210 |
+
"""Analyze text and return formatted results for Gradio interface."""
|
211 |
+
# Get predictions
|
212 |
+
analysis = classifier.predict_with_sentence_scores(text)
|
213 |
+
|
214 |
+
# Format sentence-by-sentence analysis
|
215 |
+
detailed_analysis = []
|
216 |
+
for pred in analysis['sentence_predictions']:
|
217 |
+
confidence = pred['confidence'] * 100
|
218 |
+
detailed_analysis.append(f"Sentence: {pred['sentence']}")
|
219 |
+
detailed_analysis.append(f"Prediction: {pred['prediction'].upper()}")
|
220 |
+
detailed_analysis.append(f"Confidence: {confidence:.1f}%")
|
221 |
+
detailed_analysis.append("-" * 50)
|
222 |
+
|
223 |
+
# Format overall prediction
|
224 |
+
final_pred = analysis['overall_prediction']
|
225 |
+
overall_result = f"""
|
226 |
+
FINAL PREDICTION: {final_pred['prediction'].upper()}
|
227 |
+
Overall confidence: {final_pred['confidence']*100:.1f}%
|
228 |
+
Number of sentences analyzed: {final_pred['num_sentences']}
|
229 |
+
"""
|
230 |
+
|
231 |
+
return (
|
232 |
+
analysis['highlighted_text'],
|
233 |
+
"\n".join(detailed_analysis),
|
234 |
+
overall_result
|
235 |
+
)
|
236 |
+
|
237 |
+
# Initialize the classifier globally
|
238 |
+
classifier = TextClassifier()
|
239 |
+
|
240 |
+
# Create Gradio interface
|
241 |
+
demo = gr.Interface(
|
242 |
+
fn=lambda text: analyze_text(text, classifier),
|
243 |
+
inputs=gr.Textbox(
|
244 |
+
lines=8,
|
245 |
+
placeholder="Enter text to analyze...",
|
246 |
+
label="Input Text"
|
247 |
+
),
|
248 |
+
outputs=[
|
249 |
+
gr.HTML(label="Highlighted Analysis"),
|
250 |
+
gr.Textbox(label="Sentence-by-Sentence Analysis", lines=10),
|
251 |
+
gr.Textbox(label="Overall Result", lines=4)
|
252 |
+
],
|
253 |
+
title="AI Text Detector",
|
254 |
+
description="Analyze text to detect if it was written by a human or AI. Text is analyzed sentence by sentence, with color coding indicating the prediction confidence.",
|
255 |
+
examples=[
|
256 |
+
["This is a sample text written by a human. It contains multiple sentences with different ideas. The analysis will show how each sentence is classified. This demonstrates the AI detection capabilities."],
|
257 |
+
],
|
258 |
+
allow_flagging="never"
|
259 |
+
)
|
260 |
+
|
261 |
+
# Launch the interface
|
262 |
+
if __name__ == "__main__":
|
263 |
+
demo.launch()
|