emotion_classifier / src /optimizers /optuna_optimizer.py
fioriclass's picture
on optimise le f1 score et correction erreur config model
513cd3c
# =====================================
# Fichier: optimizers/optuna_optimizer.py
# =====================================
from typing import Dict, Any, TYPE_CHECKING
from interfaces.hyperparameter_optimizer import HyperparameterOptimizer
import optuna
if TYPE_CHECKING:
from base_trainer import BaseTrainer
class OptunaOptimizer(HyperparameterOptimizer):
"""
Optimiseur d'hyperparamètres basé sur la librairie Optuna.
Implémente l'interface HyperparameterOptimizer.
"""
def optimize(self, trainer: "BaseTrainer",
param_grid: Dict[str, Any]) -> Dict[str, Any]:
"""
Recherche les meilleurs hyperparamètres pour un 'trainer' donné,
selon la grille 'param_grid', en utilisant Optuna.
:param trainer: Instance d'une classe implémentant BaseTrainer.
:param param_grid: Dictionnaire définissant l'espace de recherche
pour chaque hyperparamètre.
:return: Un dictionnaire contenant les hyperparamètres optimaux trouvés.
"""
def suggest_param(trial, param, values):
return (
trial.suggest_categorical(param, values) if isinstance(values, list) else
trial.suggest_float(param, values['low'], values['high'], log=values.get('log', False)) if isinstance(values, dict) and 'low' in values and 'high' in values else
trial.suggest_int(param, values['low'], values.get('high', 10)) if isinstance(values, dict) else
trial.suggest_float(param, 0.0, 1.0)
)
def objective(trial):
params = {param: suggest_param(trial, param, vals) for param, vals in param_grid.items()}
trainer.config.hyperparameters = params
trainer.train()
results = trainer.evaluate()
# Utiliser f1 comme métrique à maximiser (en inversant le signe car Optuna minimise)
model_type = trainer.config.model.type.lower()
return -results.get(f'{model_type}_f1', 0.0) # Inverser le signe pour maximiser
study = optuna.create_study(direction='minimize') # Minimiser -f1 équivaut à maximiser f1
n_trials = trainer.config.hyperparameters.n_trials
study.optimize(objective, n_trials=n_trials)
return study.best_params