File size: 6,937 Bytes
bccef3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
419e526
dfb1c84
 
 
 
 
 
 
 
bccef3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
419e526
bccef3b
 
 
 
 
 
 
dfb1c84
 
 
 
bccef3b
dfb1c84
 
 
bccef3b
 
 
 
 
 
 
 
 
db789ea
ca20804
db789ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a8fe7c
 
 
 
 
 
 
 
 
 
 
 
 
db789ea
 
8a8fe7c
 
 
 
 
 
 
 
 
 
 
 
 
 
db789ea
 
8a8fe7c
db789ea
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import logging
from typing import Any, Dict, List
from fastapi import APIRouter, HTTPException, Depends

from api.dependencies import verify_management_api_key
from db.models import fetch_model_by_id
from models.loader import model_pipelines, _load_single_model_pipeline

logger = logging.getLogger(__name__)

router = APIRouter(prefix="/manage", tags=["Management"], dependencies=[Depends(verify_management_api_key)])

@router.post("/load_model/{model_db_id}", summary="Load a specific model into memory")
async def load_single_model(model_db_id: Any):
    """Charge un modèle spécifique en mémoire en utilisant son ID de la base de données."""
    if model_db_id in model_pipelines:
        raise HTTPException(status_code=400, detail=f"Model ID {model_db_id} is already loaded.")

    try:
        model_data = await fetch_model_by_id(model_db_id)
        if not model_data:
            raise HTTPException(status_code=404, detail=f"Model ID {model_db_id} not found in database.")

        logger.info(f"Loading model {model_db_id} ({model_data['hf_filename']}) on demand...")
        # Charger le modèle et ses métadonnées
        loaded_data = await _load_single_model_pipeline(model_data)
        # Stocker le résultat (pipeline + metadata) dans le dictionnaire global
        model_pipelines[model_db_id] = loaded_data
        
        # Si on arrive ici sans exception, le chargement a réussi
        logger.info(f"Model {model_db_id} ({model_data['hf_filename']}) loaded successfully.")
        return {"message": f"Model {model_db_id} loaded successfully."}

    except HTTPException as http_exc:
        raise http_exc # Re-lancer les exceptions HTTP déjà formatées
    except Exception as e:
        logger.exception(f"Error loading model {model_db_id}: {e}")
        raise HTTPException(status_code=500, detail=f"Internal server error while loading model {model_db_id}.")

@router.post("/update_model/{model_db_id}", summary="Update (reload) an already loaded model")
async def update_single_model(model_db_id: Any):
    """Met à jour un modèle déjà chargé en le rechargeant depuis le Hub.
    Le modèle doit être actuellement chargé pour pouvoir être mis à jour.
    """
    if model_db_id not in model_pipelines:
        raise HTTPException(status_code=404, detail=f"Model ID {model_db_id} is not currently loaded. Use load_model first.")

    try:
        model_data = await fetch_model_by_id(model_db_id)
        if not model_data:
            # Devrait être peu probable si le modèle est chargé, mais vérification par sécurité
            raise HTTPException(status_code=404, detail=f"Model ID {model_db_id} not found in database despite being loaded.")

        logger.info(f"Updating (reloading) model {model_db_id} ({model_data['hf_filename']})...")
        
        # Supprimer l'ancien modèle de la mémoire
        if model_db_id in model_pipelines:
            del model_pipelines[model_db_id]
            logger.info(f"Removed old instance of model {model_db_id} from memory.")
        
        # Tenter de recharger
        loaded_data = await _load_single_model_pipeline(model_data, force_download=True) # Forcer le re-téléchargement
        
        # Stocker le résultat (pipeline + metadata) dans le dictionnaire global
        model_pipelines[model_db_id] = loaded_data
        
        # Si on arrive ici sans exception, le chargement a réussi
        logger.info(f"Model {model_db_id} ({model_data['hf_filename']}) updated successfully.")
        return {"message": f"Model {model_db_id} updated successfully."}

    except HTTPException as http_exc:
        raise http_exc
    except Exception as e:
        logger.exception(f"Error updating model {model_db_id}: {e}")
        # Le modèle pourrait être dans un état indéterminé (déchargé ou erreur interne)
        if model_db_id in model_pipelines:
            del model_pipelines[model_db_id] # Assurer le déchargement en cas d'erreur imprévue
        raise HTTPException(status_code=500, detail=f"Internal server error while updating model {model_db_id}. Model has been unloaded.")

@router.post("/unload_model/{model_db_id}", summary="Unload a model from memory (POST alternative)")
async def delete_single_model(model_db_id: Any):
    """Décharge un modèle de la mémoire sans le supprimer de la base de données.
    Ceci permet de libérer des ressources système.
    
    Args:
        model_db_id: L'ID du modèle à décharger.
        
    Returns:
        Un message de confirmation du déchargement.
        
    Raises:
        HTTPException: Si le modèle n'est pas trouvé en mémoire.
    """
    logger.info(f"Request to unload model {model_db_id} from memory")
    
    if model_db_id not in model_pipelines:
        raise HTTPException(status_code=404, detail=f"Model ID {model_db_id} is not currently loaded.")
    
    try:
        # Récupérer le nom du fichier pour le logging avant de supprimer
        filename = model_pipelines[model_db_id]['metadata'].get('hf_filename', 'unknown')
        
        # Récupérer une référence au modèle avant de le supprimer
        model_data = model_pipelines[model_db_id]
        pipeline = model_data.get('pipeline')
        
        # Détacher explicitement le modèle du GPU si applicable
        if hasattr(pipeline, 'to') and hasattr(pipeline, 'cpu'):
            try:
                pipeline.to('cpu')
                logger.info(f"Modèle {model_db_id} détaché du GPU")
            except Exception as e:
                logger.warning(f"Impossible de détacher le modèle du GPU: {e}")
        
        # Supprimer le modèle du dictionnaire
        del model_pipelines[model_db_id]
        
        # Supprimer explicitement les références
        del model_data
        del pipeline
        
        # Vider le cache PyTorch si disponible
        try:
            import torch
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                logger.info("Cache CUDA vidé")
        except (ImportError, AttributeError) as e:
            logger.debug(f"Impossible de vider le cache CUDA: {e}")
        
        # Force le garbage collector plusieurs fois pour libérer la mémoire
        import gc
        gc.collect()
        gc.collect()  # Parfois un second appel aide à libérer plus de mémoire
        
        logger.info(f"Successfully unloaded model ID {model_db_id} ({filename}) from memory")
        return {"message": f"Model {model_db_id} successfully unloaded from memory"}
    
    except Exception as e:
        logger.exception(f"Error unloading model {model_db_id}: {e}")
        # Si une erreur se produit pendant le déchargement, on tente quand même de supprimer
        if model_db_id in model_pipelines:
            del model_pipelines[model_db_id]
        raise HTTPException(status_code=500, detail=f"Internal server error while unloading model {model_db_id}")