File size: 6,482 Bytes
e109700
0053356
e109700
0053356
e109700
 
 
 
 
 
 
 
 
 
0053356
 
e109700
dfb1c84
 
 
 
e109700
dfb1c84
 
e109700
 
0053356
638ed9d
e109700
 
dfb1c84
0053356
 
 
 
 
 
 
 
dfb1c84
0053356
 
 
 
 
 
 
 
 
638ed9d
0053356
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dfb1c84
 
 
 
 
 
 
0053356
dfb1c84
0053356
 
 
 
 
 
 
 
 
 
 
e109700
 
0053356
e109700
0053356
e109700
 
0053356
e109700
0053356
 
e109700
 
0053356
e109700
0053356
e109700
dfb1c84
 
 
 
e109700
 
0053356
 
c2cd706
0053356
 
 
 
 
 
 
 
e109700
9f9c6d5
 
e109700
 
9f9c6d5
e109700
 
dfb1c84
e109700
 
 
 
9f9c6d5
 
 
 
 
 
e109700
9f9c6d5
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import logging
import torch
from huggingface_hub import hf_hub_download
from typing import List, Dict, Any

from config.settings import DEVICE, HF_TOKEN, NUM_THREADS
from architecture.resnet import ResNet

logger = logging.getLogger(__name__)

# Configuration de PyTorch
torch.set_num_threads(NUM_THREADS)

# Instance de base pour le modèle ResNet
# Note: Peut-être pas nécessaire de l'instancier ici si chaque chargement en crée une nouvelle
# base_model = ResNet("resnet152", num_output_neurons=2).to(DEVICE)

# Dictionnaire global pour stocker les modèles chargés et leurs métadonnées
# Clé: ID du modèle (provenant de la DB)
# Valeur: Dict{'pipeline': <modèle chargé>, 'metadata': {'model_id': ..., 'hf_filename': ...}}
model_pipelines: Dict[Any, Dict[str, Any]] = {}

async def _load_single_model_pipeline(model_data: Dict[str, Any], force_download: bool = False) -> Dict[str, Any]:
    """Charge un seul pipeline de modèle et retourne le pipeline ainsi que ses métadonnées.
    
    Args:
        model_data: Dictionnaire contenant les informations du modèle (hf_repo_id, etc.).
        force_download: Si True, force le re-téléchargement depuis Hugging Face Hub.
        
    Returns:
        Un dictionnaire contenant le 'pipeline' chargé et les 'metadata'.
        
    Raises:
        Exception: Si le chargement échoue.
    """
    model_id = model_data['model_id'] # Utiliser l'ID de la DB comme clé
    model_name = model_data['hf_filename']
    repo_id = model_data['hf_repo_id']
    subfolder = model_data['hf_subfolder']
    updated_at = model_data.get('updated_at', None) # Récupérer updated_at
    
    logger.info(f"Loading model ID {model_id}: {model_name} (repo: {repo_id}, subfolder: {subfolder})")
    
    try:
        model_weight_path = hf_hub_download(
            repo_id=repo_id,
            subfolder=subfolder,
            filename=model_name,
            token=HF_TOKEN, # Assurez-vous que HF_TOKEN est géré correctement
            force_download=force_download # Utiliser le paramètre force_download
        )
        
        logger.debug(f"Model weights downloaded to: {model_weight_path}")
        
        # Créer une nouvelle instance de modèle ResNet pour ce chargement spécifique
        # Assurez-vous que ResNet et ses arguments sont corrects
        model = ResNet("resnet152", num_output_neurons=2).to(DEVICE) 
        
        # Charger les poids
        # Attention: la méthode de chargement dépend du format des poids (state_dict, etc.)
        state_dict = torch.load(model_weight_path, map_location=DEVICE)
        
        # Gérer les cas où les poids sont dans une sous-clé (ex: 'state_dict', 'model')
        if isinstance(state_dict, dict) and 'state_dict' in state_dict:
            state_dict = state_dict['state_dict']
        elif isinstance(state_dict, dict) and 'model' in state_dict: # Autre cas commun
            state_dict = state_dict['model']
            
        # Adapter les clés si nécessaire (ex: supprimer le préfixe 'module.' de DataParallel/DDP)
        state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}

        model.load_state_dict(state_dict)
        model.eval() # Mettre le modèle en mode évaluation
        
        metadata = {
            'model_id': model_id,
            'hf_filename': model_name,
            'updated_at': str(updated_at) if updated_at else None # Ajouter updated_at (converti en str pour JSON)
            # Ajoutez d'autres métadonnées utiles ici si nécessaire
        }
        
        logger.info(f"Successfully loaded model ID {model_id}: {model_name}")
        return {'pipeline': model, 'metadata': metadata}
        
    except Exception as e:
        logger.error(f"Failed to load model ID {model_id} ({model_name}): {e}", exc_info=True)
        raise # Propage l'exception pour que l'appelant puisse la gérer


async def load_models(models_data: List[Dict[str, Any]]) -> None:
    """Charger les modèles depuis Hugging Face et les stocker dans model_pipelines.
    
    Args:
        models_data: Liste de dictionnaires contenant les informations des modèles.
        
    Raises:
        RuntimeError: Si aucun modèle n'est trouvé.
    """
    logger.info(f"Attempting to load {len(models_data)} models into memory...")
    
    if not models_data:
        error_msg = "No models data provided. Cannot load models."
        logger.error(error_msg)
        # On ne lève plus d'erreur ici, on logge juste. L'API démarrera sans modèles.
        return
    
    loaded_count = 0
    failed_models = []
    for model_data in models_data:
        model_id = model_data.get('model_id', 'N/A') # Assurez-vous que model_id est présent
        try:
            # Utilise la nouvelle fonction pour charger un seul modèle et ses métadonnées
            loaded_data = await _load_single_model_pipeline(model_data)
            # Stocke l'ensemble (pipeline + metadata) dans le dictionnaire global
            model_pipelines[model_id] = loaded_data
            loaded_count += 1
        except Exception as e:
            # Log l'échec mais continue avec les autres modèles
            logger.error(f"Failed to load model ID {model_id}: {e}")
            failed_models.append(model_data.get('display_name', f'ID {model_id}'))

    logger.info(f"Finished loading models. Successfully loaded: {loaded_count}/{len(models_data)}")
    if failed_models:
        logger.warning(f"Failed to load the following models: {', '.join(failed_models)}")

    # Pas besoin de retourner les pipelines, ils sont dans le dictionnaire global
    # return model_pipelines # Ancienne logique


def get_model(model_name: str):
    """Récupérer un modèle chargé par son nom de fichier (hf_filename).
    
    Args:
        model_name: Nom du fichier du modèle (hf_filename) à récupérer
        
    Returns:
        Le modèle chargé (le pipeline)
        
    Raises:
        KeyError: Si le modèle n'est pas trouvé
    """
    # Rechercher par hf_filename
    for model_id, model_data in model_pipelines.items():
        metadata = model_data.get('metadata', {})
        if metadata.get('hf_filename') == model_name:
            logger.info(f"Model found: {model_name} (ID: {model_id})")
            return model_data['pipeline']
    
    # Si on arrive ici, le modèle n'a pas été trouvé
    logger.error(f"Model {model_name} not found in loaded models")
    raise KeyError(f"Model {model_name} not found")