# ======================================= # Fichier: optimizers/ray_tune_optimizer.py # ======================================= from typing import Dict, Any, TYPE_CHECKING from interfaces.hyperparameter_optimizer import HyperparameterOptimizer if TYPE_CHECKING: from base_trainer import BaseTrainer import ray from ray import tune from ray.tune import CLIReporter from ray.tune.schedulers import ASHAScheduler from trainers.cuml.svm_trainer import SvmTrainer from trainers.cuml.random_forest_trainer import RandomForestTrainer from trainers.cuml.logistic_regression_trainer import LogisticRegressionTrainer from trainers.cuml.linear_regression_trainer import LinearRegressionTrainer class RayTuneOptimizer(HyperparameterOptimizer): """ Optimiseur d'hyperparamètres basé sur Ray Tune. Implémente l'interface HyperparameterOptimizer. """ def optimize(self, trainer: "BaseTrainer", param_grid: Dict[str, Any]) -> Dict[str, Any]: """ Recherche les meilleurs hyperparamètres pour un 'trainer' donné, selon la grille 'param_grid', en utilisant Ray Tune. :param trainer: Instance d'une classe implémentant BaseTrainer. :param param_grid: Dictionnaire définissant l'espace de recherche pour chaque hyperparamètre. :return: Un dictionnaire contenant les hyperparamètres optimaux trouvés. """ full_config = trainer.config.dict() if hasattr(trainer.config, 'dict') else trainer.config config = self._create_config(param_grid) scheduler = self._create_scheduler() reporter = self._create_reporter() analysis = tune.run( self._train_model, config={"hyperparameters": config, **full_config}, num_samples=full_config['hyperparameters'].get('n_trials', 100), scheduler=scheduler, progress_reporter=reporter, resources_per_trial={"cpu": 1, "gpu": 0} ) return analysis.best_config def _create_config(self, param_grid: Dict[str, Any]) -> Dict[str, Any]: """ Crée la configuration pour Ray Tune à partir du param_grid. :param param_grid: Dictionnaire définissant l'espace de recherche pour chaque hyperparamètre. :return: Dictionnaire de configuration pour Ray Tune. """ return {param: self._define_search_space(param, vals) for param, vals in param_grid.items()} def _define_search_space(self, param: str, vals: Any) -> Any: """ Définit l'espace de recherche pour un hyperparamètre donné. :param param: Nom de l'hyperparamètre. :param vals: Valeurs possibles ou dictionnaire définissant l'espace. :return: Espace de recherche Ray Tune. """ if isinstance(vals, list): return tune.choice(vals) elif isinstance(vals, dict): low = vals.get('low', 0) high = vals.get('high', 10) log = vals.get('log', False) if 'low' in vals and 'high' in vals: return tune.uniform(param, low, high) if log else tune.randint(param, low, high) return tune.uniform(param, 0.0, 1.0) def _create_scheduler(self) -> ASHAScheduler: """ Crée un scheduler ASHAScheduler pour Ray Tune. :return: Instance d'ASHAScheduler. """ return ASHAScheduler( max_t=100, grace_period=10, reduction_factor=2 ) def _create_reporter(self) -> CLIReporter: """ Crée un reporter CLIReporter pour Ray Tune. :return: Instance de CLIReporter. """ return CLIReporter( metric_columns=["validation_loss", "training_iteration"] ) def _train_model(self, config: Dict[str, Any]): """ Fonction d'entraînement pour Ray Tune. :param config: Configuration des hyperparamètres. """ merged_config = config.copy() hyperparams = merged_config.pop('hyperparameters', {}) trainer_instance = self._get_trainer_instance(merged_config) trainer_instance.config.hyperparameters = hyperparams trainer_instance.train() results = trainer_instance.evaluate() tune.report(validation_loss=results.get('validation_loss', float('inf'))) def _get_trainer_instance(self, config: Dict[str, Any]) -> "BaseTrainer": """ Obtient une instance de BaseTrainer basée sur la configuration. :param config: Configuration globale incluant hyperparamètres. :return: Instance de BaseTrainer. """ model_type = config['model']['type'].lower() trainer_mapping = { 'svm': SvmTrainer, 'random_forest': RandomForestTrainer, 'logistic_regression': LogisticRegressionTrainer, 'linear_regression': LinearRegressionTrainer, # Ajouter d'autres mappings ici si nécessaire } trainer_class = trainer_mapping.get(model_type) if not trainer_class: raise ValueError(f"Type de modèle non supporté : {model_type}") return trainer_class( config=config, data_path=config['data']['path'], target_column=config['data']['target_column'] )