fioriclass commited on
Commit
bed4774
·
1 Parent(s): 8c7aba9

tentaive de correction de l'erreur

Browse files
src/mlflow_integration/mlflow_decorator.py CHANGED
@@ -44,6 +44,24 @@ class MLflowDecorator:
44
 
45
  return wrapper
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  def _start_run(self) -> None:
48
  """
49
  Démarre explicitement un run MLflow.
 
44
 
45
  return wrapper
46
 
47
+ def _get_params_for_trainer(self, trainer_instance) -> Dict[str, Any]:
48
+ """
49
+ Gets the relevant parameters from the trainer's config.
50
+ This replaces the singledispatch logic previously in parameter_logging.py
51
+ """
52
+ # In this specific project structure, all trainers seem to store
53
+ # the relevant hyperparameters directly in trainer_instance.config.model.params.
54
+ # If specific trainers needed different logic, we could add isinstance checks here.
55
+ # from trainers.cuml.svm_trainer import SvmTrainer # Example import if needed
56
+ # if isinstance(trainer_instance, SvmTrainer):
57
+ # return specific_logic_for_svm(trainer_instance)
58
+
59
+ # Default logic: return the model params from the config
60
+ if hasattr(trainer_instance, 'config') and hasattr(trainer_instance.config, 'model') and hasattr(trainer_instance.config.model, 'params'):
61
+ return trainer_instance.config.model.params
62
+ return {}
63
+
64
+
65
  def _start_run(self) -> None:
66
  """
67
  Démarre explicitement un run MLflow.
src/utilities/cuml_pyfunc_wrapper.py CHANGED
@@ -9,11 +9,12 @@ import cudf
9
  import cupy as cp
10
  import pickle
11
  import os
 
12
 
13
  from interfaces.vectorizer import Vectorizer
14
 
15
 
16
- class CuMLPyFuncWrapper:
17
  """
18
  Classe wrapper pour intégration de modèles cuML dans MLflow PyFunc,
19
  permettant le chargement et l'inférence.
 
9
  import cupy as cp
10
  import pickle
11
  import os
12
+ import mlflow
13
 
14
  from interfaces.vectorizer import Vectorizer
15
 
16
 
17
+ class CuMLPyFuncWrapper(mlflow.pyfunc.PythonModel):
18
  """
19
  Classe wrapper pour intégration de modèles cuML dans MLflow PyFunc,
20
  permettant le chargement et l'inférence.