Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, Request | |
from fastapi.templating import Jinja2Templates | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.responses import JSONResponse | |
from optimum.neuron import utils | |
import logging | |
import sys | |
import os | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
handlers=[ | |
logging.StreamHandler(sys.stdout) | |
] | |
) | |
logger = logging.getLogger(__name__) | |
app = FastAPI() | |
# Get the absolute path to the static directory | |
static_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "static") | |
logger.info(f"Static directory path: {static_dir}") | |
# Mount static files and templates | |
app.mount("/static", StaticFiles(directory=static_dir), name="static") | |
templates = Jinja2Templates(directory="app/templates") | |
async def health_check(): | |
logger.info("Health check endpoint called") | |
return {"status": "healthy"} | |
async def home(request: Request): | |
logger.info("Home page requested") | |
return templates.TemplateResponse("index.html", {"request": request}) | |
async def get_model_list(): | |
logger.info("Fetching model list") | |
try: | |
# Ensure cache directory exists | |
os.makedirs("/cache", exist_ok=True) | |
# Get actual model configurations | |
model_list = utils.get_hub_cached_models(mode="inference") | |
logger.info(f"Found {len(model_list)} models") | |
if not model_list: | |
logger.warning("No models found") | |
return JSONResponse(content=[]) | |
# Transform the data into the expected format | |
models = [] | |
seen_models = set() | |
for model_tuple in model_list: | |
architecture, org, model_id = model_tuple | |
full_model_id = f"{org}/{model_id}" | |
if full_model_id not in seen_models: | |
models.append({ | |
"id": full_model_id, | |
"name": full_model_id, # This will be used as the title | |
"type": architecture # This will be used as the subtitle | |
}) | |
seen_models.add(full_model_id) | |
logger.info(f"Returning {len(models)} unique models") | |
return JSONResponse(content=models) | |
except Exception as e: | |
logger.error(f"Error fetching models: {str(e)}", exc_info=True) | |
return JSONResponse( | |
status_code=500, | |
content={"error": str(e)} | |
) | |
async def get_model_info_endpoint(model_id: str): | |
logger.info(f"Fetching configurations for model: {model_id}") | |
try: | |
configs = utils.get_hub_cached_entries(model_id=model_id, mode="inference") | |
logger.info(f"Found {len(configs)} configurations for model {model_id}") | |
# Return empty list if no configurations found | |
if not configs: | |
return JSONResponse(content={"configurations": []}) | |
return JSONResponse(content={"configurations": configs}) | |
except Exception as e: | |
logger.error(f"Error fetching configurations for model {model_id}: {str(e)}", exc_info=True) | |
return JSONResponse( | |
status_code=500, | |
content={"error": str(e)} | |
) |