import logging import torch from huggingface_hub import hf_hub_download from typing import List, Dict, Any from config.settings import DEVICE, HF_TOKEN, NUM_THREADS from architecture.resnet import ResNet logger = logging.getLogger(__name__) # Configuration de PyTorch torch.set_num_threads(NUM_THREADS) # Instance de base pour le modèle ResNet # Note: Peut-être pas nécessaire de l'instancier ici si chaque chargement en crée une nouvelle # base_model = ResNet("resnet152", num_output_neurons=2).to(DEVICE) # Dictionnaire global pour stocker les modèles chargés et leurs métadonnées # Clé: ID du modèle (provenant de la DB) # Valeur: Dict{'pipeline': , 'metadata': {'model_id': ..., 'hf_filename': ...}} model_pipelines: Dict[Any, Dict[str, Any]] = {} async def _load_single_model_pipeline(model_data: Dict[str, Any], force_download: bool = False) -> Dict[str, Any]: """Charge un seul pipeline de modèle et retourne le pipeline ainsi que ses métadonnées. Args: model_data: Dictionnaire contenant les informations du modèle (hf_repo_id, etc.). force_download: Si True, force le re-téléchargement depuis Hugging Face Hub. Returns: Un dictionnaire contenant le 'pipeline' chargé et les 'metadata'. Raises: Exception: Si le chargement échoue. """ model_id = model_data['model_id'] # Utiliser l'ID de la DB comme clé model_name = model_data['hf_filename'] repo_id = model_data['hf_repo_id'] subfolder = model_data['hf_subfolder'] updated_at = model_data.get('updated_at', None) # Récupérer updated_at logger.info(f"Loading model ID {model_id}: {model_name} (repo: {repo_id}, subfolder: {subfolder})") try: model_weight_path = hf_hub_download( repo_id=repo_id, subfolder=subfolder, filename=model_name, token=HF_TOKEN, # Assurez-vous que HF_TOKEN est géré correctement force_download=force_download # Utiliser le paramètre force_download ) logger.debug(f"Model weights downloaded to: {model_weight_path}") # Créer une nouvelle instance de modèle ResNet pour ce chargement spécifique # Assurez-vous que ResNet et ses arguments sont corrects model = ResNet("resnet152", num_output_neurons=2).to(DEVICE) # Charger les poids # Attention: la méthode de chargement dépend du format des poids (state_dict, etc.) state_dict = torch.load(model_weight_path, map_location=DEVICE) # Gérer les cas où les poids sont dans une sous-clé (ex: 'state_dict', 'model') if isinstance(state_dict, dict) and 'state_dict' in state_dict: state_dict = state_dict['state_dict'] elif isinstance(state_dict, dict) and 'model' in state_dict: # Autre cas commun state_dict = state_dict['model'] # Adapter les clés si nécessaire (ex: supprimer le préfixe 'module.' de DataParallel/DDP) state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} model.load_state_dict(state_dict) model.eval() # Mettre le modèle en mode évaluation metadata = { 'model_id': model_id, 'hf_filename': model_name, 'updated_at': str(updated_at) if updated_at else None # Ajouter updated_at (converti en str pour JSON) # Ajoutez d'autres métadonnées utiles ici si nécessaire } logger.info(f"Successfully loaded model ID {model_id}: {model_name}") return {'pipeline': model, 'metadata': metadata} except Exception as e: logger.error(f"Failed to load model ID {model_id} ({model_name}): {e}", exc_info=True) raise # Propage l'exception pour que l'appelant puisse la gérer async def load_models(models_data: List[Dict[str, Any]]) -> None: """Charger les modèles depuis Hugging Face et les stocker dans model_pipelines. Args: models_data: Liste de dictionnaires contenant les informations des modèles. Raises: RuntimeError: Si aucun modèle n'est trouvé. """ logger.info(f"Attempting to load {len(models_data)} models into memory...") if not models_data: error_msg = "No models data provided. Cannot load models." logger.error(error_msg) # On ne lève plus d'erreur ici, on logge juste. L'API démarrera sans modèles. return loaded_count = 0 failed_models = [] for model_data in models_data: model_id = model_data.get('model_id', 'N/A') # Assurez-vous que model_id est présent try: # Utilise la nouvelle fonction pour charger un seul modèle et ses métadonnées loaded_data = await _load_single_model_pipeline(model_data) # Stocke l'ensemble (pipeline + metadata) dans le dictionnaire global model_pipelines[model_id] = loaded_data loaded_count += 1 except Exception as e: # Log l'échec mais continue avec les autres modèles logger.error(f"Failed to load model ID {model_id}: {e}") failed_models.append(model_data.get('display_name', f'ID {model_id}')) logger.info(f"Finished loading models. Successfully loaded: {loaded_count}/{len(models_data)}") if failed_models: logger.warning(f"Failed to load the following models: {', '.join(failed_models)}") # Pas besoin de retourner les pipelines, ils sont dans le dictionnaire global # return model_pipelines # Ancienne logique def get_model(model_name: str): """Récupérer un modèle chargé par son nom de fichier (hf_filename). Args: model_name: Nom du fichier du modèle (hf_filename) à récupérer Returns: Le modèle chargé (le pipeline) Raises: KeyError: Si le modèle n'est pas trouvé """ # Rechercher par hf_filename for model_id, model_data in model_pipelines.items(): metadata = model_data.get('metadata', {}) if metadata.get('hf_filename') == model_name: logger.info(f"Model found: {model_name} (ID: {model_id})") return model_data['pipeline'] # Si on arrive ici, le modèle n'a pas été trouvé logger.error(f"Model {model_name} not found in loaded models") raise KeyError(f"Model {model_name} not found")