alexfremont's picture
Update model lookup to use filename instead of ID in get_model function
9f9c6d5
import logging
import torch
from huggingface_hub import hf_hub_download
from typing import List, Dict, Any
from config.settings import DEVICE, HF_TOKEN, NUM_THREADS
from architecture.resnet import ResNet
logger = logging.getLogger(__name__)
# Configuration de PyTorch
torch.set_num_threads(NUM_THREADS)
# Instance de base pour le modèle ResNet
# Note: Peut-être pas nécessaire de l'instancier ici si chaque chargement en crée une nouvelle
# base_model = ResNet("resnet152", num_output_neurons=2).to(DEVICE)
# Dictionnaire global pour stocker les modèles chargés et leurs métadonnées
# Clé: ID du modèle (provenant de la DB)
# Valeur: Dict{'pipeline': <modèle chargé>, 'metadata': {'model_id': ..., 'hf_filename': ...}}
model_pipelines: Dict[Any, Dict[str, Any]] = {}
async def _load_single_model_pipeline(model_data: Dict[str, Any], force_download: bool = False) -> Dict[str, Any]:
"""Charge un seul pipeline de modèle et retourne le pipeline ainsi que ses métadonnées.
Args:
model_data: Dictionnaire contenant les informations du modèle (hf_repo_id, etc.).
force_download: Si True, force le re-téléchargement depuis Hugging Face Hub.
Returns:
Un dictionnaire contenant le 'pipeline' chargé et les 'metadata'.
Raises:
Exception: Si le chargement échoue.
"""
model_id = model_data['model_id'] # Utiliser l'ID de la DB comme clé
model_name = model_data['hf_filename']
repo_id = model_data['hf_repo_id']
subfolder = model_data['hf_subfolder']
updated_at = model_data.get('updated_at', None) # Récupérer updated_at
logger.info(f"Loading model ID {model_id}: {model_name} (repo: {repo_id}, subfolder: {subfolder})")
try:
model_weight_path = hf_hub_download(
repo_id=repo_id,
subfolder=subfolder,
filename=model_name,
token=HF_TOKEN, # Assurez-vous que HF_TOKEN est géré correctement
force_download=force_download # Utiliser le paramètre force_download
)
logger.debug(f"Model weights downloaded to: {model_weight_path}")
# Créer une nouvelle instance de modèle ResNet pour ce chargement spécifique
# Assurez-vous que ResNet et ses arguments sont corrects
model = ResNet("resnet152", num_output_neurons=2).to(DEVICE)
# Charger les poids
# Attention: la méthode de chargement dépend du format des poids (state_dict, etc.)
state_dict = torch.load(model_weight_path, map_location=DEVICE)
# Gérer les cas où les poids sont dans une sous-clé (ex: 'state_dict', 'model')
if isinstance(state_dict, dict) and 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
elif isinstance(state_dict, dict) and 'model' in state_dict: # Autre cas commun
state_dict = state_dict['model']
# Adapter les clés si nécessaire (ex: supprimer le préfixe 'module.' de DataParallel/DDP)
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.eval() # Mettre le modèle en mode évaluation
metadata = {
'model_id': model_id,
'hf_filename': model_name,
'updated_at': str(updated_at) if updated_at else None # Ajouter updated_at (converti en str pour JSON)
# Ajoutez d'autres métadonnées utiles ici si nécessaire
}
logger.info(f"Successfully loaded model ID {model_id}: {model_name}")
return {'pipeline': model, 'metadata': metadata}
except Exception as e:
logger.error(f"Failed to load model ID {model_id} ({model_name}): {e}", exc_info=True)
raise # Propage l'exception pour que l'appelant puisse la gérer
async def load_models(models_data: List[Dict[str, Any]]) -> None:
"""Charger les modèles depuis Hugging Face et les stocker dans model_pipelines.
Args:
models_data: Liste de dictionnaires contenant les informations des modèles.
Raises:
RuntimeError: Si aucun modèle n'est trouvé.
"""
logger.info(f"Attempting to load {len(models_data)} models into memory...")
if not models_data:
error_msg = "No models data provided. Cannot load models."
logger.error(error_msg)
# On ne lève plus d'erreur ici, on logge juste. L'API démarrera sans modèles.
return
loaded_count = 0
failed_models = []
for model_data in models_data:
model_id = model_data.get('model_id', 'N/A') # Assurez-vous que model_id est présent
try:
# Utilise la nouvelle fonction pour charger un seul modèle et ses métadonnées
loaded_data = await _load_single_model_pipeline(model_data)
# Stocke l'ensemble (pipeline + metadata) dans le dictionnaire global
model_pipelines[model_id] = loaded_data
loaded_count += 1
except Exception as e:
# Log l'échec mais continue avec les autres modèles
logger.error(f"Failed to load model ID {model_id}: {e}")
failed_models.append(model_data.get('display_name', f'ID {model_id}'))
logger.info(f"Finished loading models. Successfully loaded: {loaded_count}/{len(models_data)}")
if failed_models:
logger.warning(f"Failed to load the following models: {', '.join(failed_models)}")
# Pas besoin de retourner les pipelines, ils sont dans le dictionnaire global
# return model_pipelines # Ancienne logique
def get_model(model_name: str):
"""Récupérer un modèle chargé par son nom de fichier (hf_filename).
Args:
model_name: Nom du fichier du modèle (hf_filename) à récupérer
Returns:
Le modèle chargé (le pipeline)
Raises:
KeyError: Si le modèle n'est pas trouvé
"""
# Rechercher par hf_filename
for model_id, model_data in model_pipelines.items():
metadata = model_data.get('metadata', {})
if metadata.get('hf_filename') == model_name:
logger.info(f"Model found: {model_name} (ID: {model_id})")
return model_data['pipeline']
# Si on arrive ici, le modèle n'a pas été trouvé
logger.error(f"Model {model_name} not found in loaded models")
raise KeyError(f"Model {model_name} not found")