|
|
|
|
|
|
|
|
|
from typing import Dict, Any, TYPE_CHECKING |
|
from interfaces.hyperparameter_optimizer import HyperparameterOptimizer |
|
|
|
if TYPE_CHECKING: |
|
from base_trainer import BaseTrainer |
|
import ray |
|
from ray import tune |
|
from ray.tune import CLIReporter |
|
from ray.tune.schedulers import ASHAScheduler |
|
from trainers.cuml.svm_trainer import SvmTrainer |
|
from trainers.cuml.random_forest_trainer import RandomForestTrainer |
|
from trainers.cuml.logistic_regression_trainer import LogisticRegressionTrainer |
|
from trainers.cuml.linear_regression_trainer import LinearRegressionTrainer |
|
|
|
|
|
class RayTuneOptimizer(HyperparameterOptimizer): |
|
""" |
|
Optimiseur d'hyperparamètres basé sur Ray Tune. |
|
Implémente l'interface HyperparameterOptimizer. |
|
""" |
|
|
|
def optimize(self, trainer: "BaseTrainer", param_grid: Dict[str, Any]) -> Dict[str, Any]: |
|
""" |
|
Recherche les meilleurs hyperparamètres pour un 'trainer' donné, |
|
selon la grille 'param_grid', en utilisant Ray Tune. |
|
|
|
:param trainer: Instance d'une classe implémentant BaseTrainer. |
|
:param param_grid: Dictionnaire définissant l'espace de recherche |
|
pour chaque hyperparamètre. |
|
:return: Un dictionnaire contenant les hyperparamètres optimaux trouvés. |
|
""" |
|
full_config = trainer.config.dict() if hasattr(trainer.config, 'dict') else trainer.config |
|
config = self._create_config(param_grid) |
|
scheduler = self._create_scheduler() |
|
reporter = self._create_reporter() |
|
|
|
analysis = tune.run( |
|
self._train_model, |
|
config={"hyperparameters": config, **full_config}, |
|
num_samples=full_config['hyperparameters'].get('n_trials', 100), |
|
scheduler=scheduler, |
|
progress_reporter=reporter, |
|
resources_per_trial={"cpu": 1, "gpu": 0} |
|
) |
|
|
|
return analysis.best_config |
|
|
|
def _create_config(self, param_grid: Dict[str, Any]) -> Dict[str, Any]: |
|
""" |
|
Crée la configuration pour Ray Tune à partir du param_grid. |
|
|
|
:param param_grid: Dictionnaire définissant l'espace de recherche |
|
pour chaque hyperparamètre. |
|
:return: Dictionnaire de configuration pour Ray Tune. |
|
""" |
|
return {param: self._define_search_space(param, vals) for param, vals in param_grid.items()} |
|
|
|
def _define_search_space(self, param: str, vals: Any) -> Any: |
|
""" |
|
Définit l'espace de recherche pour un hyperparamètre donné. |
|
|
|
:param param: Nom de l'hyperparamètre. |
|
:param vals: Valeurs possibles ou dictionnaire définissant l'espace. |
|
:return: Espace de recherche Ray Tune. |
|
""" |
|
if isinstance(vals, list): |
|
return tune.choice(vals) |
|
elif isinstance(vals, dict): |
|
low = vals.get('low', 0) |
|
high = vals.get('high', 10) |
|
log = vals.get('log', False) |
|
if 'low' in vals and 'high' in vals: |
|
return tune.uniform(param, low, high) if log else tune.randint(param, low, high) |
|
return tune.uniform(param, 0.0, 1.0) |
|
|
|
def _create_scheduler(self) -> ASHAScheduler: |
|
""" |
|
Crée un scheduler ASHAScheduler pour Ray Tune. |
|
|
|
:return: Instance d'ASHAScheduler. |
|
""" |
|
return ASHAScheduler( |
|
max_t=100, |
|
grace_period=10, |
|
reduction_factor=2 |
|
) |
|
|
|
def _create_reporter(self) -> CLIReporter: |
|
""" |
|
Crée un reporter CLIReporter pour Ray Tune. |
|
|
|
:return: Instance de CLIReporter. |
|
""" |
|
return CLIReporter( |
|
metric_columns=["validation_loss", "training_iteration"] |
|
) |
|
|
|
def _train_model(self, config: Dict[str, Any]): |
|
""" |
|
Fonction d'entraînement pour Ray Tune. |
|
|
|
:param config: Configuration des hyperparamètres. |
|
""" |
|
merged_config = config.copy() |
|
hyperparams = merged_config.pop('hyperparameters', {}) |
|
trainer_instance = self._get_trainer_instance(merged_config) |
|
trainer_instance.config.hyperparameters = hyperparams |
|
trainer_instance.train() |
|
results = trainer_instance.evaluate() |
|
tune.report(validation_loss=results.get('validation_loss', float('inf'))) |
|
|
|
def _get_trainer_instance(self, config: Dict[str, Any]) -> "BaseTrainer": |
|
""" |
|
Obtient une instance de BaseTrainer basée sur la configuration. |
|
|
|
:param config: Configuration globale incluant hyperparamètres. |
|
:return: Instance de BaseTrainer. |
|
""" |
|
model_type = config['model']['type'].lower() |
|
trainer_mapping = { |
|
'svm': SvmTrainer, |
|
'random_forest': RandomForestTrainer, |
|
'logistic_regression': LogisticRegressionTrainer, |
|
'linear_regression': LinearRegressionTrainer, |
|
|
|
} |
|
|
|
trainer_class = trainer_mapping.get(model_type) |
|
if not trainer_class: |
|
raise ValueError(f"Type de modèle non supporté : {model_type}") |
|
|
|
return trainer_class( |
|
config=config, |
|
data_path=config['data']['path'], |
|
target_column=config['data']['target_column'] |
|
) |
|
|