ApsidalSolid4 commited on
Commit
7e5917e
·
verified ·
1 Parent(s): 41365d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -20
app.py CHANGED
@@ -13,6 +13,8 @@ from functools import partial
13
  import time
14
  import csv
15
  from datetime import datetime
 
 
16
 
17
  # Configure logging
18
  logging.basicConfig(level=logging.INFO)
@@ -439,15 +441,6 @@ def analyze_text(text: str, mode: str, classifier: TextClassifier) -> tuple:
439
  overall_result
440
  )
441
 
442
- # Add a function to download the logs
443
- def download_logs():
444
- log_path = "/tmp/prediction_logs.csv"
445
- if os.path.exists(log_path):
446
- with open(log_path, 'r', encoding='utf-8') as f:
447
- content = f.read()
448
- return content
449
- return "No logs found."
450
-
451
  # Initialize the classifier globally
452
  classifier = TextClassifier()
453
 
@@ -478,24 +471,31 @@ demo = gr.Interface(
478
  flagging_mode="never"
479
  )
480
 
481
- # Add admin panel for log access (only visible to space owners)
482
- with gr.Blocks() as admin_interface:
483
- gr.Markdown("## Admin Panel - Data Logs")
484
- download_button = gr.Button("Download Logs")
485
- log_output = gr.File(label="Prediction Logs")
486
- download_button.click(fn=download_logs, outputs=log_output)
487
-
488
- # Combine interfaces
489
- app = gr.TabbedInterface([demo, admin_interface], ["AI Text Detector", "Admin"])
490
 
491
- app.app.add_middleware(
 
492
  CORSMiddleware,
493
  allow_origins=["*"], # For development
494
  allow_credentials=True,
495
- allow_methods=["GET", "POST", "OPTIONS"], # Explicitly list methods
496
  allow_headers=["*"],
497
  )
498
 
 
 
 
 
 
 
 
 
 
 
 
 
 
499
  # Ensure CORS is applied before launching
500
  if __name__ == "__main__":
501
  demo.queue()
 
13
  import time
14
  import csv
15
  from datetime import datetime
16
+ from fastapi import FastAPI
17
+ from starlette.responses import FileResponse
18
 
19
  # Configure logging
20
  logging.basicConfig(level=logging.INFO)
 
441
  overall_result
442
  )
443
 
 
 
 
 
 
 
 
 
 
444
  # Initialize the classifier globally
445
  classifier = TextClassifier()
446
 
 
471
  flagging_mode="never"
472
  )
473
 
474
+ # Get the FastAPI app from Gradio
475
+ app = demo.app
 
 
 
 
 
 
 
476
 
477
+ # Add CORS middleware
478
+ app.add_middleware(
479
  CORSMiddleware,
480
  allow_origins=["*"], # For development
481
  allow_credentials=True,
482
+ allow_methods=["GET", "POST", "OPTIONS"],
483
  allow_headers=["*"],
484
  )
485
 
486
+ # Add FastAPI endpoint for downloading logs
487
+ @app.get("/download-logs")
488
+ async def download_logs():
489
+ """Endpoint to download the prediction logs CSV file."""
490
+ log_path = "/tmp/prediction_logs.csv"
491
+ if os.path.exists(log_path):
492
+ return FileResponse(
493
+ path=log_path,
494
+ filename="prediction_logs.csv",
495
+ media_type="text/csv"
496
+ )
497
+ return {"error": "No logs found"}
498
+
499
  # Ensure CORS is applied before launching
500
  if __name__ == "__main__":
501
  demo.queue()