emotion_classifier / src /base_trainer.py
fioriclass's picture
maj pour sauvegarder le model
8c7aba9
# =========================
# Fichier: base_trainer.py
# =========================
from abc import ABC, abstractmethod
from typing import Union, Optional
import cupy as cp
from scipy.sparse import csr_matrix
import mlflow
from config import Config
from interfaces.metrics_calculator import MetricsCalculator
class BaseTrainer(ABC):
"""
Classe de base abstraite représentant un entraîneur (trainer) générique,
tel que défini dans le diagramme UML.
Attributs:
config (Config): Configuration globale du système (modèle, data, etc.).
classifier (object): Référence vers le classifieur ou le modèle entraîné.
metrics_calculator (MetricsCalculator): Outil de calcul et de logging des métriques.
data_path (str): Chemin vers les données.
target_column (str): Nom de la colonne cible dans les données.
"""
def __init__(self, config: Config, data_path: str,
target_column: str) -> None:
"""
Initialise un trainer générique avec la configuration et les informations de chemin de données.
:param config: Objet de configuration global.
:param data_path: Chemin vers le fichier de données.
:param target_column: Nom de la colonne cible pour l'entraînement/prédiction.
"""
self.config: Config = config
self.data_path: str = data_path
self.target_column: str = target_column
self.classifier: object = None
self.metrics_calculator: MetricsCalculator = None
@abstractmethod
def build_components(self) -> None:
"""
Méthode abstraite. Instancie les composants nécessaires
(e.g. le classifieur, éventuellement le vectorizer) selon la config.
"""
pass
@abstractmethod
def train(self) -> None:
"""
Méthode abstraite. Lance la procédure d'entraînement.
"""
pass
@abstractmethod
def evaluate(self) -> dict:
"""
Méthode abstraite. Évalue le modèle entraîné, par exemple
sur un jeu de validation ou de test, et calcule les métriques.
:return: Dictionnaire contenant les métriques calculées.
"""
pass
def optimize_if_needed(self) -> None:
"""
Vérifie la configuration pour déterminer si l'optimisation des hyperparamètres
est nécessaire. Si oui, instancie l'optimiseur approprié et lance
le processus d'optimisation. Met ensuite à jour la configuration du modèle
avec les meilleurs paramètres trouvés et reconstruit les composants.
"""
import logging
logger = logging.getLogger(__name__)
# Vérifier si l'optimisation est configurée
if (self.config.hyperparameters.optimizer and
self.config.hyperparameters.param_grid and
self.config.hyperparameters.n_trials > 0):
logger.info("Démarrage de l'optimisation des hyperparamètres")
# Importation et instanciation de l'optimiseur
optimizer_type = self.config.hyperparameters.optimizer.lower()
if optimizer_type == "optuna":
from optimizers.optuna_optimizer import OptunaOptimizer
optimizer = OptunaOptimizer()
elif optimizer_type == "raytune":
from optimizers.ray_tune_optimizer import RayTuneOptimizer
optimizer = RayTuneOptimizer()
else:
raise ValueError(f"Type d'optimizer non supporté: {optimizer_type}")
# Lancement de l'optimisation
best_params = optimizer.optimize(
trainer=self, # Passe l'instance actuelle du trainer
param_grid=self.config.hyperparameters.param_grid
)
logger.info(f"Meilleurs hyperparamètres trouvés: {best_params}")
# Mise à jour de la configuration du modèle avec les meilleurs paramètres
self.config.model.params.update(best_params)
# Reconstruire les composants avec les nouveaux paramètres
logger.info("Reconstruction des composants avec les hyperparamètres optimisés.")
self.build_components()
else:
logger.info("Aucune optimisation des hyperparamètres configurée.")
def log_parameters_to_mlflow(self) -> None:
"""
Appelle une fonction singledispatch (get_relevant_params_for_logging(trainer)) pour récupérer
les paramètres pertinents et les logguer, par exemple via MLflow.
Implementé ici en tant que méthode non-abstraite, mais la logique de logging
devrait être assurée dans l'environnement MLflow approprié.
"""
# Logue les paramètres du config.model
if self.config.model.params:
mlflow.log_params(self.config.model.params)
# Les paramètres pertinents du modèle (ceux utilisés pour l'initialiser, ex: C, kernel pour SVM)
# sont déjà loggués via self.config.model.params ci-dessus, qui est correctement
# peuplé grâce à l'interpolation Hydra dans config.yaml.
# Éviter de logger self.classifier.get_params() car cela est redondant et
# inclut des objets internes non sérialisables comme le handle RAFT/GPU,
# causant l'apparition de "<pylibraft.common.handle.Handle object ...>" dans les logs MLflow.
def _prepare_input_for_fit(
self, X: Union[cp.ndarray,
csr_matrix]) -> Union[cp.ndarray, csr_matrix]:
"""
Méthode utilitaire pour préparer les données d'entraînement avant
l'ajustement du modèle.
:param X: Matrice (cupy.ndarray ou scipy.sparse.csr_matrix) représentant les données.
:return: Matrice transformée ou identique, prête pour l'entraînement.
"""
return X
def _prepare_input_for_predict(
self, X: Union[cp.ndarray,
csr_matrix]) -> Union[cp.ndarray, csr_matrix]:
"""
Méthode utilitaire pour préparer les données de prédiction avant
l'appel à la méthode `predict` du modèle.
:param X: Matrice (cupy.ndarray ou scipy.sparse.csr_matrix) représentant les données.
:return: Matrice transformée ou identique, prête pour la prédiction.
"""
return X
def _get_binary_predictions(self, X: cp.ndarray) -> cp.ndarray:
"""
Retourne un vecteur de prédictions binaires (0/1).
:param X: Matrice de données de dimension (n_samples, n_features),
déjà sous forme cupy.ndarray.
:return: Vecteur de prédictions binaires (cupy.ndarray).
"""
# Ici, la logique de conversion en 0/1 n'est pas spécifiée dans l'UML,
# donc on la laisse minimale (raise NotImplementedError si nécessaire).
raise NotImplementedError(
"La méthode '_get_binary_predictions' doit être implémentée dans une sous-classe."
)
def _get_positive_probabilities(self,
X: cp.ndarray) -> Optional[cp.ndarray]:
"""
Retourne la probabilité d'appartenir à la classe positive pour chaque échantillon,
si le modèle le permet. Sinon, retourne None.
:param X: Matrice de données en cupy.ndarray.
:return: Vecteur de probabilités (cupy.ndarray) ou None si non applicable.
"""
return None
def _get_label_dtype(self) -> cp.dtype:
"""
Retourne le type cupy.dtype approprié pour les labels.
:return: Par exemple, cp.int32.
"""
return cp.int32