Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import torch | |
import logging | |
from typing import Dict, List, Any | |
from huggingface_hub import hf_hub_download | |
from config.settings import DEVICE, HF_TOKEN, NUM_THREADS | |
from architecture.resnet import ResNet | |
logger = logging.getLogger(__name__) | |
# Configuration de PyTorch | |
torch.set_num_threads(NUM_THREADS) | |
# Instance de base pour le modèle ResNet | |
base_model = ResNet("resnet152", num_output_neurons=2).to(DEVICE) | |
# Dictionnaire global pour stocker les modèles chargés | |
model_pipelines = {} | |
async def load_models(models_data: List[Dict[str, Any]]) -> Dict[str, Any]: | |
"""Charger les modèles depuis Hugging Face à partir des données de la base de données. | |
Args: | |
models_data: Liste de dictionnaires contenant les informations des modèles | |
Returns: | |
Dictionnaire des modèles chargés | |
Raises: | |
RuntimeError: Si aucun modèle n'est trouvé ou ne peut être chargé | |
""" | |
logger.info(f"Attempting to load {len(models_data)} models...") | |
if not models_data: | |
error_msg = "No models found. API cannot start without models." | |
logger.error(error_msg) | |
raise RuntimeError(error_msg) | |
loaded_count = 0 | |
for model_data in models_data: | |
try: | |
model_name = model_data['hf_filename'] | |
logger.info(f"Loading model: {model_name} (repo: {model_data['hf_repo_id']}, subfolder: {model_data['hf_subfolder']})") | |
model_weight = hf_hub_download( | |
repo_id=model_data['hf_repo_id'], | |
subfolder=model_data['hf_subfolder'], | |
filename=model_name, | |
token=HF_TOKEN, | |
) | |
# Créer une nouvelle instance pour chaque modèle pour tenir ses poids spécifiques | |
model = base_model.__class__("resnet152", num_output_neurons=2).to(DEVICE) | |
model.load_state_dict( | |
torch.load(model_weight, weights_only=True, map_location=DEVICE) | |
) | |
model.eval() | |
model_pipelines[model_name] = model | |
loaded_count += 1 | |
except Exception as e: | |
logger.error(f"Error loading model {model_data.get('hf_filename', 'N/A')}: {e}", exc_info=True) | |
logger.info(f"Model loading finished. Successfully loaded {loaded_count}/{len(models_data)} models.") | |
if loaded_count == 0: | |
error_msg = "Failed to load any models. API cannot start without models." | |
logger.error(error_msg) | |
raise RuntimeError(error_msg) | |
return model_pipelines | |
def get_model(model_name: str): | |
"""Récupérer un modèle chargé par son nom. | |
Args: | |
model_name: Nom du modèle à récupérer | |
Returns: | |
Le modèle chargé | |
Raises: | |
KeyError: Si le modèle n'est pas trouvé | |
""" | |
if model_name not in model_pipelines: | |
logger.error(f"Model {model_name} not found in loaded models") | |
raise KeyError(f"Model {model_name} not found") | |
return model_pipelines[model_name] | |