Spaces:
Sleeping
Sleeping
# ======================================= | |
# Fichier: optimizers/ray_tune_optimizer.py | |
# ======================================= | |
from typing import Dict, Any, TYPE_CHECKING | |
from interfaces.hyperparameter_optimizer import HyperparameterOptimizer | |
if TYPE_CHECKING: | |
from base_trainer import BaseTrainer | |
import ray | |
from ray import tune | |
from ray.tune import CLIReporter | |
from ray.tune.schedulers import ASHAScheduler | |
from trainers.cuml.svm_trainer import SvmTrainer | |
from trainers.cuml.random_forest_trainer import RandomForestTrainer | |
from trainers.cuml.logistic_regression_trainer import LogisticRegressionTrainer | |
from trainers.cuml.linear_regression_trainer import LinearRegressionTrainer | |
class RayTuneOptimizer(HyperparameterOptimizer): | |
""" | |
Optimiseur d'hyperparamètres basé sur Ray Tune. | |
Implémente l'interface HyperparameterOptimizer. | |
""" | |
def optimize(self, trainer: "BaseTrainer", param_grid: Dict[str, Any]) -> Dict[str, Any]: | |
""" | |
Recherche les meilleurs hyperparamètres pour un 'trainer' donné, | |
selon la grille 'param_grid', en utilisant Ray Tune. | |
:param trainer: Instance d'une classe implémentant BaseTrainer. | |
:param param_grid: Dictionnaire définissant l'espace de recherche | |
pour chaque hyperparamètre. | |
:return: Un dictionnaire contenant les hyperparamètres optimaux trouvés. | |
""" | |
full_config = trainer.config.dict() if hasattr(trainer.config, 'dict') else trainer.config | |
config = self._create_config(param_grid) | |
scheduler = self._create_scheduler() | |
reporter = self._create_reporter() | |
analysis = tune.run( | |
self._train_model, | |
config={"hyperparameters": config, **full_config}, | |
num_samples=full_config['hyperparameters'].get('n_trials', 100), | |
scheduler=scheduler, | |
progress_reporter=reporter, | |
resources_per_trial={"cpu": 1, "gpu": 0} | |
) | |
return analysis.best_config | |
def _create_config(self, param_grid: Dict[str, Any]) -> Dict[str, Any]: | |
""" | |
Crée la configuration pour Ray Tune à partir du param_grid. | |
:param param_grid: Dictionnaire définissant l'espace de recherche | |
pour chaque hyperparamètre. | |
:return: Dictionnaire de configuration pour Ray Tune. | |
""" | |
return {param: self._define_search_space(param, vals) for param, vals in param_grid.items()} | |
def _define_search_space(self, param: str, vals: Any) -> Any: | |
""" | |
Définit l'espace de recherche pour un hyperparamètre donné. | |
:param param: Nom de l'hyperparamètre. | |
:param vals: Valeurs possibles ou dictionnaire définissant l'espace. | |
:return: Espace de recherche Ray Tune. | |
""" | |
if isinstance(vals, list): | |
return tune.choice(vals) | |
elif isinstance(vals, dict): | |
low = vals.get('low', 0) | |
high = vals.get('high', 10) | |
log = vals.get('log', False) | |
if 'low' in vals and 'high' in vals: | |
return tune.uniform(param, low, high) if log else tune.randint(param, low, high) | |
return tune.uniform(param, 0.0, 1.0) | |
def _create_scheduler(self) -> ASHAScheduler: | |
""" | |
Crée un scheduler ASHAScheduler pour Ray Tune. | |
:return: Instance d'ASHAScheduler. | |
""" | |
return ASHAScheduler( | |
max_t=100, | |
grace_period=10, | |
reduction_factor=2 | |
) | |
def _create_reporter(self) -> CLIReporter: | |
""" | |
Crée un reporter CLIReporter pour Ray Tune. | |
:return: Instance de CLIReporter. | |
""" | |
return CLIReporter( | |
metric_columns=["validation_loss", "training_iteration"] | |
) | |
def _train_model(self, config: Dict[str, Any]): | |
""" | |
Fonction d'entraînement pour Ray Tune. | |
:param config: Configuration des hyperparamètres. | |
""" | |
merged_config = config.copy() | |
hyperparams = merged_config.pop('hyperparameters', {}) | |
trainer_instance = self._get_trainer_instance(merged_config) | |
trainer_instance.config.hyperparameters = hyperparams | |
trainer_instance.train() | |
results = trainer_instance.evaluate() | |
tune.report(validation_loss=results.get('validation_loss', float('inf'))) | |
def _get_trainer_instance(self, config: Dict[str, Any]) -> "BaseTrainer": | |
""" | |
Obtient une instance de BaseTrainer basée sur la configuration. | |
:param config: Configuration globale incluant hyperparamètres. | |
:return: Instance de BaseTrainer. | |
""" | |
model_type = config['model']['type'].lower() | |
trainer_mapping = { | |
'svm': SvmTrainer, | |
'random_forest': RandomForestTrainer, | |
'logistic_regression': LogisticRegressionTrainer, | |
'linear_regression': LinearRegressionTrainer, | |
# Ajouter d'autres mappings ici si nécessaire | |
} | |
trainer_class = trainer_mapping.get(model_type) | |
if not trainer_class: | |
raise ValueError(f"Type de modèle non supporté : {model_type}") | |
return trainer_class( | |
config=config, | |
data_path=config['data']['path'], | |
target_column=config['data']['target_column'] | |
) | |