emotion_classifier / src /optimizers /ray_tune_optimizer.py
fioriclass's picture
correction import et autre
8ffb539
# =======================================
# Fichier: optimizers/ray_tune_optimizer.py
# =======================================
from typing import Dict, Any, TYPE_CHECKING
from interfaces.hyperparameter_optimizer import HyperparameterOptimizer
if TYPE_CHECKING:
from base_trainer import BaseTrainer
import ray
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler
from trainers.cuml.svm_trainer import SvmTrainer
from trainers.cuml.random_forest_trainer import RandomForestTrainer
from trainers.cuml.logistic_regression_trainer import LogisticRegressionTrainer
from trainers.cuml.linear_regression_trainer import LinearRegressionTrainer
class RayTuneOptimizer(HyperparameterOptimizer):
"""
Optimiseur d'hyperparamètres basé sur Ray Tune.
Implémente l'interface HyperparameterOptimizer.
"""
def optimize(self, trainer: "BaseTrainer", param_grid: Dict[str, Any]) -> Dict[str, Any]:
"""
Recherche les meilleurs hyperparamètres pour un 'trainer' donné,
selon la grille 'param_grid', en utilisant Ray Tune.
:param trainer: Instance d'une classe implémentant BaseTrainer.
:param param_grid: Dictionnaire définissant l'espace de recherche
pour chaque hyperparamètre.
:return: Un dictionnaire contenant les hyperparamètres optimaux trouvés.
"""
full_config = trainer.config.dict() if hasattr(trainer.config, 'dict') else trainer.config
config = self._create_config(param_grid)
scheduler = self._create_scheduler()
reporter = self._create_reporter()
analysis = tune.run(
self._train_model,
config={"hyperparameters": config, **full_config},
num_samples=full_config['hyperparameters'].get('n_trials', 100),
scheduler=scheduler,
progress_reporter=reporter,
resources_per_trial={"cpu": 1, "gpu": 0}
)
return analysis.best_config
def _create_config(self, param_grid: Dict[str, Any]) -> Dict[str, Any]:
"""
Crée la configuration pour Ray Tune à partir du param_grid.
:param param_grid: Dictionnaire définissant l'espace de recherche
pour chaque hyperparamètre.
:return: Dictionnaire de configuration pour Ray Tune.
"""
return {param: self._define_search_space(param, vals) for param, vals in param_grid.items()}
def _define_search_space(self, param: str, vals: Any) -> Any:
"""
Définit l'espace de recherche pour un hyperparamètre donné.
:param param: Nom de l'hyperparamètre.
:param vals: Valeurs possibles ou dictionnaire définissant l'espace.
:return: Espace de recherche Ray Tune.
"""
if isinstance(vals, list):
return tune.choice(vals)
elif isinstance(vals, dict):
low = vals.get('low', 0)
high = vals.get('high', 10)
log = vals.get('log', False)
if 'low' in vals and 'high' in vals:
return tune.uniform(param, low, high) if log else tune.randint(param, low, high)
return tune.uniform(param, 0.0, 1.0)
def _create_scheduler(self) -> ASHAScheduler:
"""
Crée un scheduler ASHAScheduler pour Ray Tune.
:return: Instance d'ASHAScheduler.
"""
return ASHAScheduler(
max_t=100,
grace_period=10,
reduction_factor=2
)
def _create_reporter(self) -> CLIReporter:
"""
Crée un reporter CLIReporter pour Ray Tune.
:return: Instance de CLIReporter.
"""
return CLIReporter(
metric_columns=["validation_loss", "training_iteration"]
)
def _train_model(self, config: Dict[str, Any]):
"""
Fonction d'entraînement pour Ray Tune.
:param config: Configuration des hyperparamètres.
"""
merged_config = config.copy()
hyperparams = merged_config.pop('hyperparameters', {})
trainer_instance = self._get_trainer_instance(merged_config)
trainer_instance.config.hyperparameters = hyperparams
trainer_instance.train()
results = trainer_instance.evaluate()
tune.report(validation_loss=results.get('validation_loss', float('inf')))
def _get_trainer_instance(self, config: Dict[str, Any]) -> "BaseTrainer":
"""
Obtient une instance de BaseTrainer basée sur la configuration.
:param config: Configuration globale incluant hyperparamètres.
:return: Instance de BaseTrainer.
"""
model_type = config['model']['type'].lower()
trainer_mapping = {
'svm': SvmTrainer,
'random_forest': RandomForestTrainer,
'logistic_regression': LogisticRegressionTrainer,
'linear_regression': LinearRegressionTrainer,
# Ajouter d'autres mappings ici si nécessaire
}
trainer_class = trainer_mapping.get(model_type)
if not trainer_class:
raise ValueError(f"Type de modèle non supporté : {model_type}")
return trainer_class(
config=config,
data_path=config['data']['path'],
target_column=config['data']['target_column']
)