Spaces:
Sleeping
Sleeping
import logging | |
import torch | |
from huggingface_hub import hf_hub_download | |
from typing import List, Dict, Any | |
from config.settings import DEVICE, HF_TOKEN, NUM_THREADS | |
from architecture.resnet import ResNet | |
logger = logging.getLogger(__name__) | |
# Configuration de PyTorch | |
torch.set_num_threads(NUM_THREADS) | |
# Instance de base pour le modèle ResNet | |
# Note: Peut-être pas nécessaire de l'instancier ici si chaque chargement en crée une nouvelle | |
# base_model = ResNet("resnet152", num_output_neurons=2).to(DEVICE) | |
# Dictionnaire global pour stocker les modèles chargés et leurs métadonnées | |
# Clé: ID du modèle (provenant de la DB) | |
# Valeur: Dict{'pipeline': <modèle chargé>, 'metadata': {'model_id': ..., 'hf_filename': ...}} | |
model_pipelines: Dict[Any, Dict[str, Any]] = {} | |
async def _load_single_model_pipeline(model_data: Dict[str, Any], force_download: bool = False) -> Dict[str, Any]: | |
"""Charge un seul pipeline de modèle et retourne le pipeline ainsi que ses métadonnées. | |
Args: | |
model_data: Dictionnaire contenant les informations du modèle (hf_repo_id, etc.). | |
force_download: Si True, force le re-téléchargement depuis Hugging Face Hub. | |
Returns: | |
Un dictionnaire contenant le 'pipeline' chargé et les 'metadata'. | |
Raises: | |
Exception: Si le chargement échoue. | |
""" | |
model_id = model_data['model_id'] # Utiliser l'ID de la DB comme clé | |
model_name = model_data['hf_filename'] | |
repo_id = model_data['hf_repo_id'] | |
subfolder = model_data['hf_subfolder'] | |
updated_at = model_data.get('updated_at', None) # Récupérer updated_at | |
logger.info(f"Loading model ID {model_id}: {model_name} (repo: {repo_id}, subfolder: {subfolder})") | |
try: | |
model_weight_path = hf_hub_download( | |
repo_id=repo_id, | |
subfolder=subfolder, | |
filename=model_name, | |
token=HF_TOKEN, # Assurez-vous que HF_TOKEN est géré correctement | |
force_download=force_download # Utiliser le paramètre force_download | |
) | |
logger.debug(f"Model weights downloaded to: {model_weight_path}") | |
# Créer une nouvelle instance de modèle ResNet pour ce chargement spécifique | |
# Assurez-vous que ResNet et ses arguments sont corrects | |
model = ResNet("resnet152", num_output_neurons=2).to(DEVICE) | |
# Charger les poids | |
# Attention: la méthode de chargement dépend du format des poids (state_dict, etc.) | |
state_dict = torch.load(model_weight_path, map_location=DEVICE) | |
# Gérer les cas où les poids sont dans une sous-clé (ex: 'state_dict', 'model') | |
if isinstance(state_dict, dict) and 'state_dict' in state_dict: | |
state_dict = state_dict['state_dict'] | |
elif isinstance(state_dict, dict) and 'model' in state_dict: # Autre cas commun | |
state_dict = state_dict['model'] | |
# Adapter les clés si nécessaire (ex: supprimer le préfixe 'module.' de DataParallel/DDP) | |
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} | |
model.load_state_dict(state_dict) | |
model.eval() # Mettre le modèle en mode évaluation | |
metadata = { | |
'model_id': model_id, | |
'hf_filename': model_name, | |
'updated_at': str(updated_at) if updated_at else None # Ajouter updated_at (converti en str pour JSON) | |
# Ajoutez d'autres métadonnées utiles ici si nécessaire | |
} | |
logger.info(f"Successfully loaded model ID {model_id}: {model_name}") | |
return {'pipeline': model, 'metadata': metadata} | |
except Exception as e: | |
logger.error(f"Failed to load model ID {model_id} ({model_name}): {e}", exc_info=True) | |
raise # Propage l'exception pour que l'appelant puisse la gérer | |
async def load_models(models_data: List[Dict[str, Any]]) -> None: | |
"""Charger les modèles depuis Hugging Face et les stocker dans model_pipelines. | |
Args: | |
models_data: Liste de dictionnaires contenant les informations des modèles. | |
Raises: | |
RuntimeError: Si aucun modèle n'est trouvé. | |
""" | |
logger.info(f"Attempting to load {len(models_data)} models into memory...") | |
if not models_data: | |
error_msg = "No models data provided. Cannot load models." | |
logger.error(error_msg) | |
# On ne lève plus d'erreur ici, on logge juste. L'API démarrera sans modèles. | |
return | |
loaded_count = 0 | |
failed_models = [] | |
for model_data in models_data: | |
model_id = model_data.get('model_id', 'N/A') # Assurez-vous que model_id est présent | |
try: | |
# Utilise la nouvelle fonction pour charger un seul modèle et ses métadonnées | |
loaded_data = await _load_single_model_pipeline(model_data) | |
# Stocke l'ensemble (pipeline + metadata) dans le dictionnaire global | |
model_pipelines[model_id] = loaded_data | |
loaded_count += 1 | |
except Exception as e: | |
# Log l'échec mais continue avec les autres modèles | |
logger.error(f"Failed to load model ID {model_id}: {e}") | |
failed_models.append(model_data.get('display_name', f'ID {model_id}')) | |
logger.info(f"Finished loading models. Successfully loaded: {loaded_count}/{len(models_data)}") | |
if failed_models: | |
logger.warning(f"Failed to load the following models: {', '.join(failed_models)}") | |
# Pas besoin de retourner les pipelines, ils sont dans le dictionnaire global | |
# return model_pipelines # Ancienne logique | |
def get_model(model_name: str): | |
"""Récupérer un modèle chargé par son nom de fichier (hf_filename). | |
Args: | |
model_name: Nom du fichier du modèle (hf_filename) à récupérer | |
Returns: | |
Le modèle chargé (le pipeline) | |
Raises: | |
KeyError: Si le modèle n'est pas trouvé | |
""" | |
# Rechercher par hf_filename | |
for model_id, model_data in model_pipelines.items(): | |
metadata = model_data.get('metadata', {}) | |
if metadata.get('hf_filename') == model_name: | |
logger.info(f"Model found: {model_name} (ID: {model_id})") | |
return model_data['pipeline'] | |
# Si on arrive ici, le modèle n'a pas été trouvé | |
logger.error(f"Model {model_name} not found in loaded models") | |
raise KeyError(f"Model {model_name} not found") | |