|
from fastapi import FastAPI |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from fastapi.responses import JSONResponse |
|
from fastapi.staticfiles import StaticFiles |
|
import numpy as np |
|
import argparse |
|
import os |
|
from datasets import load_dataset |
|
|
|
HOST = os.environ.get("API_URL", "0.0.0.0") |
|
PORT = os.environ.get("PORT", 7860) |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--host", default=HOST) |
|
parser.add_argument("--port", type=int, default=PORT) |
|
parser.add_argument("--reload", action="store_true", default=True) |
|
parser.add_argument("--ssl_certfile") |
|
parser.add_argument("--ssl_keyfile") |
|
args = parser.parse_args() |
|
|
|
app = FastAPI() |
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
@app.get("/api/results") |
|
async def get_results(): |
|
|
|
dataset = load_dataset("smolagents/results", "2024-12-26") |
|
|
|
df = dataset["train"].to_pandas() |
|
|
|
|
|
print("Dataset loaded, shape:", df.shape) |
|
print("Columns:", df.columns) |
|
|
|
|
|
result = [] |
|
|
|
expected_columns = ['model_id', 'agent_action_type', 'source', 'acc'] |
|
for col in expected_columns: |
|
if col not in df.columns: |
|
print(f"Warning: Column {col} not found in dataset") |
|
|
|
|
|
for (model_id, agent_action_type), group in df.groupby(['model_id', 'agent_action_type']): |
|
|
|
benchmark_scores = {} |
|
benchmarks = ['GAIA', 'MATH', 'SimpleQA'] |
|
|
|
for benchmark in benchmarks: |
|
benchmark_group = group[group['source'] == benchmark] |
|
if not benchmark_group.empty: |
|
benchmark_scores[benchmark] = benchmark_group['acc'].mean() * 100 |
|
|
|
|
|
if benchmark_scores: |
|
benchmark_scores['Average'] = sum(benchmark_scores.values()) / len(benchmark_scores) |
|
|
|
|
|
result.append({ |
|
'model_id': model_id, |
|
'agent_action_type': agent_action_type, |
|
'scores': benchmark_scores |
|
}) |
|
|
|
print(f"Processed {len(result)} entries for the frontend") |
|
|
|
return result |
|
return data |
|
|
|
|
|
|
|
|
|
app.mount("/", StaticFiles(directory="static", html=True), name="static") |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
print(args) |
|
uvicorn.run( |
|
"app:app", |
|
host=args.host, |
|
port=args.port, |
|
reload=args.reload, |
|
ssl_certfile=args.ssl_certfile, |
|
ssl_keyfile=args.ssl_keyfile, |
|
) |