File size: 5,835 Bytes
bccef3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
419e526
dfb1c84
 
 
 
 
 
 
 
bccef3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
419e526
bccef3b
 
 
 
 
 
 
dfb1c84
 
 
 
bccef3b
dfb1c84
 
 
bccef3b
 
 
 
 
 
 
 
 
db789ea
ca20804
db789ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import logging
from typing import Any, Dict, List
from fastapi import APIRouter, HTTPException, Depends

from api.dependencies import verify_management_api_key
from db.models import fetch_model_by_id
from models.loader import model_pipelines, _load_single_model_pipeline

logger = logging.getLogger(__name__)

router = APIRouter(prefix="/manage", tags=["Management"], dependencies=[Depends(verify_management_api_key)])

@router.post("/load_model/{model_db_id}", summary="Load a specific model into memory")
async def load_single_model(model_db_id: Any):
    """Charge un modèle spécifique en mémoire en utilisant son ID de la base de données."""
    if model_db_id in model_pipelines:
        raise HTTPException(status_code=400, detail=f"Model ID {model_db_id} is already loaded.")

    try:
        model_data = await fetch_model_by_id(model_db_id)
        if not model_data:
            raise HTTPException(status_code=404, detail=f"Model ID {model_db_id} not found in database.")

        logger.info(f"Loading model {model_db_id} ({model_data['hf_filename']}) on demand...")
        # Charger le modèle et ses métadonnées
        loaded_data = await _load_single_model_pipeline(model_data)
        # Stocker le résultat (pipeline + metadata) dans le dictionnaire global
        model_pipelines[model_db_id] = loaded_data
        
        # Si on arrive ici sans exception, le chargement a réussi
        logger.info(f"Model {model_db_id} ({model_data['hf_filename']}) loaded successfully.")
        return {"message": f"Model {model_db_id} loaded successfully."}

    except HTTPException as http_exc:
        raise http_exc # Re-lancer les exceptions HTTP déjà formatées
    except Exception as e:
        logger.exception(f"Error loading model {model_db_id}: {e}")
        raise HTTPException(status_code=500, detail=f"Internal server error while loading model {model_db_id}.")

@router.post("/update_model/{model_db_id}", summary="Update (reload) an already loaded model")
async def update_single_model(model_db_id: Any):
    """Met à jour un modèle déjà chargé en le rechargeant depuis le Hub.
    Le modèle doit être actuellement chargé pour pouvoir être mis à jour.
    """
    if model_db_id not in model_pipelines:
        raise HTTPException(status_code=404, detail=f"Model ID {model_db_id} is not currently loaded. Use load_model first.")

    try:
        model_data = await fetch_model_by_id(model_db_id)
        if not model_data:
            # Devrait être peu probable si le modèle est chargé, mais vérification par sécurité
            raise HTTPException(status_code=404, detail=f"Model ID {model_db_id} not found in database despite being loaded.")

        logger.info(f"Updating (reloading) model {model_db_id} ({model_data['hf_filename']})...")
        
        # Supprimer l'ancien modèle de la mémoire
        if model_db_id in model_pipelines:
            del model_pipelines[model_db_id]
            logger.info(f"Removed old instance of model {model_db_id} from memory.")
        
        # Tenter de recharger
        loaded_data = await _load_single_model_pipeline(model_data, force_download=True) # Forcer le re-téléchargement
        
        # Stocker le résultat (pipeline + metadata) dans le dictionnaire global
        model_pipelines[model_db_id] = loaded_data
        
        # Si on arrive ici sans exception, le chargement a réussi
        logger.info(f"Model {model_db_id} ({model_data['hf_filename']}) updated successfully.")
        return {"message": f"Model {model_db_id} updated successfully."}

    except HTTPException as http_exc:
        raise http_exc
    except Exception as e:
        logger.exception(f"Error updating model {model_db_id}: {e}")
        # Le modèle pourrait être dans un état indéterminé (déchargé ou erreur interne)
        if model_db_id in model_pipelines:
            del model_pipelines[model_db_id] # Assurer le déchargement en cas d'erreur imprévue
        raise HTTPException(status_code=500, detail=f"Internal server error while updating model {model_db_id}. Model has been unloaded.")

@router.post("/unload_model/{model_db_id}", summary="Unload a model from memory (POST alternative)")
async def delete_single_model(model_db_id: Any):
    """Décharge un modèle de la mémoire sans le supprimer de la base de données.
    Ceci permet de libérer des ressources système.
    
    Args:
        model_db_id: L'ID du modèle à décharger.
        
    Returns:
        Un message de confirmation du déchargement.
        
    Raises:
        HTTPException: Si le modèle n'est pas trouvé en mémoire.
    """
    logger.info(f"Request to unload model {model_db_id} from memory")
    
    if model_db_id not in model_pipelines:
        raise HTTPException(status_code=404, detail=f"Model ID {model_db_id} is not currently loaded.")
    
    try:
        # Récupérer le nom du fichier pour le logging avant de supprimer
        filename = model_pipelines[model_db_id]['metadata'].get('hf_filename', 'unknown')
        
        # Supprimer le modèle de la mémoire
        del model_pipelines[model_db_id]
        
        # Force le garbage collector pour libérer la mémoire
        import gc
        gc.collect()
        
        logger.info(f"Successfully unloaded model ID {model_db_id} ({filename}) from memory")
        return {"message": f"Model {model_db_id} successfully unloaded from memory"}
    
    except Exception as e:
        logger.exception(f"Error unloading model {model_db_id}: {e}")
        # Si une erreur se produit pendant le déchargement, on tente quand même de supprimer
        if model_db_id in model_pipelines:
            del model_pipelines[model_db_id]
        raise HTTPException(status_code=500, detail=f"Internal server error while unloading model {model_db_id}")