fioriclass's picture
correction bug config
43d4438
# ======================
# Fichier: config.py
# ======================
from pydantic import BaseModel, Field
from typing import Dict, Any
class ModelConfig(BaseModel):
"""
Représente la configuration du modèle, incluant le type de modèle
et les paramètres associés.
"""
type: str = Field(
...,
description=
"Le type de modèle à entraîner (ex. 'svm', 'random_forest', 'logistic_regression', etc.)."
)
params: Dict[str, Any] = Field(
default_factory=dict,
description="Dictionnaire des paramètres propres au modèle choisi.")
class VectorizationConfig(BaseModel):
"""
Représente la configuration de la vectorisation, incluant la méthode
et les paramètres éventuels.
"""
method: str = Field(
...,
description="Le type de vectorisation à utiliser (ex. 'tfidf', 'bow')."
)
tfidf: Dict[str, Any] = Field(
default_factory=dict,
description="Paramètres spécifiques à une vectorisation TF-IDF.")
bow: Dict[str, Any] = Field(
default_factory=dict,
description="Paramètres spécifiques à une vectorisation bag-of-words.")
class DataConfig(BaseModel):
"""
Représente la configuration liée aux données, incluant
le chemin vers les données et le nom de la colonne cible.
"""
path: str = Field(...,
description="Chemin d'accès vers la source de données.")
target_column: str = Field(
..., description="Nom de la colonne contenant la variable cible.")
class HyperparameterConfig(BaseModel):
"""
Représente la configuration pour l'optimisation des hyperparamètres,
incluant le nom de l'optimiseur, la grille de paramètres et
le nombre d'itérations d'entraînement.
"""
optimizer: str = Field(
...,
description=
"Nom de l'optimiseur d'hyperparamètres (ex. 'optuna', 'raytune').")
param_grid: Dict[str, Any] = Field(
default_factory=dict,
description=
"Grille définissant l'espace de recherche pour chaque hyperparamètre.")
n_trials: int = Field(
default=1,
description="Nombre d'essais pour la recherche d'hyperparamètres.")
class MLflowConfig(BaseModel):
"""
Représente la configuration pour MLflow, incluant le nom de l'expérience
et l'URI de tracking.
"""
experiment_name: str = Field(
...,
description="Nom de l'expérience MLflow."
)
tracking_uri: str = Field(
...,
description="URI de tracking MLflow."
)
class Config(BaseModel):
"""
Objet de configuration global combinant la section modèle, vectorisation,
données, hyperparamètres et MLflow.
"""
model: ModelConfig
vectorization: VectorizationConfig
data: DataConfig
hyperparameters: HyperparameterConfig
mlflow: MLflowConfig