File size: 2,915 Bytes
bf5fb5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43d4438
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf5fb5f
 
 
43d4438
bf5fb5f
 
 
 
 
43d4438
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# ======================
# Fichier: config.py
# ======================

from pydantic import BaseModel, Field
from typing import Dict, Any


class ModelConfig(BaseModel):
    """
    Représente la configuration du modèle, incluant le type de modèle
    et les paramètres associés.
    """
    type: str = Field(
        ...,
        description=
        "Le type de modèle à entraîner (ex. 'svm', 'random_forest', 'logistic_regression', etc.)."
    )
    params: Dict[str, Any] = Field(
        default_factory=dict,
        description="Dictionnaire des paramètres propres au modèle choisi.")


class VectorizationConfig(BaseModel):
    """
    Représente la configuration de la vectorisation, incluant la méthode
    et les paramètres éventuels.
    """
    method: str = Field(
        ...,
        description="Le type de vectorisation à utiliser (ex. 'tfidf', 'bow')."
    )
    tfidf: Dict[str, Any] = Field(
        default_factory=dict,
        description="Paramètres spécifiques à une vectorisation TF-IDF.")
    bow: Dict[str, Any] = Field(
        default_factory=dict,
        description="Paramètres spécifiques à une vectorisation bag-of-words.")


class DataConfig(BaseModel):
    """
    Représente la configuration liée aux données, incluant
    le chemin vers les données et le nom de la colonne cible.
    """
    path: str = Field(...,
                      description="Chemin d'accès vers la source de données.")
    target_column: str = Field(
        ..., description="Nom de la colonne contenant la variable cible.")


class HyperparameterConfig(BaseModel):
    """
    Représente la configuration pour l'optimisation des hyperparamètres,
    incluant le nom de l'optimiseur, la grille de paramètres et
    le nombre d'itérations d'entraînement.
    """
    optimizer: str = Field(
        ...,
        description=
        "Nom de l'optimiseur d'hyperparamètres (ex. 'optuna', 'raytune').")
    param_grid: Dict[str, Any] = Field(
        default_factory=dict,
        description=
        "Grille définissant l'espace de recherche pour chaque hyperparamètre.")
    n_trials: int = Field(
        default=1,
        description="Nombre d'essais pour la recherche d'hyperparamètres.")


class MLflowConfig(BaseModel):
    """
    Représente la configuration pour MLflow, incluant le nom de l'expérience
    et l'URI de tracking.
    """
    experiment_name: str = Field(
        ...,
        description="Nom de l'expérience MLflow."
    )
    tracking_uri: str = Field(
        ...,
        description="URI de tracking MLflow."
    )


class Config(BaseModel):
    """
    Objet de configuration global combinant la section modèle, vectorisation,
    données, hyperparamètres et MLflow.
    """
    model: ModelConfig
    vectorization: VectorizationConfig
    data: DataConfig
    hyperparameters: HyperparameterConfig
    mlflow: MLflowConfig