Update app.py
Browse files
app.py
CHANGED
@@ -12,8 +12,8 @@ from concurrent.futures import ThreadPoolExecutor
|
|
12 |
from functools import partial
|
13 |
import time
|
14 |
import csv
|
15 |
-
import os
|
16 |
from datetime import datetime
|
|
|
17 |
# Configure logging
|
18 |
logging.basicConfig(level=logging.INFO)
|
19 |
logger = logging.getLogger(__name__)
|
@@ -27,37 +27,6 @@ CONFIDENCE_THRESHOLD = 0.65
|
|
27 |
BATCH_SIZE = 8 # Reduced batch size for CPU
|
28 |
MAX_WORKERS = 4 # Number of worker threads for processing
|
29 |
|
30 |
-
|
31 |
-
def log_prediction_data(input_text, word_count, prediction, confidence, execution_time, mode):
|
32 |
-
"""Log prediction data to a CSV file in the /tmp directory."""
|
33 |
-
# Define the CSV file path
|
34 |
-
csv_path = "/tmp/prediction_logs.csv"
|
35 |
-
|
36 |
-
# Check if file exists to determine if we need to write headers
|
37 |
-
file_exists = os.path.isfile(csv_path)
|
38 |
-
|
39 |
-
try:
|
40 |
-
with open(csv_path, 'a', newline='', encoding='utf-8') as f:
|
41 |
-
writer = csv.writer(f)
|
42 |
-
|
43 |
-
# Write headers if the file is newly created
|
44 |
-
if not file_exists:
|
45 |
-
writer.writerow(["timestamp", "word_count", "prediction", "confidence", "execution_time_ms", "analysis_mode", "full_text"])
|
46 |
-
|
47 |
-
# Clean up the input text for CSV storage (replace newlines with spaces)
|
48 |
-
cleaned_text = input_text.replace("\n", " ")
|
49 |
-
|
50 |
-
# Write the data row with the full text
|
51 |
-
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
52 |
-
writer.writerow([timestamp, word_count, prediction, f"{confidence:.2f}", f"{execution_time:.2f}", mode, cleaned_text])
|
53 |
-
|
54 |
-
logger.info(f"Successfully logged prediction data to {csv_path}")
|
55 |
-
return True
|
56 |
-
except Exception as e:
|
57 |
-
logger.error(f"Error logging prediction data: {str(e)}")
|
58 |
-
return False
|
59 |
-
|
60 |
-
|
61 |
class TextWindowProcessor:
|
62 |
def __init__(self):
|
63 |
try:
|
@@ -210,100 +179,6 @@ class TextClassifier:
|
|
210 |
'num_windows': len(predictions)
|
211 |
}
|
212 |
|
213 |
-
# def detailed_scan(self, text: str) -> Dict:
|
214 |
-
# """Original prediction method with modified window handling"""
|
215 |
-
# if self.model is None or self.tokenizer is None:
|
216 |
-
# self.load_model()
|
217 |
-
|
218 |
-
# self.model.eval()
|
219 |
-
# sentences = self.processor.split_into_sentences(text)
|
220 |
-
# if not sentences:
|
221 |
-
# return {}
|
222 |
-
|
223 |
-
# # Create centered windows for each sentence
|
224 |
-
# windows, window_sentence_indices = self.processor.create_centered_windows(sentences, WINDOW_SIZE)
|
225 |
-
|
226 |
-
# # Track scores for each sentence
|
227 |
-
# sentence_appearances = {i: 0 for i in range(len(sentences))}
|
228 |
-
# sentence_scores = {i: {'human_prob': 0.0, 'ai_prob': 0.0} for i in range(len(sentences))}
|
229 |
-
|
230 |
-
# # Process windows in batches
|
231 |
-
# batch_size = 16
|
232 |
-
# for i in range(0, len(windows), batch_size):
|
233 |
-
# batch_windows = windows[i:i + batch_size]
|
234 |
-
# batch_indices = window_sentence_indices[i:i + batch_size]
|
235 |
-
|
236 |
-
# inputs = self.tokenizer(
|
237 |
-
# batch_windows,
|
238 |
-
# truncation=True,
|
239 |
-
# padding=True,
|
240 |
-
# max_length=MAX_LENGTH,
|
241 |
-
# return_tensors="pt"
|
242 |
-
# ).to(self.device)
|
243 |
-
|
244 |
-
# with torch.no_grad():
|
245 |
-
# outputs = self.model(**inputs)
|
246 |
-
# probs = F.softmax(outputs.logits, dim=-1)
|
247 |
-
|
248 |
-
# # Attribute predictions more carefully
|
249 |
-
# for window_idx, indices in enumerate(batch_indices):
|
250 |
-
# center_idx = len(indices) // 2
|
251 |
-
# center_weight = 0.7 # Higher weight for center sentence
|
252 |
-
# edge_weight = 0.3 / (len(indices) - 1) # Distribute remaining weight
|
253 |
-
|
254 |
-
# for pos, sent_idx in enumerate(indices):
|
255 |
-
# # Apply higher weight to center sentence
|
256 |
-
# weight = center_weight if pos == center_idx else edge_weight
|
257 |
-
# sentence_appearances[sent_idx] += weight
|
258 |
-
# sentence_scores[sent_idx]['human_prob'] += weight * probs[window_idx][1].item()
|
259 |
-
# sentence_scores[sent_idx]['ai_prob'] += weight * probs[window_idx][0].item()
|
260 |
-
|
261 |
-
# del inputs, outputs, probs
|
262 |
-
# if torch.cuda.is_available():
|
263 |
-
# torch.cuda.empty_cache()
|
264 |
-
|
265 |
-
# # Calculate final predictions
|
266 |
-
# sentence_predictions = []
|
267 |
-
# for i in range(len(sentences)):
|
268 |
-
# if sentence_appearances[i] > 0:
|
269 |
-
# human_prob = sentence_scores[i]['human_prob'] / sentence_appearances[i]
|
270 |
-
# ai_prob = sentence_scores[i]['ai_prob'] / sentence_appearances[i]
|
271 |
-
|
272 |
-
# # Only apply minimal smoothing at prediction boundaries
|
273 |
-
# if i > 0 and i < len(sentences) - 1:
|
274 |
-
# prev_human = sentence_scores[i-1]['human_prob'] / sentence_appearances[i-1]
|
275 |
-
# prev_ai = sentence_scores[i-1]['ai_prob'] / sentence_appearances[i-1]
|
276 |
-
# next_human = sentence_scores[i+1]['human_prob'] / sentence_appearances[i+1]
|
277 |
-
# next_ai = sentence_scores[i+1]['ai_prob'] / sentence_appearances[i+1]
|
278 |
-
|
279 |
-
# # Check if we're at a prediction boundary
|
280 |
-
# current_pred = 'human' if human_prob > ai_prob else 'ai'
|
281 |
-
# prev_pred = 'human' if prev_human > prev_ai else 'ai'
|
282 |
-
# next_pred = 'human' if next_human > next_ai else 'ai'
|
283 |
-
|
284 |
-
# if current_pred != prev_pred or current_pred != next_pred:
|
285 |
-
# # Small adjustment at boundaries
|
286 |
-
# smooth_factor = 0.1
|
287 |
-
# human_prob = (human_prob * (1 - smooth_factor) +
|
288 |
-
# (prev_human + next_human) * smooth_factor / 2)
|
289 |
-
# ai_prob = (ai_prob * (1 - smooth_factor) +
|
290 |
-
# (prev_ai + next_ai) * smooth_factor / 2)
|
291 |
-
|
292 |
-
# sentence_predictions.append({
|
293 |
-
# 'sentence': sentences[i],
|
294 |
-
# 'human_prob': human_prob,
|
295 |
-
# 'ai_prob': ai_prob,
|
296 |
-
# 'prediction': 'human' if human_prob > ai_prob else 'ai',
|
297 |
-
# 'confidence': max(human_prob, ai_prob)
|
298 |
-
# })
|
299 |
-
|
300 |
-
# return {
|
301 |
-
# 'sentence_predictions': sentence_predictions,
|
302 |
-
# 'highlighted_text': self.format_predictions_html(sentence_predictions),
|
303 |
-
# 'full_text': text,
|
304 |
-
# 'overall_prediction': self.aggregate_predictions(sentence_predictions)
|
305 |
-
# }
|
306 |
-
|
307 |
def detailed_scan(self, text: str) -> Dict:
|
308 |
"""Perform a detailed scan with improved sentence-level analysis."""
|
309 |
# Clean up trailing whitespace
|
@@ -454,88 +329,124 @@ class TextClassifier:
|
|
454 |
'num_sentences': num_sentences
|
455 |
}
|
456 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
457 |
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
539 |
|
540 |
# Initialize the classifier globally
|
541 |
classifier = TextClassifier()
|
@@ -567,8 +478,17 @@ demo = gr.Interface(
|
|
567 |
flagging_mode="never"
|
568 |
)
|
569 |
|
570 |
-
|
571 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
572 |
CORSMiddleware,
|
573 |
allow_origins=["*"], # For development
|
574 |
allow_credentials=True,
|
|
|
12 |
from functools import partial
|
13 |
import time
|
14 |
import csv
|
|
|
15 |
from datetime import datetime
|
16 |
+
|
17 |
# Configure logging
|
18 |
logging.basicConfig(level=logging.INFO)
|
19 |
logger = logging.getLogger(__name__)
|
|
|
27 |
BATCH_SIZE = 8 # Reduced batch size for CPU
|
28 |
MAX_WORKERS = 4 # Number of worker threads for processing
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
class TextWindowProcessor:
|
31 |
def __init__(self):
|
32 |
try:
|
|
|
179 |
'num_windows': len(predictions)
|
180 |
}
|
181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
def detailed_scan(self, text: str) -> Dict:
|
183 |
"""Perform a detailed scan with improved sentence-level analysis."""
|
184 |
# Clean up trailing whitespace
|
|
|
329 |
'num_sentences': num_sentences
|
330 |
}
|
331 |
|
332 |
+
def log_prediction_data(input_text, word_count, prediction, confidence, execution_time, mode):
|
333 |
+
"""Log prediction data to a CSV file in the /tmp directory."""
|
334 |
+
# Define the CSV file path
|
335 |
+
csv_path = "/tmp/prediction_logs.csv"
|
336 |
+
|
337 |
+
# Check if file exists to determine if we need to write headers
|
338 |
+
file_exists = os.path.isfile(csv_path)
|
339 |
+
|
340 |
+
try:
|
341 |
+
with open(csv_path, 'a', newline='', encoding='utf-8') as f:
|
342 |
+
writer = csv.writer(f)
|
343 |
+
|
344 |
+
# Write headers if the file is newly created
|
345 |
+
if not file_exists:
|
346 |
+
writer.writerow(["timestamp", "word_count", "prediction", "confidence", "execution_time_ms", "analysis_mode", "full_text"])
|
347 |
+
|
348 |
+
# Clean up the input text for CSV storage (replace newlines with spaces)
|
349 |
+
cleaned_text = input_text.replace("\n", " ")
|
350 |
+
|
351 |
+
# Write the data row with the full text
|
352 |
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
353 |
+
writer.writerow([timestamp, word_count, prediction, f"{confidence:.2f}", f"{execution_time:.2f}", mode, cleaned_text])
|
354 |
+
|
355 |
+
logger.info(f"Successfully logged prediction data to {csv_path}")
|
356 |
+
return True
|
357 |
+
except Exception as e:
|
358 |
+
logger.error(f"Error logging prediction data: {str(e)}")
|
359 |
+
return False
|
360 |
|
361 |
+
def analyze_text(text: str, mode: str, classifier: TextClassifier) -> tuple:
|
362 |
+
"""Analyze text using specified mode and return formatted results."""
|
363 |
+
# Start timing
|
364 |
+
start_time = time.time()
|
365 |
+
|
366 |
+
# Count words in the text
|
367 |
+
word_count = len(text.split())
|
368 |
+
|
369 |
+
# If text is less than 200 words and detailed mode is selected, switch to quick mode
|
370 |
+
original_mode = mode
|
371 |
+
if word_count < 200 and mode == "detailed":
|
372 |
+
mode = "quick"
|
373 |
+
|
374 |
+
if mode == "quick":
|
375 |
+
result = classifier.quick_scan(text)
|
376 |
+
|
377 |
+
quick_analysis = f"""
|
378 |
+
PREDICTION: {result['prediction'].upper()}
|
379 |
+
Confidence: {result['confidence']*100:.1f}%
|
380 |
+
Windows analyzed: {result['num_windows']}
|
381 |
+
"""
|
382 |
+
|
383 |
+
# Add note if mode was switched
|
384 |
+
if original_mode == "detailed":
|
385 |
+
quick_analysis += f"\n\nNote: Switched to quick mode because text contains only {word_count} words. Minimum 200 words required for detailed analysis."
|
386 |
+
|
387 |
+
# Calculate execution time in milliseconds
|
388 |
+
execution_time = (time.time() - start_time) * 1000
|
389 |
+
|
390 |
+
# Log the prediction data
|
391 |
+
log_prediction_data(
|
392 |
+
input_text=text,
|
393 |
+
word_count=word_count,
|
394 |
+
prediction=result['prediction'],
|
395 |
+
confidence=result['confidence'],
|
396 |
+
execution_time=execution_time,
|
397 |
+
mode=original_mode
|
398 |
+
)
|
399 |
+
|
400 |
+
return (
|
401 |
+
text, # No highlighting in quick mode
|
402 |
+
"Quick scan mode - no sentence-level analysis available",
|
403 |
+
quick_analysis
|
404 |
+
)
|
405 |
+
else:
|
406 |
+
analysis = classifier.detailed_scan(text)
|
407 |
+
|
408 |
+
detailed_analysis = []
|
409 |
+
for pred in analysis['sentence_predictions']:
|
410 |
+
confidence = pred['confidence'] * 100
|
411 |
+
detailed_analysis.append(f"Sentence: {pred['sentence']}")
|
412 |
+
detailed_analysis.append(f"Prediction: {pred['prediction'].upper()}")
|
413 |
+
detailed_analysis.append(f"Confidence: {confidence:.1f}%")
|
414 |
+
detailed_analysis.append("-" * 50)
|
415 |
+
|
416 |
+
final_pred = analysis['overall_prediction']
|
417 |
+
overall_result = f"""
|
418 |
+
FINAL PREDICTION: {final_pred['prediction'].upper()}
|
419 |
+
Overall confidence: {final_pred['confidence']*100:.1f}%
|
420 |
+
Number of sentences analyzed: {final_pred['num_sentences']}
|
421 |
+
"""
|
422 |
+
|
423 |
+
# Calculate execution time in milliseconds
|
424 |
+
execution_time = (time.time() - start_time) * 1000
|
425 |
+
|
426 |
+
# Log the prediction data
|
427 |
+
log_prediction_data(
|
428 |
+
input_text=text,
|
429 |
+
word_count=word_count,
|
430 |
+
prediction=final_pred['prediction'],
|
431 |
+
confidence=final_pred['confidence'],
|
432 |
+
execution_time=execution_time,
|
433 |
+
mode=original_mode
|
434 |
+
)
|
435 |
+
|
436 |
+
return (
|
437 |
+
analysis['highlighted_text'],
|
438 |
+
"\n".join(detailed_analysis),
|
439 |
+
overall_result
|
440 |
+
)
|
441 |
|
442 |
+
# Add a function to download the logs
|
443 |
+
def download_logs():
|
444 |
+
log_path = "/tmp/prediction_logs.csv"
|
445 |
+
if os.path.exists(log_path):
|
446 |
+
with open(log_path, 'r', encoding='utf-8') as f:
|
447 |
+
content = f.read()
|
448 |
+
return content
|
449 |
+
return "No logs found."
|
450 |
|
451 |
# Initialize the classifier globally
|
452 |
classifier = TextClassifier()
|
|
|
478 |
flagging_mode="never"
|
479 |
)
|
480 |
|
481 |
+
# Add admin panel for log access (only visible to space owners)
|
482 |
+
with gr.Blocks() as admin_interface:
|
483 |
+
gr.Markdown("## Admin Panel - Data Logs")
|
484 |
+
download_button = gr.Button("Download Logs")
|
485 |
+
log_output = gr.File(label="Prediction Logs")
|
486 |
+
download_button.click(fn=download_logs, outputs=log_output)
|
487 |
+
|
488 |
+
# Combine interfaces
|
489 |
+
app = gr.TabbedInterface([demo, admin_interface], ["AI Text Detector", "Admin"])
|
490 |
+
|
491 |
+
app.app.add_middleware(
|
492 |
CORSMiddleware,
|
493 |
allow_origins=["*"], # For development
|
494 |
allow_credentials=True,
|