Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 3,062 Bytes
e109700 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import torch
import logging
from typing import Dict, List, Any
from huggingface_hub import hf_hub_download
from config.settings import DEVICE, HF_TOKEN, NUM_THREADS
from architecture.resnet import ResNet
logger = logging.getLogger(__name__)
# Configuration de PyTorch
torch.set_num_threads(NUM_THREADS)
# Instance de base pour le modèle ResNet
base_model = ResNet("resnet152", num_output_neurons=2).to(DEVICE)
# Dictionnaire global pour stocker les modèles chargés
model_pipelines = {}
async def load_models(models_data: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Charger les modèles depuis Hugging Face à partir des données de la base de données.
Args:
models_data: Liste de dictionnaires contenant les informations des modèles
Returns:
Dictionnaire des modèles chargés
Raises:
RuntimeError: Si aucun modèle n'est trouvé ou ne peut être chargé
"""
logger.info(f"Attempting to load {len(models_data)} models...")
if not models_data:
error_msg = "No models found. API cannot start without models."
logger.error(error_msg)
raise RuntimeError(error_msg)
loaded_count = 0
for model_data in models_data:
try:
model_name = model_data['hf_filename']
logger.info(f"Loading model: {model_name} (repo: {model_data['hf_repo_id']}, subfolder: {model_data['hf_subfolder']})")
model_weight = hf_hub_download(
repo_id=model_data['hf_repo_id'],
subfolder=model_data['hf_subfolder'],
filename=model_name,
token=HF_TOKEN,
)
# Créer une nouvelle instance pour chaque modèle pour tenir ses poids spécifiques
model = base_model.__class__("resnet152", num_output_neurons=2).to(DEVICE)
model.load_state_dict(
torch.load(model_weight, weights_only=True, map_location=DEVICE)
)
model.eval()
model_pipelines[model_name] = model
loaded_count += 1
except Exception as e:
logger.error(f"Error loading model {model_data.get('hf_filename', 'N/A')}: {e}", exc_info=True)
logger.info(f"Model loading finished. Successfully loaded {loaded_count}/{len(models_data)} models.")
if loaded_count == 0:
error_msg = "Failed to load any models. API cannot start without models."
logger.error(error_msg)
raise RuntimeError(error_msg)
return model_pipelines
def get_model(model_name: str):
"""Récupérer un modèle chargé par son nom.
Args:
model_name: Nom du modèle à récupérer
Returns:
Le modèle chargé
Raises:
KeyError: Si le modèle n'est pas trouvé
"""
if model_name not in model_pipelines:
logger.error(f"Model {model_name} not found in loaded models")
raise KeyError(f"Model {model_name} not found")
return model_pipelines[model_name]
|