bilegentile's picture
Upload folder using huggingface_hub
c19ca42 verified
raw
history blame contribute delete
1.98 kB
from __future__ import annotations
from weakref import WeakKeyDictionary
import onnxruntime as ort
def create_inference_session(
model,
gpu_index: int,
execution_provider: str,
should_tensorrt_fp16: bool = False,
tensorrt_cache_path: str | None = None,
) -> ort.InferenceSession:
if execution_provider == "TensorrtExecutionProvider":
providers = [
(
"TensorrtExecutionProvider",
{
"device_id": gpu_index,
"trt_engine_cache_enable": tensorrt_cache_path is not None,
"trt_engine_cache_path": tensorrt_cache_path,
"trt_fp16_enable": should_tensorrt_fp16,
},
),
(
"CUDAExecutionProvider",
{
"device_id": gpu_index,
},
),
"CPUExecutionProvider",
]
elif execution_provider == "CUDAExecutionProvider":
providers = [
(
"CUDAExecutionProvider",
{
"device_id": gpu_index,
},
),
"CPUExecutionProvider",
]
else:
providers = [execution_provider, "CPUExecutionProvider"]
session = ort.InferenceSession(model.bytes, providers=providers)
return session
__session_cache: WeakKeyDictionary[
any, ort.InferenceSession
] = WeakKeyDictionary()
def get_onnx_session(
model: any,
gpu_index: int,
execution_provider: str,
should_tensorrt_fp16: bool,
tensorrt_cache_path: str | None = None,
) -> ort.InferenceSession:
cached = __session_cache.get(model)
if cached is None:
cached = create_inference_session(
model,
gpu_index,
execution_provider,
should_tensorrt_fp16,
tensorrt_cache_path,
)
__session_cache[model] = cached
return cached