Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 5,835 Bytes
bccef3b 419e526 dfb1c84 bccef3b 419e526 bccef3b dfb1c84 bccef3b dfb1c84 bccef3b db789ea ca20804 db789ea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import logging
from typing import Any, Dict, List
from fastapi import APIRouter, HTTPException, Depends
from api.dependencies import verify_management_api_key
from db.models import fetch_model_by_id
from models.loader import model_pipelines, _load_single_model_pipeline
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/manage", tags=["Management"], dependencies=[Depends(verify_management_api_key)])
@router.post("/load_model/{model_db_id}", summary="Load a specific model into memory")
async def load_single_model(model_db_id: Any):
"""Charge un modèle spécifique en mémoire en utilisant son ID de la base de données."""
if model_db_id in model_pipelines:
raise HTTPException(status_code=400, detail=f"Model ID {model_db_id} is already loaded.")
try:
model_data = await fetch_model_by_id(model_db_id)
if not model_data:
raise HTTPException(status_code=404, detail=f"Model ID {model_db_id} not found in database.")
logger.info(f"Loading model {model_db_id} ({model_data['hf_filename']}) on demand...")
# Charger le modèle et ses métadonnées
loaded_data = await _load_single_model_pipeline(model_data)
# Stocker le résultat (pipeline + metadata) dans le dictionnaire global
model_pipelines[model_db_id] = loaded_data
# Si on arrive ici sans exception, le chargement a réussi
logger.info(f"Model {model_db_id} ({model_data['hf_filename']}) loaded successfully.")
return {"message": f"Model {model_db_id} loaded successfully."}
except HTTPException as http_exc:
raise http_exc # Re-lancer les exceptions HTTP déjà formatées
except Exception as e:
logger.exception(f"Error loading model {model_db_id}: {e}")
raise HTTPException(status_code=500, detail=f"Internal server error while loading model {model_db_id}.")
@router.post("/update_model/{model_db_id}", summary="Update (reload) an already loaded model")
async def update_single_model(model_db_id: Any):
"""Met à jour un modèle déjà chargé en le rechargeant depuis le Hub.
Le modèle doit être actuellement chargé pour pouvoir être mis à jour.
"""
if model_db_id not in model_pipelines:
raise HTTPException(status_code=404, detail=f"Model ID {model_db_id} is not currently loaded. Use load_model first.")
try:
model_data = await fetch_model_by_id(model_db_id)
if not model_data:
# Devrait être peu probable si le modèle est chargé, mais vérification par sécurité
raise HTTPException(status_code=404, detail=f"Model ID {model_db_id} not found in database despite being loaded.")
logger.info(f"Updating (reloading) model {model_db_id} ({model_data['hf_filename']})...")
# Supprimer l'ancien modèle de la mémoire
if model_db_id in model_pipelines:
del model_pipelines[model_db_id]
logger.info(f"Removed old instance of model {model_db_id} from memory.")
# Tenter de recharger
loaded_data = await _load_single_model_pipeline(model_data, force_download=True) # Forcer le re-téléchargement
# Stocker le résultat (pipeline + metadata) dans le dictionnaire global
model_pipelines[model_db_id] = loaded_data
# Si on arrive ici sans exception, le chargement a réussi
logger.info(f"Model {model_db_id} ({model_data['hf_filename']}) updated successfully.")
return {"message": f"Model {model_db_id} updated successfully."}
except HTTPException as http_exc:
raise http_exc
except Exception as e:
logger.exception(f"Error updating model {model_db_id}: {e}")
# Le modèle pourrait être dans un état indéterminé (déchargé ou erreur interne)
if model_db_id in model_pipelines:
del model_pipelines[model_db_id] # Assurer le déchargement en cas d'erreur imprévue
raise HTTPException(status_code=500, detail=f"Internal server error while updating model {model_db_id}. Model has been unloaded.")
@router.post("/unload_model/{model_db_id}", summary="Unload a model from memory (POST alternative)")
async def delete_single_model(model_db_id: Any):
"""Décharge un modèle de la mémoire sans le supprimer de la base de données.
Ceci permet de libérer des ressources système.
Args:
model_db_id: L'ID du modèle à décharger.
Returns:
Un message de confirmation du déchargement.
Raises:
HTTPException: Si le modèle n'est pas trouvé en mémoire.
"""
logger.info(f"Request to unload model {model_db_id} from memory")
if model_db_id not in model_pipelines:
raise HTTPException(status_code=404, detail=f"Model ID {model_db_id} is not currently loaded.")
try:
# Récupérer le nom du fichier pour le logging avant de supprimer
filename = model_pipelines[model_db_id]['metadata'].get('hf_filename', 'unknown')
# Supprimer le modèle de la mémoire
del model_pipelines[model_db_id]
# Force le garbage collector pour libérer la mémoire
import gc
gc.collect()
logger.info(f"Successfully unloaded model ID {model_db_id} ({filename}) from memory")
return {"message": f"Model {model_db_id} successfully unloaded from memory"}
except Exception as e:
logger.exception(f"Error unloading model {model_db_id}: {e}")
# Si une erreur se produit pendant le déchargement, on tente quand même de supprimer
if model_db_id in model_pipelines:
del model_pipelines[model_db_id]
raise HTTPException(status_code=500, detail=f"Internal server error while unloading model {model_db_id}")
|