File size: 5,366 Bytes
bf5fb5f
 
 
 
8ffb539
bf5fb5f
8ffb539
 
 
bf5fb5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ffb539
bf5fb5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ffb539
bf5fb5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# =======================================
# Fichier: optimizers/ray_tune_optimizer.py
# =======================================

from typing import Dict, Any, TYPE_CHECKING
from interfaces.hyperparameter_optimizer import HyperparameterOptimizer

if TYPE_CHECKING:
    from base_trainer import BaseTrainer
import ray
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler
from trainers.cuml.svm_trainer import SvmTrainer
from trainers.cuml.random_forest_trainer import RandomForestTrainer
from trainers.cuml.logistic_regression_trainer import LogisticRegressionTrainer
from trainers.cuml.linear_regression_trainer import LinearRegressionTrainer


class RayTuneOptimizer(HyperparameterOptimizer):
    """
    Optimiseur d'hyperparamètres basé sur Ray Tune.
    Implémente l'interface HyperparameterOptimizer.
    """

    def optimize(self, trainer: "BaseTrainer", param_grid: Dict[str, Any]) -> Dict[str, Any]:
        """
        Recherche les meilleurs hyperparamètres pour un 'trainer' donné,
        selon la grille 'param_grid', en utilisant Ray Tune.

        :param trainer: Instance d'une classe implémentant BaseTrainer.
        :param param_grid: Dictionnaire définissant l'espace de recherche
                           pour chaque hyperparamètre.
        :return: Un dictionnaire contenant les hyperparamètres optimaux trouvés.
        """
        full_config = trainer.config.dict() if hasattr(trainer.config, 'dict') else trainer.config
        config = self._create_config(param_grid)
        scheduler = self._create_scheduler()
        reporter = self._create_reporter()

        analysis = tune.run(
            self._train_model,
            config={"hyperparameters": config, **full_config},
            num_samples=full_config['hyperparameters'].get('n_trials', 100),
            scheduler=scheduler,
            progress_reporter=reporter,
            resources_per_trial={"cpu": 1, "gpu": 0}
        )

        return analysis.best_config

    def _create_config(self, param_grid: Dict[str, Any]) -> Dict[str, Any]:
        """
        Crée la configuration pour Ray Tune à partir du param_grid.

        :param param_grid: Dictionnaire définissant l'espace de recherche
                           pour chaque hyperparamètre.
        :return: Dictionnaire de configuration pour Ray Tune.
        """
        return {param: self._define_search_space(param, vals) for param, vals in param_grid.items()}

    def _define_search_space(self, param: str, vals: Any) -> Any:
        """
        Définit l'espace de recherche pour un hyperparamètre donné.

        :param param: Nom de l'hyperparamètre.
        :param vals: Valeurs possibles ou dictionnaire définissant l'espace.
        :return: Espace de recherche Ray Tune.
        """
        if isinstance(vals, list):
            return tune.choice(vals)
        elif isinstance(vals, dict):
            low = vals.get('low', 0)
            high = vals.get('high', 10)
            log = vals.get('log', False)
            if 'low' in vals and 'high' in vals:
                return tune.uniform(param, low, high) if log else tune.randint(param, low, high)
        return tune.uniform(param, 0.0, 1.0)

    def _create_scheduler(self) -> ASHAScheduler:
        """
        Crée un scheduler ASHAScheduler pour Ray Tune.

        :return: Instance d'ASHAScheduler.
        """
        return ASHAScheduler(
            max_t=100,
            grace_period=10,
            reduction_factor=2
        )

    def _create_reporter(self) -> CLIReporter:
        """
        Crée un reporter CLIReporter pour Ray Tune.

        :return: Instance de CLIReporter.
        """
        return CLIReporter(
            metric_columns=["validation_loss", "training_iteration"]
        )

    def _train_model(self, config: Dict[str, Any]):
        """
        Fonction d'entraînement pour Ray Tune.

        :param config: Configuration des hyperparamètres.
        """
        merged_config = config.copy()
        hyperparams = merged_config.pop('hyperparameters', {})
        trainer_instance = self._get_trainer_instance(merged_config)
        trainer_instance.config.hyperparameters = hyperparams
        trainer_instance.train()
        results = trainer_instance.evaluate()
        tune.report(validation_loss=results.get('validation_loss', float('inf')))

    def _get_trainer_instance(self, config: Dict[str, Any]) -> "BaseTrainer":
        """
        Obtient une instance de BaseTrainer basée sur la configuration.

        :param config: Configuration globale incluant hyperparamètres.
        :return: Instance de BaseTrainer.
        """
        model_type = config['model']['type'].lower()
        trainer_mapping = {
            'svm': SvmTrainer,
            'random_forest': RandomForestTrainer,
            'logistic_regression': LogisticRegressionTrainer,
            'linear_regression': LinearRegressionTrainer,
            # Ajouter d'autres mappings ici si nécessaire
        }

        trainer_class = trainer_mapping.get(model_type)
        if not trainer_class:
            raise ValueError(f"Type de modèle non supporté : {model_type}")

        return trainer_class(
            config=config,
            data_path=config['data']['path'],
            target_column=config['data']['target_column']
        )