Update app.py
Browse files
app.py
CHANGED
@@ -13,6 +13,8 @@ from functools import partial
|
|
13 |
import time
|
14 |
import csv
|
15 |
from datetime import datetime
|
|
|
|
|
16 |
|
17 |
# Configure logging
|
18 |
logging.basicConfig(level=logging.INFO)
|
@@ -439,15 +441,6 @@ def analyze_text(text: str, mode: str, classifier: TextClassifier) -> tuple:
|
|
439 |
overall_result
|
440 |
)
|
441 |
|
442 |
-
# Add a function to download the logs
|
443 |
-
def download_logs():
|
444 |
-
log_path = "/tmp/prediction_logs.csv"
|
445 |
-
if os.path.exists(log_path):
|
446 |
-
with open(log_path, 'r', encoding='utf-8') as f:
|
447 |
-
content = f.read()
|
448 |
-
return content
|
449 |
-
return "No logs found."
|
450 |
-
|
451 |
# Initialize the classifier globally
|
452 |
classifier = TextClassifier()
|
453 |
|
@@ -478,24 +471,31 @@ demo = gr.Interface(
|
|
478 |
flagging_mode="never"
|
479 |
)
|
480 |
|
481 |
-
#
|
482 |
-
|
483 |
-
gr.Markdown("## Admin Panel - Data Logs")
|
484 |
-
download_button = gr.Button("Download Logs")
|
485 |
-
log_output = gr.File(label="Prediction Logs")
|
486 |
-
download_button.click(fn=download_logs, outputs=log_output)
|
487 |
-
|
488 |
-
# Combine interfaces
|
489 |
-
app = gr.TabbedInterface([demo, admin_interface], ["AI Text Detector", "Admin"])
|
490 |
|
491 |
-
|
|
|
492 |
CORSMiddleware,
|
493 |
allow_origins=["*"], # For development
|
494 |
allow_credentials=True,
|
495 |
-
allow_methods=["GET", "POST", "OPTIONS"],
|
496 |
allow_headers=["*"],
|
497 |
)
|
498 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
499 |
# Ensure CORS is applied before launching
|
500 |
if __name__ == "__main__":
|
501 |
demo.queue()
|
|
|
13 |
import time
|
14 |
import csv
|
15 |
from datetime import datetime
|
16 |
+
from fastapi import FastAPI
|
17 |
+
from starlette.responses import FileResponse
|
18 |
|
19 |
# Configure logging
|
20 |
logging.basicConfig(level=logging.INFO)
|
|
|
441 |
overall_result
|
442 |
)
|
443 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
444 |
# Initialize the classifier globally
|
445 |
classifier = TextClassifier()
|
446 |
|
|
|
471 |
flagging_mode="never"
|
472 |
)
|
473 |
|
474 |
+
# Get the FastAPI app from Gradio
|
475 |
+
app = demo.app
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
476 |
|
477 |
+
# Add CORS middleware
|
478 |
+
app.add_middleware(
|
479 |
CORSMiddleware,
|
480 |
allow_origins=["*"], # For development
|
481 |
allow_credentials=True,
|
482 |
+
allow_methods=["GET", "POST", "OPTIONS"],
|
483 |
allow_headers=["*"],
|
484 |
)
|
485 |
|
486 |
+
# Add FastAPI endpoint for downloading logs
|
487 |
+
@app.get("/download-logs")
|
488 |
+
async def download_logs():
|
489 |
+
"""Endpoint to download the prediction logs CSV file."""
|
490 |
+
log_path = "/tmp/prediction_logs.csv"
|
491 |
+
if os.path.exists(log_path):
|
492 |
+
return FileResponse(
|
493 |
+
path=log_path,
|
494 |
+
filename="prediction_logs.csv",
|
495 |
+
media_type="text/csv"
|
496 |
+
)
|
497 |
+
return {"error": "No logs found"}
|
498 |
+
|
499 |
# Ensure CORS is applied before launching
|
500 |
if __name__ == "__main__":
|
501 |
demo.queue()
|