from fastapi import FastAPI, Request from fastapi.templating import Jinja2Templates from fastapi.staticfiles import StaticFiles from fastapi.responses import JSONResponse from optimum.neuron import utils import logging import sys import os # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.StreamHandler(sys.stdout) ] ) logger = logging.getLogger(__name__) app = FastAPI() # Get the absolute path to the static directory static_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "static") logger.info(f"Static directory path: {static_dir}") # Mount static files and templates app.mount("/static", StaticFiles(directory=static_dir), name="static") templates = Jinja2Templates(directory="app/templates") @app.get("/health") async def health_check(): logger.info("Health check endpoint called") return {"status": "healthy"} @app.get("/") async def home(request: Request): logger.info("Home page requested") return templates.TemplateResponse("index.html", {"request": request}) @app.get("/api/models") async def get_model_list(): logger.info("Fetching model list") try: # Ensure cache directory exists os.makedirs("/cache", exist_ok=True) # Get actual model configurations model_list = utils.get_hub_cached_models(mode="inference") logger.info(f"Found {len(model_list)} models") if not model_list: logger.warning("No models found") return JSONResponse(content=[]) # Transform the data into the expected format models = [] seen_models = set() for model_tuple in model_list: architecture, org, model_id = model_tuple full_model_id = f"{org}/{model_id}" if full_model_id not in seen_models: models.append({ "id": full_model_id, "name": full_model_id, # This will be used as the title "type": architecture # This will be used as the subtitle }) seen_models.add(full_model_id) logger.info(f"Returning {len(models)} unique models") return JSONResponse(content=models) except Exception as e: logger.error(f"Error fetching models: {str(e)}", exc_info=True) return JSONResponse( status_code=500, content={"error": str(e)} ) @app.get("/api/models/{model_id}") async def get_model_info_endpoint(model_id: str): logger.info(f"Fetching configurations for model: {model_id}") try: configs = utils.get_hub_cached_entries(model_id=model_id, mode="inference") logger.info(f"Found {len(configs)} configurations for model {model_id}") # Return empty list if no configurations found if not configs: return JSONResponse(content={"configurations": []}) return JSONResponse(content={"configurations": configs}) except Exception as e: logger.error(f"Error fetching configurations for model {model_id}: {str(e)}", exc_info=True) return JSONResponse( status_code=500, content={"error": str(e)} )