import logging import os from contextlib import asynccontextmanager from fastapi import FastAPI, Request import gradio as gr from gradio.routes import mount_gradio_app from api.router import router from db.models import fetch_models_for_group from models.loader import load_models, model_pipelines from config.settings import RESOURCE_GROUP, DATABASE_URL from utils.system_monitor import get_memory_status, format_memory_status # Configuration de base des logs logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) logger = logging.getLogger(__name__) # --- Événements Startup/Shutdown (Lifespan Manager) --- @asynccontextmanager async def lifespan(app: FastAPI): # Code exécuté au démarrage logger.info("Starting up API...") logger.info(f"Attempting to connect to database: {DATABASE_URL[:DATABASE_URL.find('@')] + '@...'}") # Masquer les crédentiels try: models_to_load = await fetch_models_for_group(RESOURCE_GROUP) if models_to_load: await load_models(models_to_load) logger.info("Initial models loaded successfully.") else: logger.warning(f"No models found for resource group '{RESOURCE_GROUP}'. API starting without preloaded models.") except Exception as e: logger.exception(f"Failed to load initial models during startup: {e}") # Décider s'il faut empêcher le démarrage de l'API ou continuer sans modèles # raise RuntimeError("Could not load initial models, API startup aborted.") from e yield # Code exécuté à l'arrêt logger.info("Shutting down API...") # Ajouter ici le code de nettoyage si nécessaire (ex: fermer connexions persistantes) # Créer l'application FastAPI app = FastAPI( title="Tamis AI Inference API", description="API pour l'inférence des modèles de classification d'objets", version="0.1.0", lifespan=lifespan # Correction: Utilisation de la fonction lifespan définie ci-dessus ) # Ajout du Middleware ici @app.middleware("http") async def api_key_middleware(request: Request, call_next): """Middleware pour vérifier la clé API et exempter certaines routes.""" # Skip if we're in debug mode or during startup if os.environ.get("DEBUG") == "1": logger.debug("DEBUG mode active, skipping API key check.") return await call_next(request) # Liste des chemins publics ou internes à exempter de la vérification de clé public_paths = [ '/docs', '/openapi.json', # Documentation Swagger/OpenAPI '/health', # Health check endpoint '/', # Racine (Interface Gradio) '/assets/', # Assets Gradio '/file=', # Fichiers Gradio '/queue/', # Queue Gradio '/startup-logs', # Logs HF Space '/config', # Config Gradio/HF '/info', # Info Gradio/HF '/gradio', # Potentiel préfixe Gradio '/favicon.ico' # Favicon ] # Vérifie si le chemin commence par un des préfixes publics is_public = any(request.url.path == p or (p.endswith('/') and request.url.path.startswith(p)) for p in public_paths) if is_public: logger.debug(f"Public path accessed: {request.url.path}, skipping API key check.") response = await call_next(request) return response else: # Pour toutes les autres routes, la vérification se fait via Depends() sur l'endpoint lui-même. # Ce middleware ne fait donc plus de vérification active ici, # il sert juste à logger et potentiellement à exempter certaines routes si besoin. logger.debug(f"Protected path accessed: {request.url.path}. API key verification delegated to endpoint.") response = await call_next(request) return response # Inclure les routes app.include_router(router) async def init_models(): """Charger les modèles au démarrage pour Gradio et FastAPI.""" logger.info("Initializing models for Gradio and FastAPI...") try: models_data = await fetch_models_for_group(RESOURCE_GROUP) await load_models(models_data) logger.info("Models loaded successfully.") except Exception as e: logger.error(f"Failed to initialize models: {e}", exc_info=True) # Decide if the app should fail to start or continue without models # raise RuntimeError("Model initialization failed.") # For now, let's log and continue, Gradio will show an empty list pass # Définir les fonctions pour Gradio qui récupèrent les modèles chargés def get_loaded_models_list(): """Retourne la liste des métadonnées des modèles actuellement chargés.""" # Extraire les métadonnées de chaque entrée dans model_pipelines return [item['metadata'] for item in model_pipelines.values()] def format_model_info(metadata_list): """Formate les informations des modèles pour un affichage plus convivial.""" if not metadata_list: return "Aucun modèle chargé actuellement." # Créer une version formatée des informations des modèles formatted_info = "" for model in metadata_list: # Utiliser le nom du fichier comme titre principal formatted_info += f"### Modèle : {model.get('hf_filename', 'N/A')}\n" # Mettre l'ID et la dernière mise à jour sur une même ligne formatted_info += f"**ID:** {model.get('model_id', 'N/A')} | **Dernière mise à jour:** {model.get('updated_at', 'N/A')}\n\n" return formatted_info # Créer l'interface Gradio gradio_app = gr.Blocks(title="Tamis AI - Modèles Chargés", theme=gr.themes.Soft()) with gradio_app: gr.Markdown("# 🤖 Tamis AI - Interface d'administration") gr.Markdown("## Modèles actuellement chargés dans l'API") with gr.Row(): with gr.Column(scale=2): # Visualisation des modèles avec des cartes with gr.Tab("Vue détaillée"): markdown_output = gr.Markdown(value="Chargement des modèles...", elem_id="model_details") # Affichage en tableau with gr.Tab("Vue tableau"): model_table = gr.Dataframe( headers=["ID", "Fichier", "Dernière mise à jour"], datatype=["str", "str", "str"], elem_id="model_table" ) # Vue JSON (pour référence) with gr.Tab("Vue JSON (debug)"): json_output = gr.JSON(label="Données brutes") with gr.Column(scale=1): refresh_btn = gr.Button("Rafraîchir", variant="primary") status = gr.Textbox(label="Statut", value="Chargement des informations...", interactive=False, lines=8) # Fonction pour mettre à jour tous les composants d'affichage def update_all_displays(): models = get_loaded_models_list() formatted_text = format_model_info(models) # Préparer les données pour le tableau table_data = [] for model in models: table_data.append([ model.get('model_id', 'N/A'), model.get('hf_filename', 'N/A'), model.get('updated_at', 'N/A') ]) # Récupérer et formater les informations de mémoire memory_status = get_memory_status(model_pipelines) status_text = format_memory_status(memory_status) return formatted_text, table_data, models, status_text # Connecter les événements refresh_btn.click( fn=update_all_displays, outputs=[markdown_output, model_table, json_output, status] ) # Initialiser l'affichage de la mémoire dès le démarrage memory_status = get_memory_status(model_pipelines) status.value = format_memory_status(memory_status) # Initialiser les affichages au chargement gradio_app.load(fn=update_all_displays, outputs=[markdown_output, model_table, json_output, status]) # Monter l'application Gradio à la racine dans FastAPI app = mount_gradio_app( app, gradio_app, path="/" ) # Event startup to load models (ensure it runs *after* Gradio is mounted if needed) # We call init_models inside startup @app.on_event("startup") async def startup(): """Initialiser l'API : charger les modèles depuis la base de données.""" await init_models() # Call the consolidated init function @app.get("/health") async def health_check(): """Point d'entrée pour vérifier l'état de l'API.""" return {"status": "healthy"}