alexfremont's picture
Add model management endpoints and database fetch functionality
0053356
raw
history blame
5.39 kB
import logging
import torch
from huggingface_hub import hf_hub_download
from typing import List, Dict, Any
from config.settings import DEVICE, HF_TOKEN, NUM_THREADS
from architecture.resnet import ResNet
logger = logging.getLogger(__name__)
# Configuration de PyTorch
torch.set_num_threads(NUM_THREADS)
# Instance de base pour le modèle ResNet
# Note: Peut-être pas nécessaire de l'instancier ici si chaque chargement en crée une nouvelle
# base_model = ResNet("resnet152", num_output_neurons=2).to(DEVICE)
# Dictionnaire global pour stocker les modèles chargés (pipelines)
# Clé: ID du modèle (provenant de la DB), Valeur: Pipeline/Modèle chargé
model_pipelines: Dict[Any, Any] = {}
async def _load_single_model_pipeline(model_data: Dict[str, Any]) -> Any:
"""Charge un seul pipeline de modèle à partir de ses données.
Args:
model_data: Dictionnaire contenant les informations du modèle (hf_repo_id, etc.).
Returns:
Le pipeline/modèle chargé.
Raises:
Exception: Si le chargement échoue.
"""
model_id = model_data['model_id'] # Utiliser l'ID de la DB comme clé
model_name = model_data['hf_filename']
repo_id = model_data['hf_repo_id']
subfolder = model_data['hf_subfolder']
logger.info(f"Loading model ID {model_id}: {model_name} (repo: {repo_id}, subfolder: {subfolder})")
try:
model_weight_path = hf_hub_download(
repo_id=repo_id,
subfolder=subfolder,
filename=model_name,
token=HF_TOKEN, # Assurez-vous que HF_TOKEN est géré correctement
)
logger.debug(f"Model weights downloaded to: {model_weight_path}")
# Créer une nouvelle instance de modèle ResNet pour ce chargement spécifique
# Assurez-vous que ResNet et ses arguments sont corrects
model = ResNet("resnet152", num_output_neurons=2).to(DEVICE)
# Charger les poids
# Attention: la méthode de chargement dépend du format des poids (state_dict, etc.)
state_dict = torch.load(model_weight_path, map_location=DEVICE)
# Gérer les cas où les poids sont dans une sous-clé (ex: 'state_dict', 'model')
if isinstance(state_dict, dict) and 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
elif isinstance(state_dict, dict) and 'model' in state_dict: # Autre cas commun
state_dict = state_dict['model']
# Adapter les clés si nécessaire (ex: supprimer le préfixe 'module.' de DataParallel/DDP)
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.eval() # Mettre le modèle en mode évaluation
logger.info(f"Successfully loaded model ID {model_id}: {model_name}")
return model # Retourner le modèle chargé (ou un pipeline si vous en créez un)
except Exception as e:
logger.error(f"Failed to load model ID {model_id} ({model_name}): {e}", exc_info=True)
raise # Propage l'exception pour que l'appelant puisse la gérer
async def load_models(models_data: List[Dict[str, Any]]) -> None:
"""Charger les modèles depuis Hugging Face et les stocker dans model_pipelines.
Args:
models_data: Liste de dictionnaires contenant les informations des modèles.
Raises:
RuntimeError: Si aucun modèle n'est trouvé.
"""
logger.info(f"Attempting to load {len(models_data)} models into memory...")
if not models_data:
error_msg = "No models data provided. Cannot load models."
logger.error(error_msg)
# On ne lève plus d'erreur ici, on logge juste. L'API démarrera sans modèles.
return
loaded_count = 0
failed_models = []
for model_data in models_data:
model_id = model_data.get('model_id', 'N/A') # Assurez-vous que model_id est présent
try:
# Utilise la nouvelle fonction pour charger un seul modèle
pipeline = await _load_single_model_pipeline(model_data)
# Stocke le pipeline chargé dans le dictionnaire global
model_pipelines[model_id] = pipeline
loaded_count += 1
except Exception as e:
# Log l'échec mais continue avec les autres modèles
logger.error(f"Failed to load model ID {model_id}: {e}")
failed_models.append(model_data.get('name', f'ID {model_id}'))
logger.info(f"Finished loading models. Successfully loaded: {loaded_count}/{len(models_data)}")
if failed_models:
logger.warning(f"Failed to load the following models: {', '.join(failed_models)}")
# Pas besoin de retourner les pipelines, ils sont dans le dictionnaire global
# return model_pipelines # Ancienne logique
def get_model(model_name: str):
"""Récupérer un modèle chargé par son nom.
Args:
model_name: Nom du modèle à récupérer
Returns:
Le modèle chargé
Raises:
KeyError: Si le modèle n'est pas trouvé
"""
if model_name not in model_pipelines:
logger.error(f"Model {model_name} not found in loaded models")
raise KeyError(f"Model {model_name} not found")
return model_pipelines[model_name]