ApsidalSolid4 commited on
Commit
7eaaff0
·
verified ·
1 Parent(s): 2d71b85

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +263 -0
app.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+ import torch.nn.functional as F
5
+ import spacy
6
+ from typing import List, Dict
7
+ import logging
8
+ import os
9
+ import gradio as gr
10
+
11
+ # Configure logging
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # Constants
16
+ MAX_LENGTH = 512
17
+ MODEL_NAME = "microsoft/deberta-v3-small"
18
+ WINDOW_SIZE = 17
19
+ WINDOW_OVERLAP = 2
20
+ CONFIDENCE_THRESHOLD = 0.65
21
+
22
+ class TextWindowProcessor:
23
+ def __init__(self):
24
+ try:
25
+ self.nlp = spacy.load("en_core_web_sm")
26
+ except OSError:
27
+ logger.info("Downloading spacy model...")
28
+ os.system("python -m spacy download en_core_web_sm")
29
+ self.nlp = spacy.load("en_core_web_sm")
30
+
31
+ if 'sentencizer' not in self.nlp.pipe_names:
32
+ self.nlp.add_pipe('sentencizer')
33
+
34
+ disabled_pipes = [pipe for pipe in self.nlp.pipe_names if pipe != 'sentencizer']
35
+ self.nlp.disable_pipes(*disabled_pipes)
36
+
37
+ def split_into_sentences(self, text: str) -> List[str]:
38
+ doc = self.nlp(text)
39
+ return [str(sent).strip() for sent in doc.sents]
40
+
41
+ def create_centered_windows(self, sentences: List[str], window_size: int) -> tuple[List[str], List[List[int]]]:
42
+ """Create windows centered around each sentence for detailed analysis."""
43
+ windows = []
44
+ window_sentence_indices = []
45
+
46
+ for i in range(len(sentences)):
47
+ half_window = window_size // 2
48
+ start_idx = max(0, i - half_window)
49
+ end_idx = min(len(sentences), i + half_window + 1)
50
+
51
+ if start_idx == 0:
52
+ end_idx = min(len(sentences), window_size)
53
+ elif end_idx == len(sentences):
54
+ start_idx = max(0, len(sentences) - window_size)
55
+
56
+ window = sentences[start_idx:end_idx]
57
+ windows.append(" ".join(window))
58
+ window_sentence_indices.append(list(range(start_idx, end_idx)))
59
+
60
+ return windows, window_sentence_indices
61
+
62
+ class TextClassifier:
63
+ def __init__(self):
64
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
65
+ self.model_name = MODEL_NAME
66
+ self.tokenizer = None
67
+ self.model = None
68
+ self.processor = TextWindowProcessor()
69
+ self.initialize_model()
70
+
71
+ def initialize_model(self):
72
+ """Initialize the model and tokenizer."""
73
+ logger.info("Initializing model and tokenizer...")
74
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
75
+
76
+ # First initialize the base model
77
+ self.model = AutoModelForSequenceClassification.from_pretrained(
78
+ self.model_name,
79
+ num_labels=2
80
+ ).to(self.device)
81
+
82
+ # Load your custom trained weights
83
+ model_path = "/home/user/.cache/model_files/model.pt" # Adjust filename as needed
84
+ if os.path.exists(model_path):
85
+ logger.info(f"Loading custom model from {model_path}")
86
+ checkpoint = torch.load(model_path, map_location=self.device)
87
+ self.model.load_state_dict(checkpoint['model_state_dict'])
88
+ else:
89
+ logger.warning("Custom model file not found. Using base model.")
90
+
91
+ self.model.eval()
92
+
93
+ def predict_with_sentence_scores(self, text: str) -> Dict:
94
+ """Predict with sentence-level granularity using overlapping windows."""
95
+ if not text.strip():
96
+ return {
97
+ 'sentence_predictions': [],
98
+ 'highlighted_text': '',
99
+ 'full_text': '',
100
+ 'overall_prediction': {
101
+ 'prediction': 'unknown',
102
+ 'confidence': 0.0,
103
+ 'num_sentences': 0
104
+ }
105
+ }
106
+
107
+ sentences = self.processor.split_into_sentences(text)
108
+ if not sentences:
109
+ return {}
110
+
111
+ # Create centered windows for each sentence
112
+ windows, window_sentence_indices = self.processor.create_centered_windows(sentences, WINDOW_SIZE)
113
+
114
+ # Track scores for each sentence
115
+ sentence_appearances = {i: 0 for i in range(len(sentences))}
116
+ sentence_scores = {i: {'human_prob': 0.0, 'ai_prob': 0.0} for i in range(len(sentences))}
117
+
118
+ # Process windows in batches to save memory
119
+ batch_size = 16
120
+ for i in range(0, len(windows), batch_size):
121
+ batch_windows = windows[i:i + batch_size]
122
+ batch_indices = window_sentence_indices[i:i + batch_size]
123
+
124
+ inputs = self.tokenizer(
125
+ batch_windows,
126
+ truncation=True,
127
+ padding=True,
128
+ max_length=MAX_LENGTH,
129
+ return_tensors="pt"
130
+ ).to(self.device)
131
+
132
+ with torch.no_grad():
133
+ outputs = self.model(**inputs)
134
+ probs = F.softmax(outputs.logits, dim=-1)
135
+
136
+ for window_idx, indices in enumerate(batch_indices):
137
+ for sent_idx in indices:
138
+ sentence_appearances[sent_idx] += 1
139
+ sentence_scores[sent_idx]['human_prob'] += probs[window_idx][1].item()
140
+ sentence_scores[sent_idx]['ai_prob'] += probs[window_idx][0].item()
141
+
142
+ # Average the scores and create final sentence-level predictions
143
+ sentence_predictions = []
144
+ for i in range(len(sentences)):
145
+ if sentence_appearances[i] > 0:
146
+ human_prob = sentence_scores[i]['human_prob'] / sentence_appearances[i]
147
+ ai_prob = sentence_scores[i]['ai_prob'] / sentence_appearances[i]
148
+ sentence_predictions.append({
149
+ 'sentence': sentences[i],
150
+ 'human_prob': human_prob,
151
+ 'ai_prob': ai_prob,
152
+ 'prediction': 'human' if human_prob > ai_prob else 'ai',
153
+ 'confidence': max(human_prob, ai_prob)
154
+ })
155
+
156
+ # Generate analysis outputs
157
+ return {
158
+ 'sentence_predictions': sentence_predictions,
159
+ 'highlighted_text': self.format_predictions_html(sentence_predictions),
160
+ 'full_text': text,
161
+ 'overall_prediction': self.aggregate_predictions(sentence_predictions)
162
+ }
163
+
164
+ def format_predictions_html(self, sentence_predictions: List[Dict]) -> str:
165
+ """Format predictions as HTML with color-coding."""
166
+ html_parts = []
167
+
168
+ for pred in sentence_predictions:
169
+ sentence = pred['sentence']
170
+ confidence = pred['confidence']
171
+
172
+ if confidence >= CONFIDENCE_THRESHOLD:
173
+ if pred['prediction'] == 'human':
174
+ color = "#90EE90" # Light green
175
+ else:
176
+ color = "#FFB6C6" # Light red
177
+ else:
178
+ if pred['prediction'] == 'human':
179
+ color = "#E8F5E9" # Very light green
180
+ else:
181
+ color = "#FFEBEE" # Very light red
182
+
183
+ html_parts.append(f'<span style="background-color: {color};">{sentence}</span>')
184
+
185
+ return " ".join(html_parts)
186
+
187
+ def aggregate_predictions(self, predictions: List[Dict]) -> Dict:
188
+ """Aggregate predictions from multiple sentences into a single prediction."""
189
+ if not predictions:
190
+ return {
191
+ 'prediction': 'unknown',
192
+ 'confidence': 0.0,
193
+ 'num_sentences': 0
194
+ }
195
+
196
+ total_human_prob = sum(p['human_prob'] for p in predictions)
197
+ total_ai_prob = sum(p['ai_prob'] for p in predictions)
198
+ num_sentences = len(predictions)
199
+
200
+ avg_human_prob = total_human_prob / num_sentences
201
+ avg_ai_prob = total_ai_prob / num_sentences
202
+
203
+ return {
204
+ 'prediction': 'human' if avg_human_prob > avg_ai_prob else 'ai',
205
+ 'confidence': max(avg_human_prob, avg_ai_prob),
206
+ 'num_sentences': num_sentences
207
+ }
208
+
209
+ def analyze_text(text: str, classifier: TextClassifier) -> tuple:
210
+ """Analyze text and return formatted results for Gradio interface."""
211
+ # Get predictions
212
+ analysis = classifier.predict_with_sentence_scores(text)
213
+
214
+ # Format sentence-by-sentence analysis
215
+ detailed_analysis = []
216
+ for pred in analysis['sentence_predictions']:
217
+ confidence = pred['confidence'] * 100
218
+ detailed_analysis.append(f"Sentence: {pred['sentence']}")
219
+ detailed_analysis.append(f"Prediction: {pred['prediction'].upper()}")
220
+ detailed_analysis.append(f"Confidence: {confidence:.1f}%")
221
+ detailed_analysis.append("-" * 50)
222
+
223
+ # Format overall prediction
224
+ final_pred = analysis['overall_prediction']
225
+ overall_result = f"""
226
+ FINAL PREDICTION: {final_pred['prediction'].upper()}
227
+ Overall confidence: {final_pred['confidence']*100:.1f}%
228
+ Number of sentences analyzed: {final_pred['num_sentences']}
229
+ """
230
+
231
+ return (
232
+ analysis['highlighted_text'],
233
+ "\n".join(detailed_analysis),
234
+ overall_result
235
+ )
236
+
237
+ # Initialize the classifier globally
238
+ classifier = TextClassifier()
239
+
240
+ # Create Gradio interface
241
+ demo = gr.Interface(
242
+ fn=lambda text: analyze_text(text, classifier),
243
+ inputs=gr.Textbox(
244
+ lines=8,
245
+ placeholder="Enter text to analyze...",
246
+ label="Input Text"
247
+ ),
248
+ outputs=[
249
+ gr.HTML(label="Highlighted Analysis"),
250
+ gr.Textbox(label="Sentence-by-Sentence Analysis", lines=10),
251
+ gr.Textbox(label="Overall Result", lines=4)
252
+ ],
253
+ title="AI Text Detector",
254
+ description="Analyze text to detect if it was written by a human or AI. Text is analyzed sentence by sentence, with color coding indicating the prediction confidence.",
255
+ examples=[
256
+ ["This is a sample text written by a human. It contains multiple sentences with different ideas. The analysis will show how each sentence is classified. This demonstrates the AI detection capabilities."],
257
+ ],
258
+ allow_flagging="never"
259
+ )
260
+
261
+ # Launch the interface
262
+ if __name__ == "__main__":
263
+ demo.launch()