Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import json | |
from dataclasses import dataclass | |
from pathlib import Path | |
from typing import Any | |
import numpy as np | |
import onnxruntime as ort | |
from loguru import logger | |
class ModelInfo: | |
base_model: str | |
def from_dir(cls, model_dir: Path): | |
with open(model_dir / "metadata.json", "r", encoding="utf-8") as file: | |
data = json.load(file) | |
return ModelInfo(base_model=data["bert_type"]) | |
class ONNXModel: | |
def __init__(self, model: ort.InferenceSession, model_info: ModelInfo) -> None: | |
self.model = model | |
self.model_info = model_info | |
self.model_path = Path(model._model_path) # type: ignore | |
self.model_name = self.model_path.name | |
self.providers = model.get_providers() | |
if self.providers[0] in ["CUDAExecutionProvider", "TensorrtExecutionProvider"]: | |
self.device = "cuda" | |
else: | |
self.device = "cpu" | |
self.io_types = { | |
"input_ids": np.int32, | |
"attention_mask": np.bool_ | |
} | |
self.input_names = [el.name for el in model.get_inputs()] | |
self.output_name = model.get_outputs()[0].name | |
def load_session( | |
path: str | Path, | |
provider: str = "CPUExecutionProvider", | |
session_options: ort.SessionOptions | None = None, | |
provider_options: dict[str, Any] | None = None, | |
) -> ort.InferenceSession: | |
providers = [provider] | |
if provider == "TensorrtExecutionProvider": | |
providers.append("CUDAExecutionProvider") | |
elif provider == "CUDAExecutionProvider": | |
providers.append("CPUExecutionProvider") | |
if not isinstance(path, str): | |
path = Path(path) / "model.onnx" | |
providers_options = None | |
if provider_options is not None: | |
providers_options = [provider_options] + [{} for _ in range(len(providers) - 1)] | |
session = ort.InferenceSession( | |
str(path), | |
providers=providers, | |
sess_options=session_options, | |
provider_options=providers_options, | |
) | |
logger.info("Session loaded") | |
return session | |
def from_dir(cls, model_dir: str | Path) -> ONNXModel: | |
return ONNXModel(ONNXModel.load_session(model_dir), ModelInfo.from_dir(model_dir)) | |
def __call__(self, **model_inputs: np.ndarray): | |
model_inputs = { | |
input_name: tensor.astype(self.io_types[input_name]) for input_name, tensor in model_inputs.items() | |
} | |
return self.model.run([self.output_name], model_inputs)[0] |