gauravprakashh commited on
Commit
eb5dd72
·
verified ·
1 Parent(s): e3c9245

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +52 -0
utils.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utils"""
2
+ from __future__ import annotations
3
+
4
+ import json
5
+ from pathlib import Path
6
+ from typing import Literal
7
+
8
+ from loguru import logger
9
+
10
+
11
+ def download_model(
12
+ model_name: str,
13
+ model_stage: Literal["staging", "production"],
14
+ model_dir: str | Path = "model",
15
+ ) -> Path:
16
+ """Download model from mlflow"""
17
+ import mlflow.artifacts
18
+ import mlflow.models
19
+ from mlflow.client import MlflowClient
20
+
21
+ logger.info(f"Looking for model {model_name}/{model_stage}")
22
+
23
+ if isinstance(model_dir, str):
24
+ model_dir = Path(model_dir)
25
+
26
+ client = MlflowClient()
27
+ model_versions = client.get_latest_versions(model_name, stages=[model_stage])
28
+ if len(model_versions) != 1:
29
+ raise ValueError(f"No model version for {model_name}/{model_stage}")
30
+
31
+ artifact_uri = model_versions[0].source
32
+ model_version = model_versions[0].version
33
+
34
+ logger.info(f"Found version {model_version} for {model_name}/{model_stage}")
35
+
36
+ model_path = model_dir / artifact_uri.split("/")[-1] # type: ignore
37
+ if model_path.exists():
38
+ logger.info(f"Found model in {model_path}, skipping download")
39
+ return model_path
40
+
41
+ logger.info(f"Downloading artifacts {artifact_uri} to {model_dir}")
42
+ model_path = mlflow.artifacts.download_artifacts(artifact_uri, dst_path=str(model_dir))
43
+ logger.info(f"Succesfully downloaded {model_name}")
44
+
45
+ model_info = mlflow.models.get_model_info(model_path)
46
+ metadata = model_info.metadata
47
+ metadata_path = Path(model_path) / "metadata.json"
48
+ logger.info(f"Saving metadata to {metadata_path}")
49
+ with open(metadata_path, "w", encoding="utf-8") as file:
50
+ json.dump(metadata, file)
51
+
52
+ return Path(model_path)