File size: 3,062 Bytes
e109700
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
import logging
from typing import Dict, List, Any
from huggingface_hub import hf_hub_download

from config.settings import DEVICE, HF_TOKEN, NUM_THREADS
from architecture.resnet import ResNet

logger = logging.getLogger(__name__)

# Configuration de PyTorch
torch.set_num_threads(NUM_THREADS)

# Instance de base pour le modèle ResNet
base_model = ResNet("resnet152", num_output_neurons=2).to(DEVICE)

# Dictionnaire global pour stocker les modèles chargés
model_pipelines = {}

async def load_models(models_data: List[Dict[str, Any]]) -> Dict[str, Any]:
    """Charger les modèles depuis Hugging Face à partir des données de la base de données.
    
    Args:
        models_data: Liste de dictionnaires contenant les informations des modèles
        
    Returns:
        Dictionnaire des modèles chargés
        
    Raises:
        RuntimeError: Si aucun modèle n'est trouvé ou ne peut être chargé
    """
    logger.info(f"Attempting to load {len(models_data)} models...")
    
    if not models_data:
        error_msg = "No models found. API cannot start without models."
        logger.error(error_msg)
        raise RuntimeError(error_msg)
    
    loaded_count = 0
    for model_data in models_data:
        try:
            model_name = model_data['hf_filename']
            logger.info(f"Loading model: {model_name} (repo: {model_data['hf_repo_id']}, subfolder: {model_data['hf_subfolder']})")
            
            model_weight = hf_hub_download(
                repo_id=model_data['hf_repo_id'],
                subfolder=model_data['hf_subfolder'],
                filename=model_name,
                token=HF_TOKEN,
            )
            
            # Créer une nouvelle instance pour chaque modèle pour tenir ses poids spécifiques
            model = base_model.__class__("resnet152", num_output_neurons=2).to(DEVICE)
            model.load_state_dict(
                torch.load(model_weight, weights_only=True, map_location=DEVICE)
            )
            model.eval()
            model_pipelines[model_name] = model
            loaded_count += 1
        except Exception as e:
            logger.error(f"Error loading model {model_data.get('hf_filename', 'N/A')}: {e}", exc_info=True)
    
    logger.info(f"Model loading finished. Successfully loaded {loaded_count}/{len(models_data)} models.")
    
    if loaded_count == 0:
        error_msg = "Failed to load any models. API cannot start without models."
        logger.error(error_msg)
        raise RuntimeError(error_msg)
    
    return model_pipelines

def get_model(model_name: str):
    """Récupérer un modèle chargé par son nom.
    
    Args:
        model_name: Nom du modèle à récupérer
        
    Returns:
        Le modèle chargé
        
    Raises:
        KeyError: Si le modèle n'est pas trouvé
    """
    if model_name not in model_pipelines:
        logger.error(f"Model {model_name} not found in loaded models")
        raise KeyError(f"Model {model_name} not found")
    
    return model_pipelines[model_name]