File size: 4,112 Bytes
bccef3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
419e526
dfb1c84
 
 
 
 
 
 
 
bccef3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
419e526
bccef3b
 
 
 
 
 
 
dfb1c84
 
 
 
bccef3b
dfb1c84
 
 
bccef3b
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import logging
from typing import Any, Dict, List
from fastapi import APIRouter, HTTPException, Depends

from api.dependencies import verify_management_api_key
from db.models import fetch_model_by_id
from models.loader import model_pipelines, _load_single_model_pipeline

logger = logging.getLogger(__name__)

router = APIRouter(prefix="/manage", tags=["Management"], dependencies=[Depends(verify_management_api_key)])

@router.post("/load_model/{model_db_id}", summary="Load a specific model into memory")
async def load_single_model(model_db_id: Any):
    """Charge un modèle spécifique en mémoire en utilisant son ID de la base de données."""
    if model_db_id in model_pipelines:
        raise HTTPException(status_code=400, detail=f"Model ID {model_db_id} is already loaded.")

    try:
        model_data = await fetch_model_by_id(model_db_id)
        if not model_data:
            raise HTTPException(status_code=404, detail=f"Model ID {model_db_id} not found in database.")

        logger.info(f"Loading model {model_db_id} ({model_data['hf_filename']}) on demand...")
        # Charger le modèle et ses métadonnées
        loaded_data = await _load_single_model_pipeline(model_data)
        # Stocker le résultat (pipeline + metadata) dans le dictionnaire global
        model_pipelines[model_db_id] = loaded_data
        
        # Si on arrive ici sans exception, le chargement a réussi
        logger.info(f"Model {model_db_id} ({model_data['hf_filename']}) loaded successfully.")
        return {"message": f"Model {model_db_id} loaded successfully."}

    except HTTPException as http_exc:
        raise http_exc # Re-lancer les exceptions HTTP déjà formatées
    except Exception as e:
        logger.exception(f"Error loading model {model_db_id}: {e}")
        raise HTTPException(status_code=500, detail=f"Internal server error while loading model {model_db_id}.")

@router.post("/update_model/{model_db_id}", summary="Update (reload) an already loaded model")
async def update_single_model(model_db_id: Any):
    """Met à jour un modèle déjà chargé en le rechargeant depuis le Hub.
    Le modèle doit être actuellement chargé pour pouvoir être mis à jour.
    """
    if model_db_id not in model_pipelines:
        raise HTTPException(status_code=404, detail=f"Model ID {model_db_id} is not currently loaded. Use load_model first.")

    try:
        model_data = await fetch_model_by_id(model_db_id)
        if not model_data:
            # Devrait être peu probable si le modèle est chargé, mais vérification par sécurité
            raise HTTPException(status_code=404, detail=f"Model ID {model_db_id} not found in database despite being loaded.")

        logger.info(f"Updating (reloading) model {model_db_id} ({model_data['hf_filename']})...")
        
        # Supprimer l'ancien modèle de la mémoire
        if model_db_id in model_pipelines:
            del model_pipelines[model_db_id]
            logger.info(f"Removed old instance of model {model_db_id} from memory.")
        
        # Tenter de recharger
        loaded_data = await _load_single_model_pipeline(model_data, force_download=True) # Forcer le re-téléchargement
        
        # Stocker le résultat (pipeline + metadata) dans le dictionnaire global
        model_pipelines[model_db_id] = loaded_data
        
        # Si on arrive ici sans exception, le chargement a réussi
        logger.info(f"Model {model_db_id} ({model_data['hf_filename']}) updated successfully.")
        return {"message": f"Model {model_db_id} updated successfully."}

    except HTTPException as http_exc:
        raise http_exc
    except Exception as e:
        logger.exception(f"Error updating model {model_db_id}: {e}")
        # Le modèle pourrait être dans un état indéterminé (déchargé ou erreur interne)
        if model_db_id in model_pipelines:
            del model_pipelines[model_db_id] # Assurer le déchargement en cas d'erreur imprévue
        raise HTTPException(status_code=500, detail=f"Internal server error while updating model {model_db_id}. Model has been unloaded.")