ApsidalSolid4 commited on
Commit
d475cd9
·
verified ·
1 Parent(s): 967f5dd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -46
app.py CHANGED
@@ -13,18 +13,18 @@ from functools import partial
13
  import time
14
  from datetime import datetime
15
 
16
- # Configure logging
17
  logging.basicConfig(level=logging.INFO)
18
  logger = logging.getLogger(__name__)
19
 
20
- # Constants
21
  MAX_LENGTH = 512
22
  MODEL_NAME = "microsoft/deberta-v3-small"
23
  WINDOW_SIZE = 6
24
  WINDOW_OVERLAP = 2
25
  CONFIDENCE_THRESHOLD = 0.65
26
- BATCH_SIZE = 8 # Reduced batch size for CPU
27
- MAX_WORKERS = 4 # Number of worker threads for processing
28
 
29
  class TextWindowProcessor:
30
  def __init__(self):
@@ -41,7 +41,7 @@ class TextWindowProcessor:
41
  disabled_pipes = [pipe for pipe in self.nlp.pipe_names if pipe != 'sentencizer']
42
  self.nlp.disable_pipes(*disabled_pipes)
43
 
44
- # Initialize thread pool for parallel processing
45
  self.executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
46
 
47
  def split_into_sentences(self, text: str) -> List[str]:
@@ -60,17 +60,16 @@ class TextWindowProcessor:
60
  return windows
61
 
62
  def create_centered_windows(self, sentences: List[str], window_size: int) -> Tuple[List[str], List[List[int]]]:
63
- """Create windows with better boundary handling"""
64
  windows = []
65
  window_sentence_indices = []
66
 
67
  for i in range(len(sentences)):
68
- # Calculate window boundaries centered on current sentence
69
  half_window = window_size // 2
70
  start_idx = max(0, i - half_window)
71
  end_idx = min(len(sentences), i + half_window + 1)
72
 
73
- # Create the window
74
  window = sentences[start_idx:end_idx]
75
  windows.append(" ".join(window))
76
  window_sentence_indices.append(list(range(start_idx, end_idx)))
@@ -79,7 +78,7 @@ class TextWindowProcessor:
79
 
80
  class TextClassifier:
81
  def __init__(self):
82
- # Set thread configuration before any model loading or parallel work
83
  if not torch.cuda.is_available():
84
  torch.set_num_threads(MAX_WORKERS)
85
  torch.set_num_interop_threads(MAX_WORKERS)
@@ -119,7 +118,6 @@ class TextClassifier:
119
  self.model.eval()
120
 
121
  def quick_scan(self, text: str) -> Dict:
122
- """Perform a quick scan using simple window analysis."""
123
  if not text.strip():
124
  return {
125
  'prediction': 'unknown',
@@ -132,7 +130,7 @@ class TextClassifier:
132
 
133
  predictions = []
134
 
135
- # Process windows in smaller batches for CPU efficiency
136
  for i in range(0, len(windows), BATCH_SIZE):
137
  batch_windows = windows[i:i + BATCH_SIZE]
138
 
@@ -157,7 +155,7 @@ class TextClassifier:
157
  }
158
  predictions.append(prediction)
159
 
160
- # Clean up GPU memory if available
161
  del inputs, outputs, probs
162
  if torch.cuda.is_available():
163
  torch.cuda.empty_cache()
@@ -179,8 +177,7 @@ class TextClassifier:
179
  }
180
 
181
  def detailed_scan(self, text: str) -> Dict:
182
- """Perform a detailed scan with improved sentence-level analysis."""
183
- # Clean up trailing whitespace
184
  text = text.rstrip()
185
 
186
  if not text.strip():
@@ -199,14 +196,14 @@ class TextClassifier:
199
  if not sentences:
200
  return {}
201
 
202
- # Create centered windows for each sentence
203
  windows, window_sentence_indices = self.processor.create_centered_windows(sentences, WINDOW_SIZE)
204
 
205
- # Track scores for each sentence
206
  sentence_appearances = {i: 0 for i in range(len(sentences))}
207
  sentence_scores = {i: {'human_prob': 0.0, 'ai_prob': 0.0} for i in range(len(sentences))}
208
 
209
- # Process windows in batches
210
  for i in range(0, len(windows), BATCH_SIZE):
211
  batch_windows = windows[i:i + BATCH_SIZE]
212
  batch_indices = window_sentence_indices[i:i + BATCH_SIZE]
@@ -223,45 +220,45 @@ class TextClassifier:
223
  outputs = self.model(**inputs)
224
  probs = F.softmax(outputs.logits, dim=-1)
225
 
226
- # Attribute predictions with weighted scoring
227
  for window_idx, indices in enumerate(batch_indices):
