|
|
|
|
|
|
|
|
|
from functools import singledispatch |
|
from typing import Dict, Any |
|
from base_trainer import BaseTrainer |
|
from trainers.huggingface.huggingface_transformer_trainer import HuggingFaceTransformerTrainer |
|
from trainers.cuml.svm_trainer import SvmTrainer |
|
from trainers.cuml.random_forest_trainer import RandomForestTrainer |
|
from trainers.cuml.logistic_regression_trainer import LogisticRegressionTrainer |
|
from trainers.cuml.linear_regression_trainer import LinearRegressionTrainer |
|
|
|
|
|
@singledispatch |
|
def get_relevant_params_for_logging(trainer: BaseTrainer) -> Dict[str, Any]: |
|
""" |
|
Méthode générique, par défaut, pour récupérer les paramètres |
|
à logger dans MLflow ou ailleurs. |
|
|
|
:param trainer: Trainer dont on veut extraire les paramètres. |
|
:return: Un dictionnaire de paramètres pertinents. |
|
""" |
|
|
|
|
|
return trainer.config.model.params |
|
|
|
|
|
@get_relevant_params_for_logging.register |
|
def _(trainer: HuggingFaceTransformerTrainer) -> Dict[str, Any]: |
|
""" |
|
Cas particulier pour un HuggingFaceTransformerTrainer. |
|
""" |
|
|
|
return trainer.config.model.params |
|
|
|
|
|
@get_relevant_params_for_logging.register |
|
def _(trainer: SvmTrainer) -> Dict[str, Any]: |
|
""" |
|
Cas particulier pour un SvmTrainer. |
|
""" |
|
return trainer.config.model.params |
|
|
|
|
|
@get_relevant_params_for_logging.register |
|
def _(trainer: RandomForestTrainer) -> Dict[str, Any]: |
|
""" |
|
Cas particulier pour un RandomForestTrainer. |
|
""" |
|
return trainer.config.model.params |
|
|
|
|
|
@get_relevant_params_for_logging.register |
|
def _(trainer: LogisticRegressionTrainer) -> Dict[str, Any]: |
|
""" |
|
Cas particulier pour un LogisticRegressionTrainer. |
|
""" |
|
return trainer.config.model.params |
|
|
|
|
|
@get_relevant_params_for_logging.register |
|
def _(trainer: LinearRegressionTrainer) -> Dict[str, Any]: |
|
""" |
|
Cas particulier pour un LinearRegressionTrainer. |
|
""" |
|
return trainer.config.model.params |
|
|