Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 8,283 Bytes
1ef2263 5df1f2d 68a1cf9 77f06e1 b26a6dc f85309c b26a6dc 5df1f2d 38a3c61 e109700 f73ccd7 5df1f2d e109700 5df1f2d e109700 f73ccd7 5df1f2d 38a3c61 e109700 38a3c61 b26a6dc df3bf97 b26a6dc dfb1c84 b26a6dc df3bf97 1445bc9 df3bf97 b26a6dc df3bf97 b26a6dc df3bf97 b26a6dc df3bf97 b26a6dc 38a3c61 e109700 1445bc9 b26a6dc e109700 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
import logging
import os
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request
import gradio as gr
from gradio.routes import mount_gradio_app
from api.router import router
from db.models import fetch_models_for_group
from models.loader import load_models, model_pipelines
from config.settings import RESOURCE_GROUP, DATABASE_URL
# Configuration de base des logs
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
# --- Événements Startup/Shutdown (Lifespan Manager) ---
@asynccontextmanager
async def lifespan(app: FastAPI):
# Code exécuté au démarrage
logger.info("Starting up API...")
logger.info(f"Attempting to connect to database: {DATABASE_URL[:DATABASE_URL.find('@')] + '@...'}") # Masquer les crédentiels
try:
models_to_load = await fetch_models_for_group(RESOURCE_GROUP)
if models_to_load:
await load_models(models_to_load)
logger.info("Initial models loaded successfully.")
else:
logger.warning(f"No models found for resource group '{RESOURCE_GROUP}'. API starting without preloaded models.")
except Exception as e:
logger.exception(f"Failed to load initial models during startup: {e}")
# Décider s'il faut empêcher le démarrage de l'API ou continuer sans modèles
# raise RuntimeError("Could not load initial models, API startup aborted.") from e
yield
# Code exécuté à l'arrêt
logger.info("Shutting down API...")
# Ajouter ici le code de nettoyage si nécessaire (ex: fermer connexions persistantes)
# Créer l'application FastAPI
app = FastAPI(
title="Tamis AI Inference API",
description="API pour l'inférence des modèles de classification d'objets",
version="0.1.0",
lifespan=lifespan # Correction: Utilisation de la fonction lifespan définie ci-dessus
)
# Ajout du Middleware ici
@app.middleware("http")
async def api_key_middleware(request: Request, call_next):
"""Middleware pour vérifier la clé API et exempter certaines routes."""
# Skip if we're in debug mode or during startup
if os.environ.get("DEBUG") == "1":
logger.debug("DEBUG mode active, skipping API key check.")
return await call_next(request)
# Liste des chemins publics ou internes à exempter de la vérification de clé
public_paths = [
'/docs', '/openapi.json', # Documentation Swagger/OpenAPI
'/health', # Health check endpoint
'/', # Racine (Interface Gradio)
'/assets/', # Assets Gradio
'/file=', # Fichiers Gradio
'/queue/', # Queue Gradio
'/startup-logs', # Logs HF Space
'/config', # Config Gradio/HF
'/info', # Info Gradio/HF
'/gradio', # Potentiel préfixe Gradio
'/favicon.ico' # Favicon
]
# Vérifie si le chemin commence par un des préfixes publics
is_public = any(request.url.path == p or (p.endswith('/') and request.url.path.startswith(p)) for p in public_paths)
if is_public:
logger.debug(f"Public path accessed: {request.url.path}, skipping API key check.")
response = await call_next(request)
return response
else:
# Pour toutes les autres routes, la vérification se fait via Depends() sur l'endpoint lui-même.
# Ce middleware ne fait donc plus de vérification active ici,
# il sert juste à logger et potentiellement à exempter certaines routes si besoin.
logger.debug(f"Protected path accessed: {request.url.path}. API key verification delegated to endpoint.")
response = await call_next(request)
return response
# Inclure les routes
app.include_router(router)
async def init_models():
"""Charger les modèles au démarrage pour Gradio et FastAPI."""
logger.info("Initializing models for Gradio and FastAPI...")
try:
models_data = await fetch_models_for_group(RESOURCE_GROUP)
await load_models(models_data)
logger.info("Models loaded successfully.")
except Exception as e:
logger.error(f"Failed to initialize models: {e}", exc_info=True)
# Decide if the app should fail to start or continue without models
# raise RuntimeError("Model initialization failed.")
# For now, let's log and continue, Gradio will show an empty list
pass
# Définir les fonctions pour Gradio qui récupèrent les modèles chargés
def get_loaded_models_list():
"""Retourne la liste des métadonnées des modèles actuellement chargés."""
# Extraire les métadonnées de chaque entrée dans model_pipelines
return [item['metadata'] for item in model_pipelines.values()]
def format_model_info(metadata_list):
"""Formate les informations des modèles pour un affichage plus convivial."""
if not metadata_list:
return "Aucun modèle chargé actuellement."
# Créer une version formatée des informations des modèles
formatted_info = ""
for model in metadata_list:
# Utiliser le nom du fichier comme titre principal
formatted_info += f"### Modèle : {model.get('hf_filename', 'N/A')}\n"
# Mettre l'ID et la dernière mise à jour sur une même ligne
formatted_info += f"**ID:** {model.get('model_id', 'N/A')} | **Dernière mise à jour:** {model.get('updated_at', 'N/A')}\n\n"
return formatted_info
# Créer l'interface Gradio
gradio_app = gr.Blocks(title="Tamis AI - Modèles Chargés", theme=gr.themes.Soft())
with gradio_app:
gr.Markdown("# 🤖 Tamis AI - Interface d'administration")
gr.Markdown("## Modèles actuellement chargés dans l'API")
with gr.Row():
with gr.Column(scale=2):
# Visualisation des modèles avec des cartes
with gr.Tab("Vue détaillée"):
markdown_output = gr.Markdown(value="Chargement des modèles...", elem_id="model_details")
# Affichage en tableau
with gr.Tab("Vue tableau"):
model_table = gr.Dataframe(
headers=["ID", "Fichier", "Dernière mise à jour"],
datatype=["str", "str", "str"],
elem_id="model_table"
)
# Vue JSON (pour référence)
with gr.Tab("Vue JSON (debug)"):
json_output = gr.JSON(label="Données brutes")
with gr.Column(scale=1):
refresh_btn = gr.Button("Rafraîchir", variant="primary")
status = gr.Textbox(label="Statut", value="Prêt", interactive=False)
# Fonction pour mettre à jour tous les composants d'affichage
def update_all_displays():
models = get_loaded_models_list()
formatted_text = format_model_info(models)
# Préparer les données pour le tableau
table_data = []
for model in models:
table_data.append([
model.get('model_id', 'N/A'),
model.get('hf_filename', 'N/A'),
model.get('updated_at', 'N/A')
])
return formatted_text, table_data, models, "Modèles mis à jour"
# Connecter les événements
refresh_btn.click(
fn=update_all_displays,
outputs=[markdown_output, model_table, json_output, status]
)
# Initialiser les affichages au chargement
gradio_app.load(fn=update_all_displays, outputs=[markdown_output, model_table, json_output, status])
# Monter l'application Gradio à la racine dans FastAPI
app = mount_gradio_app(
app, gradio_app, path="/"
)
# Event startup to load models (ensure it runs *after* Gradio is mounted if needed)
# We call init_models inside startup
@app.on_event("startup")
async def startup():
"""Initialiser l'API : charger les modèles depuis la base de données."""
await init_models() # Call the consolidated init function
@app.get("/health")
async def health_check():
"""Point d'entrée pour vérifier l'état de l'API."""
return {"status": "healthy"}
|