emotion_classifier / src /parameter_logging.py
fioriclass's picture
initialisation
bf5fb5f
# ==================================
# Fichier: parameter_logging.py
# ==================================
from functools import singledispatch
from typing import Dict, Any
from base_trainer import BaseTrainer
from trainers.huggingface.huggingface_transformer_trainer import HuggingFaceTransformerTrainer
from trainers.cuml.svm_trainer import SvmTrainer
from trainers.cuml.random_forest_trainer import RandomForestTrainer
from trainers.cuml.logistic_regression_trainer import LogisticRegressionTrainer
from trainers.cuml.linear_regression_trainer import LinearRegressionTrainer
@singledispatch
def get_relevant_params_for_logging(trainer: BaseTrainer) -> Dict[str, Any]:
"""
Méthode générique, par défaut, pour récupérer les paramètres
à logger dans MLflow ou ailleurs.
:param trainer: Trainer dont on veut extraire les paramètres.
:return: Un dictionnaire de paramètres pertinents.
"""
# Par défaut, on récupère la section 'params' du config.model,
# mais l'UML ne précise pas la logique interne.
return trainer.config.model.params
@get_relevant_params_for_logging.register
def _(trainer: HuggingFaceTransformerTrainer) -> Dict[str, Any]:
"""
Cas particulier pour un HuggingFaceTransformerTrainer.
"""
# Extrait les paramètres spécifiques HuggingFace indiqués dans trainer.config.model.params.
return trainer.config.model.params
@get_relevant_params_for_logging.register
def _(trainer: SvmTrainer) -> Dict[str, Any]:
"""
Cas particulier pour un SvmTrainer.
"""
return trainer.config.model.params
@get_relevant_params_for_logging.register
def _(trainer: RandomForestTrainer) -> Dict[str, Any]:
"""
Cas particulier pour un RandomForestTrainer.
"""
return trainer.config.model.params
@get_relevant_params_for_logging.register
def _(trainer: LogisticRegressionTrainer) -> Dict[str, Any]:
"""
Cas particulier pour un LogisticRegressionTrainer.
"""
return trainer.config.model.params
@get_relevant_params_for_logging.register
def _(trainer: LinearRegressionTrainer) -> Dict[str, Any]:
"""
Cas particulier pour un LinearRegressionTrainer.
"""
return trainer.config.model.params