File size: 2,222 Bytes
bf5fb5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# ==================================
# Fichier: parameter_logging.py
# ==================================

from functools import singledispatch
from typing import Dict, Any
from base_trainer import BaseTrainer
from trainers.huggingface.huggingface_transformer_trainer import HuggingFaceTransformerTrainer
from trainers.cuml.svm_trainer import SvmTrainer
from trainers.cuml.random_forest_trainer import RandomForestTrainer
from trainers.cuml.logistic_regression_trainer import LogisticRegressionTrainer
from trainers.cuml.linear_regression_trainer import LinearRegressionTrainer


@singledispatch
def get_relevant_params_for_logging(trainer: BaseTrainer) -> Dict[str, Any]:
    """
    Méthode générique, par défaut, pour récupérer les paramètres
    à logger dans MLflow ou ailleurs.

    :param trainer: Trainer dont on veut extraire les paramètres.
    :return: Un dictionnaire de paramètres pertinents.
    """
    # Par défaut, on récupère la section 'params' du config.model,
    # mais l'UML ne précise pas la logique interne.
    return trainer.config.model.params


@get_relevant_params_for_logging.register
def _(trainer: HuggingFaceTransformerTrainer) -> Dict[str, Any]:
    """
    Cas particulier pour un HuggingFaceTransformerTrainer.
    """
    # Extrait les paramètres spécifiques HuggingFace indiqués dans trainer.config.model.params.
    return trainer.config.model.params


@get_relevant_params_for_logging.register
def _(trainer: SvmTrainer) -> Dict[str, Any]:
    """
    Cas particulier pour un SvmTrainer.
    """
    return trainer.config.model.params


@get_relevant_params_for_logging.register
def _(trainer: RandomForestTrainer) -> Dict[str, Any]:
    """
    Cas particulier pour un RandomForestTrainer.
    """
    return trainer.config.model.params


@get_relevant_params_for_logging.register
def _(trainer: LogisticRegressionTrainer) -> Dict[str, Any]:
    """
    Cas particulier pour un LogisticRegressionTrainer.
    """
    return trainer.config.model.params


@get_relevant_params_for_logging.register
def _(trainer: LinearRegressionTrainer) -> Dict[str, Any]:
    """
    Cas particulier pour un LinearRegressionTrainer.
    """
    return trainer.config.model.params