|
|
|
|
|
|
|
|
|
from pydantic import BaseModel, Field |
|
from typing import Dict, Any |
|
|
|
|
|
class ModelConfig(BaseModel): |
|
""" |
|
Représente la configuration du modèle, incluant le type de modèle |
|
et les paramètres associés. |
|
""" |
|
type: str = Field( |
|
..., |
|
description= |
|
"Le type de modèle à entraîner (ex. 'svm', 'random_forest', 'logistic_regression', etc.)." |
|
) |
|
params: Dict[str, Any] = Field( |
|
default_factory=dict, |
|
description="Dictionnaire des paramètres propres au modèle choisi.") |
|
|
|
|
|
class VectorizationConfig(BaseModel): |
|
""" |
|
Représente la configuration de la vectorisation, incluant la méthode |
|
et les paramètres éventuels. |
|
""" |
|
method: str = Field( |
|
..., |
|
description="Le type de vectorisation à utiliser (ex. 'tfidf', 'bow')." |
|
) |
|
tfidf: Dict[str, Any] = Field( |
|
default_factory=dict, |
|
description="Paramètres spécifiques à une vectorisation TF-IDF.") |
|
bow: Dict[str, Any] = Field( |
|
default_factory=dict, |
|
description="Paramètres spécifiques à une vectorisation bag-of-words.") |
|
|
|
|
|
class DataConfig(BaseModel): |
|
""" |
|
Représente la configuration liée aux données, incluant |
|
le chemin vers les données et le nom de la colonne cible. |
|
""" |
|
path: str = Field(..., |
|
description="Chemin d'accès vers la source de données.") |
|
target_column: str = Field( |
|
..., description="Nom de la colonne contenant la variable cible.") |
|
|
|
|
|
class HyperparameterConfig(BaseModel): |
|
""" |
|
Représente la configuration pour l'optimisation des hyperparamètres, |
|
incluant le nom de l'optimiseur, la grille de paramètres et |
|
le nombre d'itérations d'entraînement. |
|
""" |
|
optimizer: str = Field( |
|
..., |
|
description= |
|
"Nom de l'optimiseur d'hyperparamètres (ex. 'optuna', 'raytune').") |
|
param_grid: Dict[str, Any] = Field( |
|
default_factory=dict, |
|
description= |
|
"Grille définissant l'espace de recherche pour chaque hyperparamètre.") |
|
n_trials: int = Field( |
|
default=1, |
|
description="Nombre d'essais pour la recherche d'hyperparamètres.") |
|
|
|
|
|
class MLflowConfig(BaseModel): |
|
""" |
|
Représente la configuration pour MLflow, incluant le nom de l'expérience |
|
et l'URI de tracking. |
|
""" |
|
experiment_name: str = Field( |
|
..., |
|
description="Nom de l'expérience MLflow." |
|
) |
|
tracking_uri: str = Field( |
|
..., |
|
description="URI de tracking MLflow." |
|
) |
|
|
|
|
|
class Config(BaseModel): |
|
""" |
|
Objet de configuration global combinant la section modèle, vectorisation, |
|
données, hyperparamètres et MLflow. |
|
""" |
|
model: ModelConfig |
|
vectorization: VectorizationConfig |
|
data: DataConfig |
|
hyperparameters: HyperparameterConfig |
|
mlflow: MLflowConfig |
|
|