inference-api-g1 / api /management.py
alexfremont's picture
Refactor model loading to store metadata alongside pipelines in model_pipelines dict
dfb1c84
raw
history blame
4.11 kB
import logging
from typing import Any, Dict, List
from fastapi import APIRouter, HTTPException, Depends
from api.dependencies import verify_management_api_key
from db.models import fetch_model_by_id
from models.loader import model_pipelines, _load_single_model_pipeline
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/manage", tags=["Management"], dependencies=[Depends(verify_management_api_key)])
@router.post("/load_model/{model_db_id}", summary="Load a specific model into memory")
async def load_single_model(model_db_id: Any):
"""Charge un modèle spécifique en mémoire en utilisant son ID de la base de données."""
if model_db_id in model_pipelines:
raise HTTPException(status_code=400, detail=f"Model ID {model_db_id} is already loaded.")
try:
model_data = await fetch_model_by_id(model_db_id)
if not model_data:
raise HTTPException(status_code=404, detail=f"Model ID {model_db_id} not found in database.")
logger.info(f"Loading model {model_db_id} ({model_data['hf_filename']}) on demand...")
# Charger le modèle et ses métadonnées
loaded_data = await _load_single_model_pipeline(model_data)
# Stocker le résultat (pipeline + metadata) dans le dictionnaire global
model_pipelines[model_db_id] = loaded_data
# Si on arrive ici sans exception, le chargement a réussi
logger.info(f"Model {model_db_id} ({model_data['hf_filename']}) loaded successfully.")
return {"message": f"Model {model_db_id} loaded successfully."}
except HTTPException as http_exc:
raise http_exc # Re-lancer les exceptions HTTP déjà formatées
except Exception as e:
logger.exception(f"Error loading model {model_db_id}: {e}")
raise HTTPException(status_code=500, detail=f"Internal server error while loading model {model_db_id}.")
@router.post("/update_model/{model_db_id}", summary="Update (reload) an already loaded model")
async def update_single_model(model_db_id: Any):
"""Met à jour un modèle déjà chargé en le rechargeant depuis le Hub.
Le modèle doit être actuellement chargé pour pouvoir être mis à jour.
"""
if model_db_id not in model_pipelines:
raise HTTPException(status_code=404, detail=f"Model ID {model_db_id} is not currently loaded. Use load_model first.")
try:
model_data = await fetch_model_by_id(model_db_id)
if not model_data:
# Devrait être peu probable si le modèle est chargé, mais vérification par sécurité
raise HTTPException(status_code=404, detail=f"Model ID {model_db_id} not found in database despite being loaded.")
logger.info(f"Updating (reloading) model {model_db_id} ({model_data['hf_filename']})...")
# Supprimer l'ancien modèle de la mémoire
if model_db_id in model_pipelines:
del model_pipelines[model_db_id]
logger.info(f"Removed old instance of model {model_db_id} from memory.")
# Tenter de recharger
loaded_data = await _load_single_model_pipeline(model_data, force_download=True) # Forcer le re-téléchargement
# Stocker le résultat (pipeline + metadata) dans le dictionnaire global
model_pipelines[model_db_id] = loaded_data
# Si on arrive ici sans exception, le chargement a réussi
logger.info(f"Model {model_db_id} ({model_data['hf_filename']}) updated successfully.")
return {"message": f"Model {model_db_id} updated successfully."}
except HTTPException as http_exc:
raise http_exc
except Exception as e:
logger.exception(f"Error updating model {model_db_id}: {e}")
# Le modèle pourrait être dans un état indéterminé (déchargé ou erreur interne)
if model_db_id in model_pipelines:
del model_pipelines[model_db_id] # Assurer le déchargement en cas d'erreur imprévue
raise HTTPException(status_code=500, detail=f"Internal server error while updating model {model_db_id}. Model has been unloaded.")