228
  center_idx = len(indices) // 2
229
- center_weight = 0.7 # Higher weight for center sentence
230
- edge_weight = 0.3 / (len(indices) - 1) # Distribute remaining weight
231
 
232
  for pos, sent_idx in enumerate(indices):
233
- # Apply higher weight to center sentence
234
  weight = center_weight if pos == center_idx else edge_weight
235
  sentence_appearances[sent_idx] += weight
236
  sentence_scores[sent_idx]['human_prob'] += weight * probs[window_idx][1].item()
237
  sentence_scores[sent_idx]['ai_prob'] += weight * probs[window_idx][0].item()
238
 
239
- # Clean up memory
240
  del inputs, outputs, probs
241
  if torch.cuda.is_available():
242
  torch.cuda.empty_cache()
243
 
244
- # Calculate final predictions with boundary smoothing
245
  sentence_predictions = []
246
  for i in range(len(sentences)):
247
  if sentence_appearances[i] > 0:
248
  human_prob = sentence_scores[i]['human_prob'] / sentence_appearances[i]
249
  ai_prob = sentence_scores[i]['ai_prob'] / sentence_appearances[i]
250
 
251
- # Apply minimal smoothing at prediction boundaries
252
  if i > 0 and i < len(sentences) - 1:
253
  prev_human = sentence_scores[i-1]['human_prob'] / sentence_appearances[i-1]
254
  prev_ai = sentence_scores[i-1]['ai_prob'] / sentence_appearances[i-1]
255
  next_human = sentence_scores[i+1]['human_prob'] / sentence_appearances[i+1]
256
  next_ai = sentence_scores[i+1]['ai_prob'] / sentence_appearances[i+1]
257
 
258
- # Check if we're at a prediction boundary
259
  current_pred = 'human' if human_prob > ai_prob else 'ai'
260
  prev_pred = 'human' if prev_human > prev_ai else 'ai'
261
  next_pred = 'human' if next_human > next_ai else 'ai'
262
 
263
  if current_pred != prev_pred or current_pred != next_pred:
264
- # Small adjustment at boundaries
265
  smooth_factor = 0.1
266
  human_prob = (human_prob * (1 - smooth_factor) +
267
  (prev_human + next_human) * smooth_factor / 2)
@@ -284,7 +281,6 @@ class TextClassifier:
284
  }
285
 
286
  def format_predictions_html(self, sentence_predictions: List[Dict]) -> str:
287
- """Format predictions as HTML with color-coding."""
288
  html_parts = []
289
 
290
  for pred in sentence_predictions:
@@ -293,21 +289,20 @@ class TextClassifier:
293
 
294
  if confidence >= CONFIDENCE_THRESHOLD:
295
  if pred['prediction'] == 'human':
296
- color = "#90EE90" # Light green
297
  else:
298
- color = "#FFB6C6" # Light red
299
  else:
300
  if pred['prediction'] == 'human':
301
- color = "#E8F5E9" # Very light green
302
  else:
303
- color = "#FFEBEE" # Very light red
304
 
305
  html_parts.append(f'<span style="background-color: {color};">{sentence}</span>')
306
 
307
  return " ".join(html_parts)
308
 
309
  def aggregate_predictions(self, predictions: List[Dict]) -> Dict:
310
- """Aggregate predictions from multiple sentences into a single prediction."""
311
  if not predictions:
312
  return {
313
  'prediction': 'unknown',
@@ -329,14 +324,13 @@ class TextClassifier:
329
  }
330
 
331
  def analyze_text(text: str, mode: str, classifier: TextClassifier) -> tuple:
332
- """Analyze text using specified mode and return formatted results."""
333
- # Start timing for normal analysis
334
  start_time = time.time()
335
 
336
- # Count words in the text
337
  word_count = len(text.split())
338
 
339
- # If text is less than 200 words and detailed mode is selected, switch to quick mode
340
  original_mode = mode
341
  if word_count < 200 and mode == "detailed":
342
  mode = "quick"
@@ -350,15 +344,15 @@ def analyze_text(text: str, mode: str, classifier: TextClassifier) -> tuple:
350
  Windows analyzed: {result['num_windows']}
351
  """
352
 
353
- # Add note if mode was switched
354
  if original_mode == "detailed":
355
  quick_analysis += f"\n\nNote: Switched to quick mode because text contains only {word_count} words. Minimum 200 words required for detailed analysis."
356
 
357
- # Calculate execution time in milliseconds
358
  execution_time = (time.time() - start_time) * 1000
359
 
360
  return (
361
- text, # No highlighting in quick mode
362
  "Quick scan mode - no sentence-level analysis available",
363
  quick_analysis
364
  )
@@ -380,7 +374,7 @@ def analyze_text(text: str, mode: str, classifier: TextClassifier) -> tuple:
380
  Number of sentences analyzed: {final_pred['num_sentences']}
381
  """
