import logging from typing import Any, Dict, List from fastapi import APIRouter, HTTPException, Depends from api.dependencies import verify_management_api_key from db.models import fetch_model_by_id from models.loader import model_pipelines, _load_single_model_pipeline logger = logging.getLogger(__name__) router = APIRouter(prefix="/manage", tags=["Management"], dependencies=[Depends(verify_management_api_key)]) @router.post("/load_model/{model_db_id}", summary="Load a specific model into memory") async def load_single_model(model_db_id: Any): """Charge un modèle spécifique en mémoire en utilisant son ID de la base de données.""" if model_db_id in model_pipelines: raise HTTPException(status_code=400, detail=f"Model ID {model_db_id} is already loaded.") try: model_data = await fetch_model_by_id(model_db_id) if not model_data: raise HTTPException(status_code=404, detail=f"Model ID {model_db_id} not found in database.") logger.info(f"Loading model {model_db_id} ({model_data['hf_filename']}) on demand...") # Charger le modèle et ses métadonnées loaded_data = await _load_single_model_pipeline(model_data) # Stocker le résultat (pipeline + metadata) dans le dictionnaire global model_pipelines[model_db_id] = loaded_data # Si on arrive ici sans exception, le chargement a réussi logger.info(f"Model {model_db_id} ({model_data['hf_filename']}) loaded successfully.") return {"message": f"Model {model_db_id} loaded successfully."} except HTTPException as http_exc: raise http_exc # Re-lancer les exceptions HTTP déjà formatées except Exception as e: logger.exception(f"Error loading model {model_db_id}: {e}") raise HTTPException(status_code=500, detail=f"Internal server error while loading model {model_db_id}.") @router.post("/update_model/{model_db_id}", summary="Update (reload) an already loaded model") async def update_single_model(model_db_id: Any): """Met à jour un modèle déjà chargé en le rechargeant depuis le Hub. Le modèle doit être actuellement chargé pour pouvoir être mis à jour. """ if model_db_id not in model_pipelines: raise HTTPException(status_code=404, detail=f"Model ID {model_db_id} is not currently loaded. Use load_model first.") try: model_data = await fetch_model_by_id(model_db_id) if not model_data: # Devrait être peu probable si le modèle est chargé, mais vérification par sécurité raise HTTPException(status_code=404, detail=f"Model ID {model_db_id} not found in database despite being loaded.") logger.info(f"Updating (reloading) model {model_db_id} ({model_data['hf_filename']})...") # Supprimer l'ancien modèle de la mémoire if model_db_id in model_pipelines: del model_pipelines[model_db_id] logger.info(f"Removed old instance of model {model_db_id} from memory.") # Tenter de recharger loaded_data = await _load_single_model_pipeline(model_data, force_download=True) # Forcer le re-téléchargement # Stocker le résultat (pipeline + metadata) dans le dictionnaire global model_pipelines[model_db_id] = loaded_data # Si on arrive ici sans exception, le chargement a réussi logger.info(f"Model {model_db_id} ({model_data['hf_filename']}) updated successfully.") return {"message": f"Model {model_db_id} updated successfully."} except HTTPException as http_exc: raise http_exc except Exception as e: logger.exception(f"Error updating model {model_db_id}: {e}") # Le modèle pourrait être dans un état indéterminé (déchargé ou erreur interne) if model_db_id in model_pipelines: del model_pipelines[model_db_id] # Assurer le déchargement en cas d'erreur imprévue raise HTTPException(status_code=500, detail=f"Internal server error while updating model {model_db_id}. Model has been unloaded.") @router.post("/unload_model/{model_db_id}", summary="Unload a model from memory (POST alternative)") async def delete_single_model(model_db_id: Any): """Décharge un modèle de la mémoire sans le supprimer de la base de données. Ceci permet de libérer des ressources système. Args: model_db_id: L'ID du modèle à décharger. Returns: Un message de confirmation du déchargement. Raises: HTTPException: Si le modèle n'est pas trouvé en mémoire. """ logger.info(f"Request to unload model {model_db_id} from memory") if model_db_id not in model_pipelines: raise HTTPException(status_code=404, detail=f"Model ID {model_db_id} is not currently loaded.") try: # Récupérer le nom du fichier pour le logging avant de supprimer filename = model_pipelines[model_db_id]['metadata'].get('hf_filename', 'unknown') # Récupérer une référence au modèle avant de le supprimer model_data = model_pipelines[model_db_id] pipeline = model_data.get('pipeline') # Détacher explicitement le modèle du GPU si applicable if hasattr(pipeline, 'to') and hasattr(pipeline, 'cpu'): try: pipeline.to('cpu') logger.info(f"Modèle {model_db_id} détaché du GPU") except Exception as e: logger.warning(f"Impossible de détacher le modèle du GPU: {e}") # Supprimer le modèle du dictionnaire del model_pipelines[model_db_id] # Supprimer explicitement les références del model_data del pipeline # Vider le cache PyTorch si disponible try: import torch if torch.cuda.is_available(): torch.cuda.empty_cache() logger.info("Cache CUDA vidé") except (ImportError, AttributeError) as e: logger.debug(f"Impossible de vider le cache CUDA: {e}") # Force le garbage collector plusieurs fois pour libérer la mémoire import gc gc.collect() gc.collect() # Parfois un second appel aide à libérer plus de mémoire logger.info(f"Successfully unloaded model ID {model_db_id} ({filename}) from memory") return {"message": f"Model {model_db_id} successfully unloaded from memory"} except Exception as e: logger.exception(f"Error unloading model {model_db_id}: {e}") # Si une erreur se produit pendant le déchargement, on tente quand même de supprimer if model_db_id in model_pipelines: del model_pipelines[model_db_id] raise HTTPException(status_code=500, detail=f"Internal server error while unloading model {model_db_id}")