|
|
|
|
|
|
|
|
|
from typing import Dict, Any, TYPE_CHECKING |
|
from interfaces.hyperparameter_optimizer import HyperparameterOptimizer |
|
import optuna |
|
|
|
if TYPE_CHECKING: |
|
from base_trainer import BaseTrainer |
|
|
|
|
|
class OptunaOptimizer(HyperparameterOptimizer): |
|
""" |
|
Optimiseur d'hyperparamètres basé sur la librairie Optuna. |
|
Implémente l'interface HyperparameterOptimizer. |
|
""" |
|
|
|
def optimize(self, trainer: "BaseTrainer", |
|
param_grid: Dict[str, Any]) -> Dict[str, Any]: |
|
""" |
|
Recherche les meilleurs hyperparamètres pour un 'trainer' donné, |
|
selon la grille 'param_grid', en utilisant Optuna. |
|
|
|
:param trainer: Instance d'une classe implémentant BaseTrainer. |
|
:param param_grid: Dictionnaire définissant l'espace de recherche |
|
pour chaque hyperparamètre. |
|
:return: Un dictionnaire contenant les hyperparamètres optimaux trouvés. |
|
""" |
|
def suggest_param(trial, param, values): |
|
return ( |
|
trial.suggest_categorical(param, values) if isinstance(values, list) else |
|
trial.suggest_float(param, values['low'], values['high'], log=values.get('log', False)) if isinstance(values, dict) and 'low' in values and 'high' in values else |
|
trial.suggest_int(param, values['low'], values.get('high', 10)) if isinstance(values, dict) else |
|
trial.suggest_float(param, 0.0, 1.0) |
|
) |
|
|
|
def objective(trial): |
|
params = {param: suggest_param(trial, param, vals) for param, vals in param_grid.items()} |
|
trainer.config.hyperparameters = params |
|
trainer.train() |
|
results = trainer.evaluate() |
|
|
|
model_type = trainer.config.model.type.lower() |
|
return -results.get(f'{model_type}_f1', 0.0) |
|
|
|
study = optuna.create_study(direction='minimize') |
|
n_trials = trainer.config.hyperparameters.n_trials |
|
study.optimize(objective, n_trials=n_trials) |
|
|
|
return study.best_params |
|
|