382
 
383
- # Calculate execution time in milliseconds
384
  execution_time = (time.time() - start_time) * 1000
385
 
386
  return (
@@ -389,10 +383,10 @@ def analyze_text(text: str, mode: str, classifier: TextClassifier) -> tuple:
389
  overall_result
390
  )
391
 
392
- # Initialize the classifier globally
393
  classifier = TextClassifier()
394
 
395
- # Create Gradio interface
396
  demo = gr.Interface(
397
  fn=lambda text, mode: analyze_text(text, mode, classifier),
398
  inputs=[
@@ -419,19 +413,17 @@ demo = gr.Interface(
419
  flagging_mode="never"
420
  )
421
 
422
- # Get the FastAPI app from Gradio
423
  app = demo.app
424
 
425
- # Add CORS middleware
426
  app.add_middleware(
427
  CORSMiddleware,
428
- allow_origins=["*"], # For development
429
  allow_credentials=True,
430
  allow_methods=["GET", "POST", "OPTIONS"],
431
  allow_headers=["*"],
432
  )
433
 
434
- # Ensure CORS is applied before launching
435
  if __name__ == "__main__":
436
  demo.queue()
437
  demo.launch(
 
13
  import time
14
  from datetime import datetime
15
 
16
+
17
  logging.basicConfig(level=logging.INFO)
18
  logger = logging.getLogger(__name__)
19
 
20
+
21
  MAX_LENGTH = 512
22
  MODEL_NAME = "microsoft/deberta-v3-small"
23
  WINDOW_SIZE = 6
24
  WINDOW_OVERLAP = 2
25
  CONFIDENCE_THRESHOLD = 0.65
26
+ BATCH_SIZE = 8
27
+ MAX_WORKERS = 4
28
 
29
  class TextWindowProcessor:
30
  def __init__(self):
 
41
  disabled_pipes = [pipe for pipe in self.nlp.pipe_names if pipe != 'sentencizer']
42
  self.nlp.disable_pipes(*disabled_pipes)
43
 
44
+
45
  self.executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
46
 
47
  def split_into_sentences(self, text: str) -> List[str]:
 
60
  return windows
61
 
62
  def create_centered_windows(self, sentences: List[str], window_size: int) -> Tuple[List[str], List[List[int]]]:
 
63
  windows = []
64
  window_sentence_indices = []
65
 
66
  for i in range(len(sentences)):
67
+
68
  half_window = window_size // 2
69
  start_idx = max(0, i - half_window)
70
  end_idx = min(len(sentences), i + half_window + 1)
71
 
72
+
73
  window = sentences[start_idx:end_idx]
74
  windows.append(" ".join(window))
75
  window_sentence_indices.append(list(range(start_idx, end_idx)))
 
78
 
79
  class TextClassifier:
80
  def __init__(self):
81
+
82
  if not torch.cuda.is_available():
83
  torch.set_num_threads(MAX_WORKERS)
84
  torch.set_num_interop_threads(MAX_WORKERS)
 
118
  self.model.eval()
119
 
120
  def quick_scan(self, text: str) -> Dict:
 
121
  if not text.strip():
122
  return {
123
  'prediction': 'unknown',
 
130
 
131
  predictions = []
132
 
133
+
134
  for i in range(0, len(windows), BATCH_SIZE):
135
  batch_windows = windows[i:i + BATCH_SIZE]
136
 
 
155
  }
156
  predictions.append(prediction)
157
 
158
+
159
  del inputs, outputs, probs
160
  if torch.cuda.is_available():
161
  torch.cuda.empty_cache()
 
177
  }
178
 
179
  def detailed_scan(self, text: str) -> Dict:
180
+
 
181
  text = text.rstrip()
182
 
183
  if not text.strip():
 
196
  if not sentences:
197
  return {}
198
 
199
+
200
  windows, window_sentence_indices = self.processor.create_centered_windows(sentences, WINDOW_SIZE)
201
 
202
+
203
  sentence_appearances = {i: 0 for i in range(len(sentences))}
204
  sentence_scores = {i: {'human_prob': 0.0, 'ai_prob': 0.0} for i in range(len(sentences))}
205
 
206
+
207
  for i in range(0, len(windows), BATCH_SIZE):
208
  batch_windows = windows[i:i + BATCH_SIZE]
209
  batch_indices = window_sentence_indices[i:i + BATCH_SIZE]
 
220
  outputs = self.model(**inputs)
221
  probs = F.softmax(outputs.logits, dim=-1)
222
 
223
+
224
  for window_idx, indices in enumerate(batch_indices):
225
  center_idx = len(indices) // 2
226
+ center_weight = 0.7
227
+ edge_weight = 0.3 / (len(indices) - 1)
228
 
229
  for pos, sent_idx in enumerate(indices):
230
+
231
  weight = center_weight if pos == center_idx else edge_weight
232
  sentence_appearances[sent_idx] += weight
233
  sentence_scores[sent_idx]['human_prob'] += weight * probs[window_idx][1].item()
234
  sentence_scores[sent_idx]['ai_prob'] += weight * probs[window_idx][0].item()
235
 
236
+
237
  del inputs, outputs, probs
238
  if torch.cuda.is_available():
239
  torch.cuda.empty_cache()
240
 
241
+
242
  sentence_predictions = []
243
  for i in range(len(sentences)):
244
  if sentence_appearances[i] > 0:
245
  human_prob = sentence_scores[i]['human_prob'] / sentence_appearances[i]
246
  ai_prob = sentence_scores[i]['ai_prob'] / sentence_appearances[i]
247
 
248
+
249
  if i > 0 and i < len(sentences) - 1:
250
  prev_human = sentence_scores[i-1]['human_prob'] / sentence_appearances[i-1]
251
  prev_ai = sentence_scores[i-1]['ai_prob'] / sentence_appearances[i-1]
252
  next_human = sentence_scores[i+1]['human_prob'] / sentence_appearances[i+1]
253
  next_ai = sentence_scores[i+1]['ai_prob'] / sentence_appearances[i+1]
254
 
255
+
256
  current_pred = 'human' if human_prob > ai_prob else 'ai'
257
  prev_pred = 'human' if prev_human > prev_ai else 'ai'
258
  next_pred = 'human' if next_human > next_ai else 'ai'
259
 
260
  if current_pred != prev_pred or current_pred != next_pred:
261
+
262
  smooth_factor = 0.1
263
  human_prob = (human_prob * (1 - smooth_factor) +
264
  (prev_human + next_human) * smooth_factor / 2)
 
281
  }
282
 
283
  def format_predictions_html(self, sentence_predictions: List[Dict]) -> str:
 
284
  html_parts = []
285
 
286
  for pred in sentence_predictions:
 
289
 
290
  if confidence >= CONFIDENCE_THRESHOLD:
291
  if pred['prediction'] == 'human':
292
+ color = "
293
  else:
294
+ color = "
295
  else:
296
  if pred['prediction'] == 'human':
297
+ color = "
298
  else:
299
+ color = "
300
 
301
  html_parts.append(f'<span style="background-color: {color};">{sentence}</span>')
302
 
303
  return " ".join(html_parts)
304
 
305
  def aggregate_predictions(self, predictions: List[Dict]) -> Dict:
 
306
  if not predictions:
307
  return {
308
  'prediction': 'unknown',
 
324
  }
325
 
326
  def analyze_text(text: str, mode: str, classifier: TextClassifier) -> tuple:
327
+
 
328
  start_time = time.time()
329
 
330
+
331
  word_count = len(text.split())
332
 
333
+
334
  original_mode = mode
335
  if word_count < 200 and mode == "detailed":
336
  mode = "quick"
 
344
  Windows analyzed: {result['num_windows']}
345
  """
346
 
347
+
348
  if original_mode == "detailed":
349
  quick_analysis += f"\n\nNote: Switched to quick mode because text contains only {word_count} words. Minimum 200 words required for detailed analysis."
350
 
351
+
352
  execution_time = (time.time() - start_time) * 1000
353
 
354
  return (
355
+ text,
356
  "Quick scan mode - no sentence-level analysis available",
357
  quick_analysis
358
  )
 
374
  Number of sentences analyzed: {final_pred['num_sentences']}
375
  """
376
 
377
+
378
  execution_time = (time.time() - start_time) * 1000
379
 
380
  return (
 
383
  overall_result
384
  )
385
 
386
+
387
  classifier = TextClassifier()
388
 
389
+
390
  demo = gr.Interface(
391
  fn=lambda text, mode: analyze_text(text, mode, classifier),
392
  inputs=[
 
413
  flagging_mode="never"
414
  )
415
 
416
+
417
  app = demo.app
418
 
 
419
  app.add_middleware(
420
  CORSMiddleware,
421
+ allow_origins=["*"],
422
  allow_credentials=True,
423
  allow_methods=["GET", "POST", "OPTIONS"],
424
  allow_headers=["*"],
425
  )
426
 
 
427
  if __name__ == "__main__":
428
  demo.queue()
429
  demo.launch(