|
|
|
|
|
|
|
|
|
|
|
import os |
|
import logging |
|
import hydra |
|
from omegaconf import DictConfig, OmegaConf |
|
import mlflow |
|
|
|
|
|
from trainers.cuml.svm_trainer import SvmTrainer |
|
from trainers.cuml.random_forest_trainer import RandomForestTrainer |
|
from trainers.cuml.logistic_regression_trainer import LogisticRegressionTrainer |
|
from trainers.cuml.linear_regression_trainer import LinearRegressionTrainer |
|
from trainers.huggingface.huggingface_transformer_trainer import HuggingFaceTransformerTrainer |
|
|
|
|
|
from optimizers.optuna_optimizer import OptunaOptimizer |
|
from optimizers.ray_tune_optimizer import RayTuneOptimizer |
|
|
|
|
|
from mlflow_integration.mlflow_decorator import MLflowDecorator |
|
|
|
|
|
from config import Config |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def get_trainer(config: Config): |
|
""" |
|
Crée et retourne l'instance du trainer approprié en fonction de la configuration. |
|
|
|
Args: |
|
config: Objet de configuration |
|
|
|
Returns: |
|
Une instance concrète de BaseTrainer |
|
""" |
|
model_type = config.model.type.lower() |
|
|
|
|
|
trainer_map = { |
|
"svm": SvmTrainer, |
|
"random_forest": RandomForestTrainer, |
|
"logistic_regression": LogisticRegressionTrainer, |
|
"linear_regression": LinearRegressionTrainer, |
|
"transformer": HuggingFaceTransformerTrainer, |
|
} |
|
|
|
if model_type not in trainer_map: |
|
raise ValueError(f"Type de modèle non supporté: {model_type}") |
|
|
|
|
|
trainer_class = trainer_map[model_type] |
|
return trainer_class( |
|
config=config, |
|
data_path=config.data.path, |
|
target_column=config.data.target_column |
|
) |
|
|
|
|
|
def get_optimizer(config: Config): |
|
""" |
|
Crée et retourne l'instance d'optimizer appropriée en fonction de la configuration. |
|
|
|
Args: |
|
config: Objet de configuration |
|
|
|
Returns: |
|
Une instance concrète de HyperparameterOptimizer |
|
""" |
|
optimizer_type = config.hyperparameters.optimizer.lower() |
|
|
|
|
|
optimizer_map = { |
|
"optuna": OptunaOptimizer, |
|
"raytune": RayTuneOptimizer, |
|
} |
|
|
|
if optimizer_type not in optimizer_map: |
|
raise ValueError(f"Type d'optimizer non supporté: {optimizer_type}") |
|
|
|
|
|
optimizer_class = optimizer_map[optimizer_type] |
|
return optimizer_class() |
|
|
|
|
|
@hydra.main(config_path="conf", config_name="config") |
|
def main(cfg: DictConfig) -> None: |
|
""" |
|
Point d'entrée principal de l'application. |
|
|
|
Args: |
|
cfg: Configuration Hydra sous forme de DictConfig |
|
""" |
|
|
|
config_dict = OmegaConf.to_container(cfg, resolve=True) |
|
config = Config(**config_dict) |
|
|
|
logger.info(f"Configuration chargée: {config}") |
|
|
|
|
|
trainer = get_trainer(config) |
|
|
|
|
|
trainer.build_components() |
|
|
|
|
|
mlflow_decorator = MLflowDecorator( |
|
experiment_name=config.mlflow.experiment_name, |
|
tracking_uri=config.mlflow.tracking_uri |
|
) |
|
|
|
train_with_mlflow = mlflow_decorator(trainer.train) |
|
evaluate_with_mlflow = mlflow_decorator(trainer.evaluate) |
|
log_params_with_mlflow = mlflow_decorator(trainer.log_parameters_to_mlflow) |
|
optimize_if_needed_with_mlflow = mlflow_decorator(trainer.optimize_if_needed) |
|
|
|
logger.info("Vérification et lancement de l'optimisation des hyperparamètres si nécessaire (avec MLflow)...") |
|
optimize_if_needed_with_mlflow() |
|
|
|
|
|
logger.info("Lancement de l'entraînement avec MLflow...") |
|
train_with_mlflow() |
|
|
|
|
|
logger.info("Lancement de l'évaluation avec MLflow...") |
|
evaluate_with_mlflow() |
|
|
|
|
|
logger.info("Logging des paramètres avec MLflow...") |
|
log_params_with_mlflow() |
|
|
|
logger.info("Entraînement, évaluation et logging des paramètres terminés avec succès via MLflow.") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |