File size: 5,393 Bytes
e109700
0053356
e109700
0053356
e109700
 
 
 
 
 
 
 
 
 
0053356
 
e109700
0053356
 
 
e109700
0053356
 
e109700
 
0053356
e109700
 
0053356
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e109700
 
0053356
e109700
0053356
e109700
 
0053356
e109700
0053356
 
e109700
 
0053356
e109700
0053356
e109700
0053356
 
 
 
e109700
 
0053356
 
 
 
 
 
 
 
 
 
 
e109700
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import logging
import torch
from huggingface_hub import hf_hub_download
from typing import List, Dict, Any

from config.settings import DEVICE, HF_TOKEN, NUM_THREADS
from architecture.resnet import ResNet

logger = logging.getLogger(__name__)

# Configuration de PyTorch
torch.set_num_threads(NUM_THREADS)

# Instance de base pour le modèle ResNet
# Note: Peut-être pas nécessaire de l'instancier ici si chaque chargement en crée une nouvelle
# base_model = ResNet("resnet152", num_output_neurons=2).to(DEVICE)

# Dictionnaire global pour stocker les modèles chargés (pipelines)
# Clé: ID du modèle (provenant de la DB), Valeur: Pipeline/Modèle chargé
model_pipelines: Dict[Any, Any] = {}

async def _load_single_model_pipeline(model_data: Dict[str, Any]) -> Any:
    """Charge un seul pipeline de modèle à partir de ses données.
    
    Args:
        model_data: Dictionnaire contenant les informations du modèle (hf_repo_id, etc.).
        
    Returns:
        Le pipeline/modèle chargé.
        
    Raises:
        Exception: Si le chargement échoue.
    """
    model_id = model_data['model_id'] # Utiliser l'ID de la DB comme clé
    model_name = model_data['hf_filename']
    repo_id = model_data['hf_repo_id']
    subfolder = model_data['hf_subfolder']
    
    logger.info(f"Loading model ID {model_id}: {model_name} (repo: {repo_id}, subfolder: {subfolder})")
    
    try:
        model_weight_path = hf_hub_download(
            repo_id=repo_id,
            subfolder=subfolder,
            filename=model_name,
            token=HF_TOKEN, # Assurez-vous que HF_TOKEN est géré correctement
        )
        
        logger.debug(f"Model weights downloaded to: {model_weight_path}")
        
        # Créer une nouvelle instance de modèle ResNet pour ce chargement spécifique
        # Assurez-vous que ResNet et ses arguments sont corrects
        model = ResNet("resnet152", num_output_neurons=2).to(DEVICE) 
        
        # Charger les poids
        # Attention: la méthode de chargement dépend du format des poids (state_dict, etc.)
        state_dict = torch.load(model_weight_path, map_location=DEVICE)
        
        # Gérer les cas où les poids sont dans une sous-clé (ex: 'state_dict', 'model')
        if isinstance(state_dict, dict) and 'state_dict' in state_dict:
            state_dict = state_dict['state_dict']
        elif isinstance(state_dict, dict) and 'model' in state_dict: # Autre cas commun
            state_dict = state_dict['model']
            
        # Adapter les clés si nécessaire (ex: supprimer le préfixe 'module.' de DataParallel/DDP)
        state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}

        model.load_state_dict(state_dict)
        model.eval() # Mettre le modèle en mode évaluation
        
        logger.info(f"Successfully loaded model ID {model_id}: {model_name}")
        return model # Retourner le modèle chargé (ou un pipeline si vous en créez un)
        
    except Exception as e:
        logger.error(f"Failed to load model ID {model_id} ({model_name}): {e}", exc_info=True)
        raise # Propage l'exception pour que l'appelant puisse la gérer


async def load_models(models_data: List[Dict[str, Any]]) -> None:
    """Charger les modèles depuis Hugging Face et les stocker dans model_pipelines.
    
    Args:
        models_data: Liste de dictionnaires contenant les informations des modèles.
        
    Raises:
        RuntimeError: Si aucun modèle n'est trouvé.
    """
    logger.info(f"Attempting to load {len(models_data)} models into memory...")
    
    if not models_data:
        error_msg = "No models data provided. Cannot load models."
        logger.error(error_msg)
        # On ne lève plus d'erreur ici, on logge juste. L'API démarrera sans modèles.
        return
    
    loaded_count = 0
    failed_models = []
    for model_data in models_data:
        model_id = model_data.get('model_id', 'N/A') # Assurez-vous que model_id est présent
        try:
            # Utilise la nouvelle fonction pour charger un seul modèle
            pipeline = await _load_single_model_pipeline(model_data)
            # Stocke le pipeline chargé dans le dictionnaire global
            model_pipelines[model_id] = pipeline
            loaded_count += 1
        except Exception as e:
            # Log l'échec mais continue avec les autres modèles
            logger.error(f"Failed to load model ID {model_id}: {e}")
            failed_models.append(model_data.get('name', f'ID {model_id}'))

    logger.info(f"Finished loading models. Successfully loaded: {loaded_count}/{len(models_data)}")
    if failed_models:
        logger.warning(f"Failed to load the following models: {', '.join(failed_models)}")

    # Pas besoin de retourner les pipelines, ils sont dans le dictionnaire global
    # return model_pipelines # Ancienne logique


def get_model(model_name: str):
    """Récupérer un modèle chargé par son nom.
    
    Args:
        model_name: Nom du modèle à récupérer
        
    Returns:
        Le modèle chargé
        
    Raises:
        KeyError: Si le modèle n'est pas trouvé
    """
    if model_name not in model_pipelines:
        logger.error(f"Model {model_name} not found in loaded models")
        raise KeyError(f"Model {model_name} not found")
    
    return model_pipelines[model_name]