Spaces:
Runtime error
Runtime error
File size: 1,979 Bytes
c19ca42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
from __future__ import annotations
from weakref import WeakKeyDictionary
import onnxruntime as ort
def create_inference_session(
model,
gpu_index: int,
execution_provider: str,
should_tensorrt_fp16: bool = False,
tensorrt_cache_path: str | None = None,
) -> ort.InferenceSession:
if execution_provider == "TensorrtExecutionProvider":
providers = [
(
"TensorrtExecutionProvider",
{
"device_id": gpu_index,
"trt_engine_cache_enable": tensorrt_cache_path is not None,
"trt_engine_cache_path": tensorrt_cache_path,
"trt_fp16_enable": should_tensorrt_fp16,
},
),
(
"CUDAExecutionProvider",
{
"device_id": gpu_index,
},
),
"CPUExecutionProvider",
]
elif execution_provider == "CUDAExecutionProvider":
providers = [
(
"CUDAExecutionProvider",
{
"device_id": gpu_index,
},
),
"CPUExecutionProvider",
]
else:
providers = [execution_provider, "CPUExecutionProvider"]
session = ort.InferenceSession(model.bytes, providers=providers)
return session
__session_cache: WeakKeyDictionary[
any, ort.InferenceSession
] = WeakKeyDictionary()
def get_onnx_session(
model: any,
gpu_index: int,
execution_provider: str,
should_tensorrt_fp16: bool,
tensorrt_cache_path: str | None = None,
) -> ort.InferenceSession:
cached = __session_cache.get(model)
if cached is None:
cached = create_inference_session(
model,
gpu_index,
execution_provider,
should_tensorrt_fp16,
tensorrt_cache_path,
)
__session_cache[model] = cached
return cached
|