Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
from fastapi import APIRouter, Request, HTTPException | |
import logging | |
import os | |
from api import prediction | |
from config.settings import API_KEY, MANAGEMENT_API_KEY | |
from db.models import fetch_model_by_id | |
from models.loader import model_pipelines, _load_single_model_pipeline, get_model | |
from models.schemas import PredictionRequest, PredictionResponse | |
logger = logging.getLogger(__name__) | |
# Router principal | |
router = APIRouter() | |
# Middleware d'authentification | |
async def verify_api_key(request: Request, call_next): | |
"""Middleware pour vérifier la clé API dans les en-têtes.""" | |
# Skip if we're in debug mode or during startup | |
if os.environ.get("DEBUG") == "1": | |
return await call_next(request) | |
# Pour les routes Hugging Face Space (logs, etc.) | |
# Aussi pour les assets statiques de Gradio | |
if request.url.path.startswith('/assets/') \ | |
or request.url.path == '/theme.css' \ | |
or request.url.path.startswith('/queue/'): | |
return await call_next(request) | |
if request.url.path == "/" and "logs=container" in request.url.query: | |
return await call_next(request) | |
api_key = request.headers.get("x-api-key") | |
if api_key is None or api_key not in API_KEY.split(','): | |
logger.warning(f"Unauthorized API access attempt from {request.client.host}") | |
raise HTTPException(status_code=403, detail="Unauthorized") | |
response = await call_next(request) | |
return response | |
# Dépendance pour la Sécurité de l'API de Gestion | |
async def verify_management_api_key(x_api_key: str = Header(...)): | |
"""Vérifie si la clé API fournie correspond à celle configurée.""" | |
if not MANAGEMENT_API_KEY: | |
logger.warning("MANAGEMENT_API_KEY is not set. Management endpoints are unsecured!") | |
# Décider si on bloque ou on autorise sans clé définie | |
# Pour la sécurité, il vaut mieux bloquer par défaut | |
raise HTTPException(status_code=500, detail="Management API key not configured on server.") | |
if x_api_key != MANAGEMENT_API_KEY: | |
logger.warning(f"Invalid or missing API key attempt for management endpoint.") | |
raise HTTPException(status_code=401, detail="Invalid or missing API Key") | |
return True # Clé valide | |
# Inclure les routes des autres modules | |
router.include_router(prediction.router, tags=["Prediction"]) | |
# Nouvel Endpoint de Gestion | |
async def load_single_model(model_db_id: Any): # L'ID peut être int ou str | |
"""Charge un modèle spécifique en mémoire en utilisant son ID de la base de données.""" | |
logger.info(f"Received request to load model with DB ID: {model_db_id}") | |
# 1. Vérifier si le modèle est déjà chargé | |
if model_db_id in model_pipelines: | |
logger.info(f"Model ID {model_db_id} is already loaded.") | |
return {"status": "success", "message": f"Model ID {model_db_id} is already loaded."} | |
# 2. Récupérer les informations du modèle depuis la DB | |
try: | |
model_data = await fetch_model_by_id(model_db_id) | |
if not model_data: | |
logger.error(f"Model ID {model_db_id} not found in database.") | |
raise HTTPException(status_code=404, detail=f"Model ID {model_db_id} not found in database.") | |
except Exception as e: | |
logger.exception(f"Database error fetching model ID {model_db_id}: {e}") | |
raise HTTPException(status_code=500, detail=f"Database error checking model ID {model_db_id}.") | |
# 3. Charger le modèle | |
try: | |
logger.info(f"Attempting to load model ID {model_db_id} ('{model_data.get('name', 'N/A')}') into memory...") | |
pipeline = await _load_single_model_pipeline(model_data) | |
# 4. Ajouter au dictionnaire des modèles chargés | |
model_pipelines[model_db_id] = pipeline | |
logger.info(f"Successfully loaded and added model ID {model_db_id} to running pipelines.") | |
return {"status": "success", "message": f"Model ID {model_db_id} loaded successfully."} | |
except Exception as e: | |
logger.exception(f"Failed to load model ID {model_db_id}: {e}") | |
# Ne pas laisser un pipeline potentiellement corrompu dans le dictionnaire | |
if model_db_id in model_pipelines: | |
del model_pipelines[model_db_id] | |
raise HTTPException(status_code=500, detail=f"Failed to load model ID {model_db_id}. Check server logs for details.") | |
async def update_single_model(model_db_id: Any): | |
"""Retélécharge et met à jour un modèle spécifique qui est déjà chargé en mémoire.""" | |
logger.info(f"Received request to update model with DB ID: {model_db_id}") | |
# 1. Vérifier si le modèle est actuellement chargé | |
if model_db_id not in model_pipelines: | |
logger.error(f"Attempted to update model ID {model_db_id}, but it is not loaded.") | |
raise HTTPException( | |
status_code=404, | |
detail=f"Model ID {model_db_id} is not currently loaded. Use load_model first." | |
) | |
# 2. Récupérer les informations du modèle depuis la DB (pour s'assurer qu'elles sont à jour si besoin) | |
try: | |
model_data = await fetch_model_by_id(model_db_id) | |
if not model_data: | |
# Ceci indiquerait une incohérence si le modèle est dans model_pipelines mais pas dans la DB | |
logger.error(f"Inconsistency: Model ID {model_db_id} loaded but not found in database during update.") | |
raise HTTPException(status_code=500, detail=f"Inconsistency: Model ID {model_db_id} not found in database.") | |
except Exception as e: | |
logger.exception(f"Database error fetching model ID {model_db_id} during update: {e}") | |
raise HTTPException(status_code=500, detail=f"Database error checking model ID {model_db_id} for update.") | |
# 3. Recharger le modèle | |
try: | |
logger.info(f"Attempting to reload model ID {model_db_id} ('{model_data.get('name', 'N/A')}') from source...") | |
# Supprimer l'ancien modèle de la mémoire avant de charger le nouveau pour libérer des ressources GPU/CPU si possible | |
# Attention : ceci pourrait causer une brève indisponibilité du modèle pendant le rechargement. | |
# Une stratégie alternative serait de charger le nouveau d'abord, puis de remplacer. | |
if model_db_id in model_pipelines: | |
del model_pipelines[model_db_id] | |
# Potentiellement forcer le nettoyage de la mémoire GPU ici si nécessaire (torch.cuda.empty_cache() - à utiliser avec prudence) | |
logger.debug(f"Removed old instance of model ID {model_db_id} from memory before update.") | |
pipeline = await _load_single_model_pipeline(model_data) | |
# 4. Mettre à jour le dictionnaire avec le nouveau pipeline | |
model_pipelines[model_db_id] = pipeline | |
logger.info(f"Successfully updated model ID {model_db_id} in running pipelines.") | |
return {"status": "success", "message": f"Model ID {model_db_id} updated successfully."} | |
except Exception as e: | |
logger.exception(f"Failed to reload model ID {model_db_id}: {e}") | |
# Si le rechargement échoue, l'ancien modèle a déjà été supprimé. | |
# Il faut soit tenter de recharger l'ancien, soit le laisser déchargé. | |
# Pour l'instant, on le laisse déchargé et on signale l'erreur. | |
raise HTTPException(status_code=500, detail=f"Failed to reload model ID {model_db_id}. Model is now unloaded. Check server logs.") | |