fioriclass's picture
propre et fonctionnel
47d1597
raw
history blame
6.03 kB
#!/usr/bin/env python3
# =========================
# Fichier: main.py
# =========================
import os
import logging
import hydra
import mlflow
from omegaconf import DictConfig, OmegaConf
from config import Config
from trainers.cuml.svm_trainer import SvmTrainer
from trainers.cuml.random_forest_trainer import RandomForestTrainer
from trainers.cuml.logistic_regression_trainer import LogisticRegressionTrainer
from trainers.cuml.linear_regression_trainer import LinearRegressionTrainer
from trainers.huggingface.huggingface_transformer_trainer import HuggingFaceTransformerTrainer
from optimizers.optuna_optimizer import OptunaOptimizer
from optimizers.ray_tune_optimizer import RayTuneOptimizer
from mlflow_integration.mlflow_decorator import MLflowDecorator
import tempfile
import pickle
import pandas as pd
from utilities.cuml_pyfunc_wrapper import CuMLPyFuncWrapper
logger = logging.getLogger(__name__)
def get_trainer(config: Config):
"""
Crée et retourne l'instance du trainer approprié en fonction de la configuration.
"""
model_type = config.model.type.lower()
# Mapping des types de modèles vers leurs trainers
trainer_map = {
"svm": SvmTrainer,
"random_forest": RandomForestTrainer,
"logistic_regression": LogisticRegressionTrainer,
"linear_regression": LinearRegressionTrainer,
"transformer": HuggingFaceTransformerTrainer,
}
if model_type not in trainer_map:
raise ValueError(f"Type de modèle non supporté: {model_type}")
trainer_class = trainer_map[model_type]
return trainer_class(
config=config,
data_path=config.data.path,
target_column=config.data.target_column
)
def get_optimizer(config: Config):
"""
Crée et retourne l'instance d'optimizer appropriée en fonction de la configuration.
"""
optimizer_type = config.hyperparameters.optimizer.lower()
optimizer_map = {
"optuna": OptunaOptimizer,
"raytune": RayTuneOptimizer,
}
if optimizer_type not in optimizer_map:
raise ValueError(f"Type d'optimizer non supporté: {optimizer_type}")
return optimizer_map[optimizer_type]()
def log_cuml_model_to_mlflow(trainer_instance, run_id=None):
"""
Sérialise le vectorizer et le classifier dans un répertoire temporaire
puis log le tout dans MLflow en tant que modèle PyFunc.
Les artifacts sont ainsi stockés dans mlruns, liés au run en cours.
"""
logger.info("Logging du modèle CuML via mlflow.pyfunc.log_model...")
input_example = pd.DataFrame({"example_text": ["exemple"]})
# On va utiliser mlflow.pyfunc.log_model pour stocker le wrapper PyFunc + nos artifacts
with tempfile.TemporaryDirectory() as tmpdir:
vectorizer_path = os.path.join(tmpdir, "vectorizer.pkl")
classifier_path = os.path.join(tmpdir, "classifier.pkl")
# Sauvegarde sur disque
with open(vectorizer_path, "wb") as vf:
pickle.dump(trainer_instance.vectorizer, vf)
with open(classifier_path, "wb") as cf:
pickle.dump(trainer_instance.classifier, cf)
# PyFunc wrapper (placeholder, héberge la logique de load_model)
pyfunc_wrapper = CuMLPyFuncWrapper(
vectorizer=None,
classifier=None
)
# Log en PyFunc; "cuml_model" est le chemin (artifact_path) où sera stocké le modèle dans MLflow
mlflow.pyfunc.log_model(
artifact_path="cuml_model",
python_model=pyfunc_wrapper,
artifacts={
"vectorizer": vectorizer_path,
"classifier": classifier_path
},
input_example=input_example
)
logger.info("Le modèle et ses artifacts ont été enregistrés dans MLflow.")
@hydra.main(config_path="conf", config_name="config", version_base=None)
def main(cfg: DictConfig) -> None:
"""
Point d'entrée principal de l'application.
"""
try:
config = Config(**OmegaConf.to_container(cfg, resolve=True))
except Exception as e:
logger.error(f"Erreur lors de la validation Pydantic de la configuration: {e}")
logger.error(f"Configuration après fusion Hydra: \n{OmegaConf.to_yaml(cfg)}")
raise
logger.info(f"Configuration Pydantic finale chargée: {config}")
# Sélection du tracker MLflow
mlflow.set_tracking_uri(config.mlflow.tracking_uri)
trainer = get_trainer(config)
trainer.build_components()
def run_pipeline(trainer_instance):
"""
Exécute la séquence complète :
- Optimisation hyperparamètres (si besoin)
- Entraînement
- Évaluation
- Logging MLflow (paramètres, métriques, et modèles)
"""
logger.info("Vérification et lancement de l'optimisation des hyperparamètres si nécessaire...")
trainer_instance.optimize_if_needed()
logger.info("Lancement de l'entraînement...")
trainer_instance.train()
logger.info("Lancement de l'évaluation...")
metrics = trainer_instance.evaluate()
logger.info(f"Metrics calculés: {metrics}")
# Log des métriques
mlflow.log_metrics(metrics)
logger.info("Logging des paramètres...")
trainer_instance.log_parameters_to_mlflow()
# Log du modèle final (vectorizer+classifier) sous forme PyFunc
log_cuml_model_to_mlflow(trainer_instance)
logger.info("Pipeline MLflow complet terminé.")
# On utilise un décorateur défini pour centraliser le démarrage/arrêt du run
mlflow_decorator = MLflowDecorator(
experiment_name=config.mlflow.experiment_name,
tracking_uri=config.mlflow.tracking_uri
)
run_pipeline_with_mlflow = mlflow_decorator(run_pipeline)
logger.info("Lancement du pipeline complet avec MLflow...")
run_pipeline_with_mlflow(trainer)
logger.info("Pipeline MLflow terminé avec succès.")
if __name__ == "__main__":
main()