Spaces:
Sleeping
Sleeping
File size: 6,937 Bytes
bccef3b 419e526 dfb1c84 bccef3b 419e526 bccef3b dfb1c84 bccef3b dfb1c84 bccef3b db789ea ca20804 db789ea 8a8fe7c db789ea 8a8fe7c db789ea 8a8fe7c db789ea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import logging
from typing import Any, Dict, List
from fastapi import APIRouter, HTTPException, Depends
from api.dependencies import verify_management_api_key
from db.models import fetch_model_by_id
from models.loader import model_pipelines, _load_single_model_pipeline
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/manage", tags=["Management"], dependencies=[Depends(verify_management_api_key)])
@router.post("/load_model/{model_db_id}", summary="Load a specific model into memory")
async def load_single_model(model_db_id: Any):
"""Charge un modèle spécifique en mémoire en utilisant son ID de la base de données."""
if model_db_id in model_pipelines:
raise HTTPException(status_code=400, detail=f"Model ID {model_db_id} is already loaded.")
try:
model_data = await fetch_model_by_id(model_db_id)
if not model_data:
raise HTTPException(status_code=404, detail=f"Model ID {model_db_id} not found in database.")
logger.info(f"Loading model {model_db_id} ({model_data['hf_filename']}) on demand...")
# Charger le modèle et ses métadonnées
loaded_data = await _load_single_model_pipeline(model_data)
# Stocker le résultat (pipeline + metadata) dans le dictionnaire global
model_pipelines[model_db_id] = loaded_data
# Si on arrive ici sans exception, le chargement a réussi
logger.info(f"Model {model_db_id} ({model_data['hf_filename']}) loaded successfully.")
return {"message": f"Model {model_db_id} loaded successfully."}
except HTTPException as http_exc:
raise http_exc # Re-lancer les exceptions HTTP déjà formatées
except Exception as e:
logger.exception(f"Error loading model {model_db_id}: {e}")
raise HTTPException(status_code=500, detail=f"Internal server error while loading model {model_db_id}.")
@router.post("/update_model/{model_db_id}", summary="Update (reload) an already loaded model")
async def update_single_model(model_db_id: Any):
"""Met à jour un modèle déjà chargé en le rechargeant depuis le Hub.
Le modèle doit être actuellement chargé pour pouvoir être mis à jour.
"""
if model_db_id not in model_pipelines:
raise HTTPException(status_code=404, detail=f"Model ID {model_db_id} is not currently loaded. Use load_model first.")
try:
model_data = await fetch_model_by_id(model_db_id)
if not model_data:
# Devrait être peu probable si le modèle est chargé, mais vérification par sécurité
raise HTTPException(status_code=404, detail=f"Model ID {model_db_id} not found in database despite being loaded.")
logger.info(f"Updating (reloading) model {model_db_id} ({model_data['hf_filename']})...")
# Supprimer l'ancien modèle de la mémoire
if model_db_id in model_pipelines:
del model_pipelines[model_db_id]
logger.info(f"Removed old instance of model {model_db_id} from memory.")
# Tenter de recharger
loaded_data = await _load_single_model_pipeline(model_data, force_download=True) # Forcer le re-téléchargement
# Stocker le résultat (pipeline + metadata) dans le dictionnaire global
model_pipelines[model_db_id] = loaded_data
# Si on arrive ici sans exception, le chargement a réussi
logger.info(f"Model {model_db_id} ({model_data['hf_filename']}) updated successfully.")
return {"message": f"Model {model_db_id} updated successfully."}
except HTTPException as http_exc:
raise http_exc
except Exception as e:
logger.exception(f"Error updating model {model_db_id}: {e}")
# Le modèle pourrait être dans un état indéterminé (déchargé ou erreur interne)
if model_db_id in model_pipelines:
del model_pipelines[model_db_id] # Assurer le déchargement en cas d'erreur imprévue
raise HTTPException(status_code=500, detail=f"Internal server error while updating model {model_db_id}. Model has been unloaded.")
@router.post("/unload_model/{model_db_id}", summary="Unload a model from memory (POST alternative)")
async def delete_single_model(model_db_id: Any):
"""Décharge un modèle de la mémoire sans le supprimer de la base de données.
Ceci permet de libérer des ressources système.
Args:
model_db_id: L'ID du modèle à décharger.
Returns:
Un message de confirmation du déchargement.
Raises:
HTTPException: Si le modèle n'est pas trouvé en mémoire.
"""
logger.info(f"Request to unload model {model_db_id} from memory")
if model_db_id not in model_pipelines:
raise HTTPException(status_code=404, detail=f"Model ID {model_db_id} is not currently loaded.")
try:
# Récupérer le nom du fichier pour le logging avant de supprimer
filename = model_pipelines[model_db_id]['metadata'].get('hf_filename', 'unknown')
# Récupérer une référence au modèle avant de le supprimer
model_data = model_pipelines[model_db_id]
pipeline = model_data.get('pipeline')
# Détacher explicitement le modèle du GPU si applicable
if hasattr(pipeline, 'to') and hasattr(pipeline, 'cpu'):
try:
pipeline.to('cpu')
logger.info(f"Modèle {model_db_id} détaché du GPU")
except Exception as e:
logger.warning(f"Impossible de détacher le modèle du GPU: {e}")
# Supprimer le modèle du dictionnaire
del model_pipelines[model_db_id]
# Supprimer explicitement les références
del model_data
del pipeline
# Vider le cache PyTorch si disponible
try:
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info("Cache CUDA vidé")
except (ImportError, AttributeError) as e:
logger.debug(f"Impossible de vider le cache CUDA: {e}")
# Force le garbage collector plusieurs fois pour libérer la mémoire
import gc
gc.collect()
gc.collect() # Parfois un second appel aide à libérer plus de mémoire
logger.info(f"Successfully unloaded model ID {model_db_id} ({filename}) from memory")
return {"message": f"Model {model_db_id} successfully unloaded from memory"}
except Exception as e:
logger.exception(f"Error unloading model {model_db_id}: {e}")
# Si une erreur se produit pendant le déchargement, on tente quand même de supprimer
if model_db_id in model_pipelines:
del model_pipelines[model_db_id]
raise HTTPException(status_code=500, detail=f"Internal server error while unloading model {model_db_id}")
|