inference-api-g1 / api /router.py
alexfremont's picture
Update requirements.txt with new package dependencies
832b9bb
raw
history blame
7.82 kB
from fastapi import APIRouter, Request, HTTPException
import logging
import os
from api import prediction
from config.settings import API_KEY, MANAGEMENT_API_KEY
from db.models import fetch_model_by_id
from models.loader import model_pipelines, _load_single_model_pipeline, get_model
from models.schemas import PredictionRequest, PredictionResponse
logger = logging.getLogger(__name__)
# Router principal
router = APIRouter()
# Middleware d'authentification
async def verify_api_key(request: Request, call_next):
"""Middleware pour vérifier la clé API dans les en-têtes."""
# Skip if we're in debug mode or during startup
if os.environ.get("DEBUG") == "1":
return await call_next(request)
# Pour les routes Hugging Face Space (logs, etc.)
# Aussi pour les assets statiques de Gradio
if request.url.path.startswith('/assets/') \
or request.url.path == '/theme.css' \
or request.url.path.startswith('/queue/'):
return await call_next(request)
if request.url.path == "/" and "logs=container" in request.url.query:
return await call_next(request)
api_key = request.headers.get("x-api-key")
if api_key is None or api_key not in API_KEY.split(','):
logger.warning(f"Unauthorized API access attempt from {request.client.host}")
raise HTTPException(status_code=403, detail="Unauthorized")
response = await call_next(request)
return response
# Dépendance pour la Sécurité de l'API de Gestion
async def verify_management_api_key(x_api_key: str = Header(...)):
"""Vérifie si la clé API fournie correspond à celle configurée."""
if not MANAGEMENT_API_KEY:
logger.warning("MANAGEMENT_API_KEY is not set. Management endpoints are unsecured!")
# Décider si on bloque ou on autorise sans clé définie
# Pour la sécurité, il vaut mieux bloquer par défaut
raise HTTPException(status_code=500, detail="Management API key not configured on server.")
if x_api_key != MANAGEMENT_API_KEY:
logger.warning(f"Invalid or missing API key attempt for management endpoint.")
raise HTTPException(status_code=401, detail="Invalid or missing API Key")
return True # Clé valide
# Inclure les routes des autres modules
router.include_router(prediction.router, tags=["Prediction"])
# Nouvel Endpoint de Gestion
@router.post(
"/manage/load_model/{model_db_id}",
summary="Load a specific model into memory",
dependencies=[Depends(verify_management_api_key)] # Sécurise l'endpoint
)
async def load_single_model(model_db_id: Any): # L'ID peut être int ou str
"""Charge un modèle spécifique en mémoire en utilisant son ID de la base de données."""
logger.info(f"Received request to load model with DB ID: {model_db_id}")
# 1. Vérifier si le modèle est déjà chargé
if model_db_id in model_pipelines:
logger.info(f"Model ID {model_db_id} is already loaded.")
return {"status": "success", "message": f"Model ID {model_db_id} is already loaded."}
# 2. Récupérer les informations du modèle depuis la DB
try:
model_data = await fetch_model_by_id(model_db_id)
if not model_data:
logger.error(f"Model ID {model_db_id} not found in database.")
raise HTTPException(status_code=404, detail=f"Model ID {model_db_id} not found in database.")
except Exception as e:
logger.exception(f"Database error fetching model ID {model_db_id}: {e}")
raise HTTPException(status_code=500, detail=f"Database error checking model ID {model_db_id}.")
# 3. Charger le modèle
try:
logger.info(f"Attempting to load model ID {model_db_id} ('{model_data.get('name', 'N/A')}') into memory...")
pipeline = await _load_single_model_pipeline(model_data)
# 4. Ajouter au dictionnaire des modèles chargés
model_pipelines[model_db_id] = pipeline
logger.info(f"Successfully loaded and added model ID {model_db_id} to running pipelines.")
return {"status": "success", "message": f"Model ID {model_db_id} loaded successfully."}
except Exception as e:
logger.exception(f"Failed to load model ID {model_db_id}: {e}")
# Ne pas laisser un pipeline potentiellement corrompu dans le dictionnaire
if model_db_id in model_pipelines:
del model_pipelines[model_db_id]
raise HTTPException(status_code=500, detail=f"Failed to load model ID {model_db_id}. Check server logs for details.")
@router.post(
"/manage/update_model/{model_db_id}",
summary="Reload/Update a specific model already in memory",
dependencies=[Depends(verify_management_api_key)] # Sécurise l'endpoint
)
async def update_single_model(model_db_id: Any):
"""Retélécharge et met à jour un modèle spécifique qui est déjà chargé en mémoire."""
logger.info(f"Received request to update model with DB ID: {model_db_id}")
# 1. Vérifier si le modèle est actuellement chargé
if model_db_id not in model_pipelines:
logger.error(f"Attempted to update model ID {model_db_id}, but it is not loaded.")
raise HTTPException(
status_code=404,
detail=f"Model ID {model_db_id} is not currently loaded. Use load_model first."
)
# 2. Récupérer les informations du modèle depuis la DB (pour s'assurer qu'elles sont à jour si besoin)
try:
model_data = await fetch_model_by_id(model_db_id)
if not model_data:
# Ceci indiquerait une incohérence si le modèle est dans model_pipelines mais pas dans la DB
logger.error(f"Inconsistency: Model ID {model_db_id} loaded but not found in database during update.")
raise HTTPException(status_code=500, detail=f"Inconsistency: Model ID {model_db_id} not found in database.")
except Exception as e:
logger.exception(f"Database error fetching model ID {model_db_id} during update: {e}")
raise HTTPException(status_code=500, detail=f"Database error checking model ID {model_db_id} for update.")
# 3. Recharger le modèle
try:
logger.info(f"Attempting to reload model ID {model_db_id} ('{model_data.get('name', 'N/A')}') from source...")
# Supprimer l'ancien modèle de la mémoire avant de charger le nouveau pour libérer des ressources GPU/CPU si possible
# Attention : ceci pourrait causer une brève indisponibilité du modèle pendant le rechargement.
# Une stratégie alternative serait de charger le nouveau d'abord, puis de remplacer.
if model_db_id in model_pipelines:
del model_pipelines[model_db_id]
# Potentiellement forcer le nettoyage de la mémoire GPU ici si nécessaire (torch.cuda.empty_cache() - à utiliser avec prudence)
logger.debug(f"Removed old instance of model ID {model_db_id} from memory before update.")
pipeline = await _load_single_model_pipeline(model_data)
# 4. Mettre à jour le dictionnaire avec le nouveau pipeline
model_pipelines[model_db_id] = pipeline
logger.info(f"Successfully updated model ID {model_db_id} in running pipelines.")
return {"status": "success", "message": f"Model ID {model_db_id} updated successfully."}
except Exception as e:
logger.exception(f"Failed to reload model ID {model_db_id}: {e}")
# Si le rechargement échoue, l'ancien modèle a déjà été supprimé.
# Il faut soit tenter de recharger l'ancien, soit le laisser déchargé.
# Pour l'instant, on le laisse déchargé et on signale l'erreur.
raise HTTPException(status_code=500, detail=f"Failed to reload model ID {model_db_id}. Model is now unloaded. Check server logs.")