Spaces:
Sleeping
Sleeping
import logging | |
from typing import Any, Dict, List | |
from fastapi import APIRouter, HTTPException, Depends | |
from api.dependencies import verify_management_api_key | |
from db.models import fetch_model_by_id | |
from models.loader import model_pipelines, _load_single_model_pipeline | |
logger = logging.getLogger(__name__) | |
router = APIRouter(prefix="/manage", tags=["Management"], dependencies=[Depends(verify_management_api_key)]) | |
async def load_single_model(model_db_id: Any): | |
"""Charge un modèle spécifique en mémoire en utilisant son ID de la base de données.""" | |
if model_db_id in model_pipelines: | |
raise HTTPException(status_code=400, detail=f"Model ID {model_db_id} is already loaded.") | |
try: | |
model_data = await fetch_model_by_id(model_db_id) | |
if not model_data: | |
raise HTTPException(status_code=404, detail=f"Model ID {model_db_id} not found in database.") | |
logger.info(f"Loading model {model_db_id} ({model_data['hf_filename']}) on demand...") | |
# Charger le modèle et ses métadonnées | |
loaded_data = await _load_single_model_pipeline(model_data) | |
# Stocker le résultat (pipeline + metadata) dans le dictionnaire global | |
model_pipelines[model_db_id] = loaded_data | |
# Si on arrive ici sans exception, le chargement a réussi | |
logger.info(f"Model {model_db_id} ({model_data['hf_filename']}) loaded successfully.") | |
return {"message": f"Model {model_db_id} loaded successfully."} | |
except HTTPException as http_exc: | |
raise http_exc # Re-lancer les exceptions HTTP déjà formatées | |
except Exception as e: | |
logger.exception(f"Error loading model {model_db_id}: {e}") | |
raise HTTPException(status_code=500, detail=f"Internal server error while loading model {model_db_id}.") | |
async def update_single_model(model_db_id: Any): | |
"""Met à jour un modèle déjà chargé en le rechargeant depuis le Hub. | |
Le modèle doit être actuellement chargé pour pouvoir être mis à jour. | |
""" | |
if model_db_id not in model_pipelines: | |
raise HTTPException(status_code=404, detail=f"Model ID {model_db_id} is not currently loaded. Use load_model first.") | |
try: | |
model_data = await fetch_model_by_id(model_db_id) | |
if not model_data: | |
# Devrait être peu probable si le modèle est chargé, mais vérification par sécurité | |
raise HTTPException(status_code=404, detail=f"Model ID {model_db_id} not found in database despite being loaded.") | |
logger.info(f"Updating (reloading) model {model_db_id} ({model_data['hf_filename']})...") | |
# Supprimer l'ancien modèle de la mémoire | |
if model_db_id in model_pipelines: | |
del model_pipelines[model_db_id] | |
logger.info(f"Removed old instance of model {model_db_id} from memory.") | |
# Tenter de recharger | |
loaded_data = await _load_single_model_pipeline(model_data, force_download=True) # Forcer le re-téléchargement | |
# Stocker le résultat (pipeline + metadata) dans le dictionnaire global | |
model_pipelines[model_db_id] = loaded_data | |
# Si on arrive ici sans exception, le chargement a réussi | |
logger.info(f"Model {model_db_id} ({model_data['hf_filename']}) updated successfully.") | |
return {"message": f"Model {model_db_id} updated successfully."} | |
except HTTPException as http_exc: | |
raise http_exc | |
except Exception as e: | |
logger.exception(f"Error updating model {model_db_id}: {e}") | |
# Le modèle pourrait être dans un état indéterminé (déchargé ou erreur interne) | |
if model_db_id in model_pipelines: | |
del model_pipelines[model_db_id] # Assurer le déchargement en cas d'erreur imprévue | |
raise HTTPException(status_code=500, detail=f"Internal server error while updating model {model_db_id}. Model has been unloaded.") | |
async def delete_single_model(model_db_id: Any): | |
"""Décharge un modèle de la mémoire sans le supprimer de la base de données. | |
Ceci permet de libérer des ressources système. | |
Args: | |
model_db_id: L'ID du modèle à décharger. | |
Returns: | |
Un message de confirmation du déchargement. | |
Raises: | |
HTTPException: Si le modèle n'est pas trouvé en mémoire. | |
""" | |
logger.info(f"Request to unload model {model_db_id} from memory") | |
if model_db_id not in model_pipelines: | |
raise HTTPException(status_code=404, detail=f"Model ID {model_db_id} is not currently loaded.") | |
try: | |
# Récupérer le nom du fichier pour le logging avant de supprimer | |
filename = model_pipelines[model_db_id]['metadata'].get('hf_filename', 'unknown') | |
# Récupérer une référence au modèle avant de le supprimer | |
model_data = model_pipelines[model_db_id] | |
pipeline = model_data.get('pipeline') | |
# Détacher explicitement le modèle du GPU si applicable | |
if hasattr(pipeline, 'to') and hasattr(pipeline, 'cpu'): | |
try: | |
pipeline.to('cpu') | |
logger.info(f"Modèle {model_db_id} détaché du GPU") | |
except Exception as e: | |
logger.warning(f"Impossible de détacher le modèle du GPU: {e}") | |
# Supprimer le modèle du dictionnaire | |
del model_pipelines[model_db_id] | |
# Supprimer explicitement les références | |
del model_data | |
del pipeline | |
# Vider le cache PyTorch si disponible | |
try: | |
import torch | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
logger.info("Cache CUDA vidé") | |
except (ImportError, AttributeError) as e: | |
logger.debug(f"Impossible de vider le cache CUDA: {e}") | |
# Force le garbage collector plusieurs fois pour libérer la mémoire | |
import gc | |
gc.collect() | |
gc.collect() # Parfois un second appel aide à libérer plus de mémoire | |
logger.info(f"Successfully unloaded model ID {model_db_id} ({filename}) from memory") | |
return {"message": f"Model {model_db_id} successfully unloaded from memory"} | |
except Exception as e: | |
logger.exception(f"Error unloading model {model_db_id}: {e}") | |
# Si une erreur se produit pendant le déchargement, on tente quand même de supprimer | |
if model_db_id in model_pipelines: | |
del model_pipelines[model_db_id] | |
raise HTTPException(status_code=500, detail=f"Internal server error while unloading model {model_db_id}") | |