Spaces:
Sleeping
Sleeping
# ====================== | |
# Fichier: config.py | |
# ====================== | |
from pydantic import BaseModel, Field | |
from typing import Dict, Any | |
class ModelConfig(BaseModel): | |
""" | |
Représente la configuration du modèle, incluant le type de modèle | |
et les paramètres associés. | |
""" | |
type: str = Field( | |
..., | |
description= | |
"Le type de modèle à entraîner (ex. 'svm', 'random_forest', 'logistic_regression', etc.)." | |
) | |
params: Dict[str, Any] = Field( | |
default_factory=dict, | |
description="Dictionnaire des paramètres propres au modèle choisi.") | |
class VectorizationConfig(BaseModel): | |
""" | |
Représente la configuration de la vectorisation, incluant la méthode | |
et les paramètres éventuels. | |
""" | |
method: str = Field( | |
..., | |
description="Le type de vectorisation à utiliser (ex. 'tfidf', 'bow')." | |
) | |
tfidf: Dict[str, Any] = Field( | |
default_factory=dict, | |
description="Paramètres spécifiques à une vectorisation TF-IDF.") | |
bow: Dict[str, Any] = Field( | |
default_factory=dict, | |
description="Paramètres spécifiques à une vectorisation bag-of-words.") | |
class DataConfig(BaseModel): | |
""" | |
Représente la configuration liée aux données, incluant | |
le chemin vers les données et le nom de la colonne cible. | |
""" | |
path: str = Field(..., | |
description="Chemin d'accès vers la source de données.") | |
target_column: str = Field( | |
..., description="Nom de la colonne contenant la variable cible.") | |
class HyperparameterConfig(BaseModel): | |
""" | |
Représente la configuration pour l'optimisation des hyperparamètres, | |
incluant le nom de l'optimiseur, la grille de paramètres et | |
le nombre d'itérations d'entraînement. | |
""" | |
optimizer: str = Field( | |
..., | |
description= | |
"Nom de l'optimiseur d'hyperparamètres (ex. 'optuna', 'raytune').") | |
param_grid: Dict[str, Any] = Field( | |
default_factory=dict, | |
description= | |
"Grille définissant l'espace de recherche pour chaque hyperparamètre.") | |
n_trials: int = Field( | |
default=1, | |
description="Nombre d'essais pour la recherche d'hyperparamètres.") | |
class MLflowConfig(BaseModel): | |
""" | |
Représente la configuration pour MLflow, incluant le nom de l'expérience | |
et l'URI de tracking. | |
""" | |
experiment_name: str = Field( | |
..., | |
description="Nom de l'expérience MLflow." | |
) | |
tracking_uri: str = Field( | |
..., | |
description="URI de tracking MLflow." | |
) | |
class Config(BaseModel): | |
""" | |
Objet de configuration global combinant la section modèle, vectorisation, | |
données, hyperparamètres et MLflow. | |
""" | |
model: ModelConfig | |
vectorization: VectorizationConfig | |
data: DataConfig | |
hyperparameters: HyperparameterConfig | |
mlflow: MLflowConfig | |