Spaces:
Runtime error
Runtime error
"""Utils""" | |
from __future__ import annotations | |
import json | |
from pathlib import Path | |
from typing import Literal | |
from loguru import logger | |
def download_model( | |
model_name: str, | |
model_stage: Literal["staging", "production"], | |
model_dir: str | Path = "model", | |
) -> Path: | |
"""Download model from mlflow""" | |
import mlflow.artifacts | |
import mlflow.models | |
from mlflow.client import MlflowClient | |
logger.info(f"Looking for model {model_name}/{model_stage}") | |
if isinstance(model_dir, str): | |
model_dir = Path(model_dir) | |
client = MlflowClient() | |
model_versions = client.get_latest_versions(model_name, stages=[model_stage]) | |
if len(model_versions) != 1: | |
raise ValueError(f"No model version for {model_name}/{model_stage}") | |
artifact_uri = model_versions[0].source | |
model_version = model_versions[0].version | |
logger.info(f"Found version {model_version} for {model_name}/{model_stage}") | |
model_path = model_dir / artifact_uri.split("/")[-1] # type: ignore | |
if model_path.exists(): | |
logger.info(f"Found model in {model_path}, skipping download") | |
return model_path | |
logger.info(f"Downloading artifacts {artifact_uri} to {model_dir}") | |
model_path = mlflow.artifacts.download_artifacts(artifact_uri, dst_path=str(model_dir)) | |
logger.info(f"Succesfully downloaded {model_name}") | |
model_info = mlflow.models.get_model_info(model_path) | |
metadata = model_info.metadata | |
metadata_path = Path(model_path) / "metadata.json" | |
logger.info(f"Saving metadata to {metadata_path}") | |
with open(metadata_path, "w", encoding="utf-8") as file: | |
json.dump(metadata, file) | |
return Path(model_path) |