gauravprakashh commited on
Commit
7d31858
·
verified ·
1 Parent(s): 8faf3a4

Create onnx_model.py

Browse files
Files changed (1) hide show
  1. onnx_model.py +84 -0
onnx_model.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ import numpy as np
9
+ import onnxruntime as ort
10
+ from loguru import logger
11
+
12
+
13
+ @dataclass
14
+ class ModelInfo:
15
+ base_model: str
16
+
17
+ @classmethod
18
+ def from_dir(cls, model_dir: Path):
19
+ with open(model_dir / "metadata.json", "r", encoding="utf-8") as file:
20
+ data = json.load(file)
21
+ return ModelInfo(base_model=data["bert_type"])
22
+
23
+
24
+ class ONNXModel:
25
+ def __init__(self, model: ort.InferenceSession, model_info: ModelInfo) -> None:
26
+ self.model = model
27
+ self.model_info = model_info
28
+ self.model_path = Path(model._model_path) # type: ignore
29
+ self.model_name = self.model_path.name
30
+
31
+ self.providers = model.get_providers()
32
+
33
+ if self.providers[0] in ["CUDAExecutionProvider", "TensorrtExecutionProvider"]:
34
+ self.device = "cuda"
35
+ else:
36
+ self.device = "cpu"
37
+
38
+ self.io_types = {
39
+ "input_ids": np.int32,
40
+ "attention_mask": np.bool_
41
+ }
42
+
43
+ self.input_names = [el.name for el in model.get_inputs()]
44
+ self.output_name = model.get_outputs()[0].name
45
+
46
+ @staticmethod
47
+ def load_session(
48
+ path: str | Path,
49
+ provider: str = "CPUExecutionProvider",
50
+ session_options: ort.SessionOptions | None = None,
51
+ provider_options: dict[str, Any] | None = None,
52
+ ) -> ort.InferenceSession:
53
+ providers = [provider]
54
+ if provider == "TensorrtExecutionProvider":
55
+ providers.append("CUDAExecutionProvider")
56
+ elif provider == "CUDAExecutionProvider":
57
+ providers.append("CPUExecutionProvider")
58
+
59
+ if not isinstance(path, str):
60
+ path = Path(path) / "model.onnx"
61
+
62
+ providers_options = None
63
+ if provider_options is not None:
64
+ providers_options = [provider_options] + [{} for _ in range(len(providers) - 1)]
65
+
66
+ session = ort.InferenceSession(
67
+ str(path),
68
+ providers=providers,
69
+ sess_options=session_options,
70
+ provider_options=providers_options,
71
+ )
72
+ logger.info("Session loaded")
73
+ return session
74
+
75
+ @classmethod
76
+ def from_dir(cls, model_dir: str | Path) -> ONNXModel:
77
+ return ONNXModel(ONNXModel.load_session(model_dir), ModelInfo.from_dir(model_dir))
78
+
79
+ def __call__(self, **model_inputs: np.ndarray):
80
+ model_inputs = {
81
+ input_name: tensor.astype(self.io_types[input_name]) for input_name, tensor in model_inputs.items()
82
+ }
83
+
84
+ return self.model.run([self.output_name], model_inputs)[0]