fioriclass's picture
initialisation
bf5fb5f
raw
history blame
4.57 kB
#!/usr/bin/env python3
# =========================
# Fichier: main.py
# =========================
import os
import logging
import hydra
from omegaconf import DictConfig, OmegaConf
import mlflow
# Import des trainers
from trainers.cuml.svm_trainer import SvmTrainer
from trainers.cuml.random_forest_trainer import RandomForestTrainer
from trainers.cuml.logistic_regression_trainer import LogisticRegressionTrainer
from trainers.cuml.linear_regression_trainer import LinearRegressionTrainer
from trainers.huggingface.huggingface_transformer_trainer import HuggingFaceTransformerTrainer
# Import des optimizers
from optimizers.optuna_optimizer import OptunaOptimizer
from optimizers.ray_tune_optimizer import RayTuneOptimizer
# Import du décorateur MLflow
from mlflow_integration.mlflow_decorator import MLflowDecorator
# Import de la configuration
from config import Config
# Configuration du logging
logger = logging.getLogger(__name__)
def get_trainer(config: Config):
"""
Crée et retourne l'instance du trainer approprié en fonction de la configuration.
Args:
config: Objet de configuration
Returns:
Une instance concrète de BaseTrainer
"""
model_type = config.model.type.lower()
# Mapping des types de modèles vers leurs trainers
trainer_map = {
"svm": SvmTrainer,
"random_forest": RandomForestTrainer,
"logistic_regression": LogisticRegressionTrainer,
"linear_regression": LinearRegressionTrainer,
"transformer": HuggingFaceTransformerTrainer,
}
if model_type not in trainer_map:
raise ValueError(f"Type de modèle non supporté: {model_type}")
# Création de l'instance du trainer avec la configuration
trainer_class = trainer_map[model_type]
return trainer_class(
config=config,
data_path=config.data.path,
target_column=config.data.target_column
)
def get_optimizer(config: Config):
"""
Crée et retourne l'instance d'optimizer appropriée en fonction de la configuration.
Args:
config: Objet de configuration
Returns:
Une instance concrète de HyperparameterOptimizer
"""
optimizer_type = config.hyperparameters.optimizer.lower()
# Mapping des types d'optimizers
optimizer_map = {
"optuna": OptunaOptimizer,
"raytune": RayTuneOptimizer,
}
if optimizer_type not in optimizer_map:
raise ValueError(f"Type d'optimizer non supporté: {optimizer_type}")
# Création de l'instance de l'optimizer
optimizer_class = optimizer_map[optimizer_type]
return optimizer_class()
@hydra.main(config_path="conf", config_name="config")
def main(cfg: DictConfig) -> None:
"""
Point d'entrée principal de l'application.
Args:
cfg: Configuration Hydra sous forme de DictConfig
"""
# Conversion de la configuration Hydra en configuration Pydantic
config_dict = OmegaConf.to_container(cfg, resolve=True)
config = Config(**config_dict)
logger.info(f"Configuration chargée: {config}")
# Création du trainer approprié
trainer = get_trainer(config)
# Construction des composants (vectorizer, classifier, etc.)
trainer.build_components()
mlflow_decorator = MLflowDecorator(
experiment_name=config.mlflow.experiment_name,
tracking_uri=config.mlflow.tracking_uri
)
# Appliquer le décorateur aux méthodes clés
train_with_mlflow = mlflow_decorator(trainer.train)
evaluate_with_mlflow = mlflow_decorator(trainer.evaluate)
log_params_with_mlflow = mlflow_decorator(trainer.log_parameters_to_mlflow)
optimize_if_needed_with_mlflow = mlflow_decorator(trainer.optimize_if_needed) # Décorer aussi l'optimisation
logger.info("Vérification et lancement de l'optimisation des hyperparamètres si nécessaire (avec MLflow)...")
optimize_if_needed_with_mlflow()
# Exécuter l'entraînement (toujours avec MLflow)
logger.info("Lancement de l'entraînement avec MLflow...")
train_with_mlflow()
# Exécuter l'évaluation (toujours avec MLflow)
logger.info("Lancement de l'évaluation avec MLflow...")
evaluate_with_mlflow()
# Logger les paramètres (toujours avec MLflow)
logger.info("Logging des paramètres avec MLflow...")
log_params_with_mlflow()
logger.info("Entraînement, évaluation et logging des paramètres terminés avec succès via MLflow.")
if __name__ == "__main__":
main()