alexfremont's picture
Refactor API architecture with modular design and database integration
e109700
raw
history blame
3.06 kB
import torch
import logging
from typing import Dict, List, Any
from huggingface_hub import hf_hub_download
from config.settings import DEVICE, HF_TOKEN, NUM_THREADS
from architecture.resnet import ResNet
logger = logging.getLogger(__name__)
# Configuration de PyTorch
torch.set_num_threads(NUM_THREADS)
# Instance de base pour le modèle ResNet
base_model = ResNet("resnet152", num_output_neurons=2).to(DEVICE)
# Dictionnaire global pour stocker les modèles chargés
model_pipelines = {}
async def load_models(models_data: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Charger les modèles depuis Hugging Face à partir des données de la base de données.
Args:
models_data: Liste de dictionnaires contenant les informations des modèles
Returns:
Dictionnaire des modèles chargés
Raises:
RuntimeError: Si aucun modèle n'est trouvé ou ne peut être chargé
"""
logger.info(f"Attempting to load {len(models_data)} models...")
if not models_data:
error_msg = "No models found. API cannot start without models."
logger.error(error_msg)
raise RuntimeError(error_msg)
loaded_count = 0
for model_data in models_data:
try:
model_name = model_data['hf_filename']
logger.info(f"Loading model: {model_name} (repo: {model_data['hf_repo_id']}, subfolder: {model_data['hf_subfolder']})")
model_weight = hf_hub_download(
repo_id=model_data['hf_repo_id'],
subfolder=model_data['hf_subfolder'],
filename=model_name,
token=HF_TOKEN,
)
# Créer une nouvelle instance pour chaque modèle pour tenir ses poids spécifiques
model = base_model.__class__("resnet152", num_output_neurons=2).to(DEVICE)
model.load_state_dict(
torch.load(model_weight, weights_only=True, map_location=DEVICE)
)
model.eval()
model_pipelines[model_name] = model
loaded_count += 1
except Exception as e:
logger.error(f"Error loading model {model_data.get('hf_filename', 'N/A')}: {e}", exc_info=True)
logger.info(f"Model loading finished. Successfully loaded {loaded_count}/{len(models_data)} models.")
if loaded_count == 0:
error_msg = "Failed to load any models. API cannot start without models."
logger.error(error_msg)
raise RuntimeError(error_msg)
return model_pipelines
def get_model(model_name: str):
"""Récupérer un modèle chargé par son nom.
Args:
model_name: Nom du modèle à récupérer
Returns:
Le modèle chargé
Raises:
KeyError: Si le modèle n'est pas trouvé
"""
if model_name not in model_pipelines:
logger.error(f"Model {model_name} not found in loaded models")
raise KeyError(f"Model {model_name} not found")
return model_pipelines[model_name]