Spaces:
Sleeping
Sleeping
File size: 6,482 Bytes
e109700 0053356 e109700 0053356 e109700 0053356 e109700 dfb1c84 e109700 dfb1c84 e109700 0053356 638ed9d e109700 dfb1c84 0053356 dfb1c84 0053356 638ed9d 0053356 dfb1c84 0053356 dfb1c84 0053356 e109700 0053356 e109700 0053356 e109700 0053356 e109700 0053356 e109700 0053356 e109700 0053356 e109700 dfb1c84 e109700 0053356 c2cd706 0053356 e109700 9f9c6d5 e109700 9f9c6d5 e109700 dfb1c84 e109700 9f9c6d5 e109700 9f9c6d5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import logging
import torch
from huggingface_hub import hf_hub_download
from typing import List, Dict, Any
from config.settings import DEVICE, HF_TOKEN, NUM_THREADS
from architecture.resnet import ResNet
logger = logging.getLogger(__name__)
# Configuration de PyTorch
torch.set_num_threads(NUM_THREADS)
# Instance de base pour le modèle ResNet
# Note: Peut-être pas nécessaire de l'instancier ici si chaque chargement en crée une nouvelle
# base_model = ResNet("resnet152", num_output_neurons=2).to(DEVICE)
# Dictionnaire global pour stocker les modèles chargés et leurs métadonnées
# Clé: ID du modèle (provenant de la DB)
# Valeur: Dict{'pipeline': <modèle chargé>, 'metadata': {'model_id': ..., 'hf_filename': ...}}
model_pipelines: Dict[Any, Dict[str, Any]] = {}
async def _load_single_model_pipeline(model_data: Dict[str, Any], force_download: bool = False) -> Dict[str, Any]:
"""Charge un seul pipeline de modèle et retourne le pipeline ainsi que ses métadonnées.
Args:
model_data: Dictionnaire contenant les informations du modèle (hf_repo_id, etc.).
force_download: Si True, force le re-téléchargement depuis Hugging Face Hub.
Returns:
Un dictionnaire contenant le 'pipeline' chargé et les 'metadata'.
Raises:
Exception: Si le chargement échoue.
"""
model_id = model_data['model_id'] # Utiliser l'ID de la DB comme clé
model_name = model_data['hf_filename']
repo_id = model_data['hf_repo_id']
subfolder = model_data['hf_subfolder']
updated_at = model_data.get('updated_at', None) # Récupérer updated_at
logger.info(f"Loading model ID {model_id}: {model_name} (repo: {repo_id}, subfolder: {subfolder})")
try:
model_weight_path = hf_hub_download(
repo_id=repo_id,
subfolder=subfolder,
filename=model_name,
token=HF_TOKEN, # Assurez-vous que HF_TOKEN est géré correctement
force_download=force_download # Utiliser le paramètre force_download
)
logger.debug(f"Model weights downloaded to: {model_weight_path}")
# Créer une nouvelle instance de modèle ResNet pour ce chargement spécifique
# Assurez-vous que ResNet et ses arguments sont corrects
model = ResNet("resnet152", num_output_neurons=2).to(DEVICE)
# Charger les poids
# Attention: la méthode de chargement dépend du format des poids (state_dict, etc.)
state_dict = torch.load(model_weight_path, map_location=DEVICE)
# Gérer les cas où les poids sont dans une sous-clé (ex: 'state_dict', 'model')
if isinstance(state_dict, dict) and 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
elif isinstance(state_dict, dict) and 'model' in state_dict: # Autre cas commun
state_dict = state_dict['model']
# Adapter les clés si nécessaire (ex: supprimer le préfixe 'module.' de DataParallel/DDP)
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.eval() # Mettre le modèle en mode évaluation
metadata = {
'model_id': model_id,
'hf_filename': model_name,
'updated_at': str(updated_at) if updated_at else None # Ajouter updated_at (converti en str pour JSON)
# Ajoutez d'autres métadonnées utiles ici si nécessaire
}
logger.info(f"Successfully loaded model ID {model_id}: {model_name}")
return {'pipeline': model, 'metadata': metadata}
except Exception as e:
logger.error(f"Failed to load model ID {model_id} ({model_name}): {e}", exc_info=True)
raise # Propage l'exception pour que l'appelant puisse la gérer
async def load_models(models_data: List[Dict[str, Any]]) -> None:
"""Charger les modèles depuis Hugging Face et les stocker dans model_pipelines.
Args:
models_data: Liste de dictionnaires contenant les informations des modèles.
Raises:
RuntimeError: Si aucun modèle n'est trouvé.
"""
logger.info(f"Attempting to load {len(models_data)} models into memory...")
if not models_data:
error_msg = "No models data provided. Cannot load models."
logger.error(error_msg)
# On ne lève plus d'erreur ici, on logge juste. L'API démarrera sans modèles.
return
loaded_count = 0
failed_models = []
for model_data in models_data:
model_id = model_data.get('model_id', 'N/A') # Assurez-vous que model_id est présent
try:
# Utilise la nouvelle fonction pour charger un seul modèle et ses métadonnées
loaded_data = await _load_single_model_pipeline(model_data)
# Stocke l'ensemble (pipeline + metadata) dans le dictionnaire global
model_pipelines[model_id] = loaded_data
loaded_count += 1
except Exception as e:
# Log l'échec mais continue avec les autres modèles
logger.error(f"Failed to load model ID {model_id}: {e}")
failed_models.append(model_data.get('display_name', f'ID {model_id}'))
logger.info(f"Finished loading models. Successfully loaded: {loaded_count}/{len(models_data)}")
if failed_models:
logger.warning(f"Failed to load the following models: {', '.join(failed_models)}")
# Pas besoin de retourner les pipelines, ils sont dans le dictionnaire global
# return model_pipelines # Ancienne logique
def get_model(model_name: str):
"""Récupérer un modèle chargé par son nom de fichier (hf_filename).
Args:
model_name: Nom du fichier du modèle (hf_filename) à récupérer
Returns:
Le modèle chargé (le pipeline)
Raises:
KeyError: Si le modèle n'est pas trouvé
"""
# Rechercher par hf_filename
for model_id, model_data in model_pipelines.items():
metadata = model_data.get('metadata', {})
if metadata.get('hf_filename') == model_name:
logger.info(f"Model found: {model_name} (ID: {model_id})")
return model_data['pipeline']
# Si on arrive ici, le modèle n'a pas été trouvé
logger.error(f"Model {model_name} not found in loaded models")
raise KeyError(f"Model {model_name} not found")
|