Spaces:
Sleeping
Sleeping
File size: 5,366 Bytes
bf5fb5f 8ffb539 bf5fb5f 8ffb539 bf5fb5f 8ffb539 bf5fb5f 8ffb539 bf5fb5f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
# =======================================
# Fichier: optimizers/ray_tune_optimizer.py
# =======================================
from typing import Dict, Any, TYPE_CHECKING
from interfaces.hyperparameter_optimizer import HyperparameterOptimizer
if TYPE_CHECKING:
from base_trainer import BaseTrainer
import ray
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler
from trainers.cuml.svm_trainer import SvmTrainer
from trainers.cuml.random_forest_trainer import RandomForestTrainer
from trainers.cuml.logistic_regression_trainer import LogisticRegressionTrainer
from trainers.cuml.linear_regression_trainer import LinearRegressionTrainer
class RayTuneOptimizer(HyperparameterOptimizer):
"""
Optimiseur d'hyperparamètres basé sur Ray Tune.
Implémente l'interface HyperparameterOptimizer.
"""
def optimize(self, trainer: "BaseTrainer", param_grid: Dict[str, Any]) -> Dict[str, Any]:
"""
Recherche les meilleurs hyperparamètres pour un 'trainer' donné,
selon la grille 'param_grid', en utilisant Ray Tune.
:param trainer: Instance d'une classe implémentant BaseTrainer.
:param param_grid: Dictionnaire définissant l'espace de recherche
pour chaque hyperparamètre.
:return: Un dictionnaire contenant les hyperparamètres optimaux trouvés.
"""
full_config = trainer.config.dict() if hasattr(trainer.config, 'dict') else trainer.config
config = self._create_config(param_grid)
scheduler = self._create_scheduler()
reporter = self._create_reporter()
analysis = tune.run(
self._train_model,
config={"hyperparameters": config, **full_config},
num_samples=full_config['hyperparameters'].get('n_trials', 100),
scheduler=scheduler,
progress_reporter=reporter,
resources_per_trial={"cpu": 1, "gpu": 0}
)
return analysis.best_config
def _create_config(self, param_grid: Dict[str, Any]) -> Dict[str, Any]:
"""
Crée la configuration pour Ray Tune à partir du param_grid.
:param param_grid: Dictionnaire définissant l'espace de recherche
pour chaque hyperparamètre.
:return: Dictionnaire de configuration pour Ray Tune.
"""
return {param: self._define_search_space(param, vals) for param, vals in param_grid.items()}
def _define_search_space(self, param: str, vals: Any) -> Any:
"""
Définit l'espace de recherche pour un hyperparamètre donné.
:param param: Nom de l'hyperparamètre.
:param vals: Valeurs possibles ou dictionnaire définissant l'espace.
:return: Espace de recherche Ray Tune.
"""
if isinstance(vals, list):
return tune.choice(vals)
elif isinstance(vals, dict):
low = vals.get('low', 0)
high = vals.get('high', 10)
log = vals.get('log', False)
if 'low' in vals and 'high' in vals:
return tune.uniform(param, low, high) if log else tune.randint(param, low, high)
return tune.uniform(param, 0.0, 1.0)
def _create_scheduler(self) -> ASHAScheduler:
"""
Crée un scheduler ASHAScheduler pour Ray Tune.
:return: Instance d'ASHAScheduler.
"""
return ASHAScheduler(
max_t=100,
grace_period=10,
reduction_factor=2
)
def _create_reporter(self) -> CLIReporter:
"""
Crée un reporter CLIReporter pour Ray Tune.
:return: Instance de CLIReporter.
"""
return CLIReporter(
metric_columns=["validation_loss", "training_iteration"]
)
def _train_model(self, config: Dict[str, Any]):
"""
Fonction d'entraînement pour Ray Tune.
:param config: Configuration des hyperparamètres.
"""
merged_config = config.copy()
hyperparams = merged_config.pop('hyperparameters', {})
trainer_instance = self._get_trainer_instance(merged_config)
trainer_instance.config.hyperparameters = hyperparams
trainer_instance.train()
results = trainer_instance.evaluate()
tune.report(validation_loss=results.get('validation_loss', float('inf')))
def _get_trainer_instance(self, config: Dict[str, Any]) -> "BaseTrainer":
"""
Obtient une instance de BaseTrainer basée sur la configuration.
:param config: Configuration globale incluant hyperparamètres.
:return: Instance de BaseTrainer.
"""
model_type = config['model']['type'].lower()
trainer_mapping = {
'svm': SvmTrainer,
'random_forest': RandomForestTrainer,
'logistic_regression': LogisticRegressionTrainer,
'linear_regression': LinearRegressionTrainer,
# Ajouter d'autres mappings ici si nécessaire
}
trainer_class = trainer_mapping.get(model_type)
if not trainer_class:
raise ValueError(f"Type de modèle non supporté : {model_type}")
return trainer_class(
config=config,
data_path=config['data']['path'],
target_column=config['data']['target_column']
)
|