# ========================= # Fichier: base_trainer.py # ========================= from abc import ABC, abstractmethod from typing import Union, Optional import cupy as cp from scipy.sparse import csr_matrix import mlflow from config import Config from interfaces.metrics_calculator import MetricsCalculator class BaseTrainer(ABC): """ Classe de base abstraite représentant un entraîneur (trainer) générique, tel que défini dans le diagramme UML. Attributs: config (Config): Configuration globale du système (modèle, data, etc.). classifier (object): Référence vers le classifieur ou le modèle entraîné. metrics_calculator (MetricsCalculator): Outil de calcul et de logging des métriques. data_path (str): Chemin vers les données. target_column (str): Nom de la colonne cible dans les données. """ def __init__(self, config: Config, data_path: str, target_column: str) -> None: """ Initialise un trainer générique avec la configuration et les informations de chemin de données. :param config: Objet de configuration global. :param data_path: Chemin vers le fichier de données. :param target_column: Nom de la colonne cible pour l'entraînement/prédiction. """ self.config: Config = config self.data_path: str = data_path self.target_column: str = target_column self.classifier: object = None self.metrics_calculator: MetricsCalculator = None @abstractmethod def build_components(self) -> None: """ Méthode abstraite. Instancie les composants nécessaires (e.g. le classifieur, éventuellement le vectorizer) selon la config. """ pass @abstractmethod def train(self) -> None: """ Méthode abstraite. Lance la procédure d'entraînement. """ pass @abstractmethod def evaluate(self) -> dict: """ Méthode abstraite. Évalue le modèle entraîné, par exemple sur un jeu de validation ou de test, et calcule les métriques. :return: Dictionnaire contenant les métriques calculées. """ pass def optimize_if_needed(self) -> None: """ Vérifie la configuration pour déterminer si l'optimisation des hyperparamètres est nécessaire. Si oui, instancie l'optimiseur approprié et lance le processus d'optimisation. Met ensuite à jour la configuration du modèle avec les meilleurs paramètres trouvés et reconstruit les composants. """ import logging logger = logging.getLogger(__name__) # Vérifier si l'optimisation est configurée if (self.config.hyperparameters.optimizer and self.config.hyperparameters.param_grid and self.config.hyperparameters.n_trials > 0): logger.info("Démarrage de l'optimisation des hyperparamètres") # Importation et instanciation de l'optimiseur optimizer_type = self.config.hyperparameters.optimizer.lower() if optimizer_type == "optuna": from optimizers.optuna_optimizer import OptunaOptimizer optimizer = OptunaOptimizer() elif optimizer_type == "raytune": from optimizers.ray_tune_optimizer import RayTuneOptimizer optimizer = RayTuneOptimizer() else: raise ValueError(f"Type d'optimizer non supporté: {optimizer_type}") # Lancement de l'optimisation best_params = optimizer.optimize( trainer=self, # Passe l'instance actuelle du trainer param_grid=self.config.hyperparameters.param_grid ) logger.info(f"Meilleurs hyperparamètres trouvés: {best_params}") # Mise à jour de la configuration du modèle avec les meilleurs paramètres self.config.model.params.update(best_params) # Reconstruire les composants avec les nouveaux paramètres logger.info("Reconstruction des composants avec les hyperparamètres optimisés.") self.build_components() else: logger.info("Aucune optimisation des hyperparamètres configurée.") def log_parameters_to_mlflow(self) -> None: """ Appelle une fonction singledispatch (get_relevant_params_for_logging(trainer)) pour récupérer les paramètres pertinents et les logguer, par exemple via MLflow. Implementé ici en tant que méthode non-abstraite, mais la logique de logging devrait être assurée dans l'environnement MLflow approprié. """ # Logue les paramètres du config.model if self.config.model.params: mlflow.log_params(self.config.model.params) # Les paramètres pertinents du modèle (ceux utilisés pour l'initialiser, ex: C, kernel pour SVM) # sont déjà loggués via self.config.model.params ci-dessus, qui est correctement # peuplé grâce à l'interpolation Hydra dans config.yaml. # Éviter de logger self.classifier.get_params() car cela est redondant et # inclut des objets internes non sérialisables comme le handle RAFT/GPU, # causant l'apparition de "" dans les logs MLflow. def _prepare_input_for_fit( self, X: Union[cp.ndarray, csr_matrix]) -> Union[cp.ndarray, csr_matrix]: """ Méthode utilitaire pour préparer les données d'entraînement avant l'ajustement du modèle. :param X: Matrice (cupy.ndarray ou scipy.sparse.csr_matrix) représentant les données. :return: Matrice transformée ou identique, prête pour l'entraînement. """ return X def _prepare_input_for_predict( self, X: Union[cp.ndarray, csr_matrix]) -> Union[cp.ndarray, csr_matrix]: """ Méthode utilitaire pour préparer les données de prédiction avant l'appel à la méthode `predict` du modèle. :param X: Matrice (cupy.ndarray ou scipy.sparse.csr_matrix) représentant les données. :return: Matrice transformée ou identique, prête pour la prédiction. """ return X def _get_binary_predictions(self, X: cp.ndarray) -> cp.ndarray: """ Retourne un vecteur de prédictions binaires (0/1). :param X: Matrice de données de dimension (n_samples, n_features), déjà sous forme cupy.ndarray. :return: Vecteur de prédictions binaires (cupy.ndarray). """ # Ici, la logique de conversion en 0/1 n'est pas spécifiée dans l'UML, # donc on la laisse minimale (raise NotImplementedError si nécessaire). raise NotImplementedError( "La méthode '_get_binary_predictions' doit être implémentée dans une sous-classe." ) def _get_positive_probabilities(self, X: cp.ndarray) -> Optional[cp.ndarray]: """ Retourne la probabilité d'appartenir à la classe positive pour chaque échantillon, si le modèle le permet. Sinon, retourne None. :param X: Matrice de données en cupy.ndarray. :return: Vecteur de probabilités (cupy.ndarray) ou None si non applicable. """ return None def _get_label_dtype(self) -> cp.dtype: """ Retourne le type cupy.dtype approprié pour les labels. :return: Par exemple, cp.int32. """ return cp.int32