from __future__ import annotations from weakref import WeakKeyDictionary import onnxruntime as ort def create_inference_session( model, gpu_index: int, execution_provider: str, should_tensorrt_fp16: bool = False, tensorrt_cache_path: str | None = None, ) -> ort.InferenceSession: if execution_provider == "TensorrtExecutionProvider": providers = [ ( "TensorrtExecutionProvider", { "device_id": gpu_index, "trt_engine_cache_enable": tensorrt_cache_path is not None, "trt_engine_cache_path": tensorrt_cache_path, "trt_fp16_enable": should_tensorrt_fp16, }, ), ( "CUDAExecutionProvider", { "device_id": gpu_index, }, ), "CPUExecutionProvider", ] elif execution_provider == "CUDAExecutionProvider": providers = [ ( "CUDAExecutionProvider", { "device_id": gpu_index, }, ), "CPUExecutionProvider", ] else: providers = [execution_provider, "CPUExecutionProvider"] session = ort.InferenceSession(model.bytes, providers=providers) return session __session_cache: WeakKeyDictionary[ any, ort.InferenceSession ] = WeakKeyDictionary() def get_onnx_session( model: any, gpu_index: int, execution_provider: str, should_tensorrt_fp16: bool, tensorrt_cache_path: str | None = None, ) -> ort.InferenceSession: cached = __session_cache.get(model) if cached is None: cached = create_inference_session( model, gpu_index, execution_provider, should_tensorrt_fp16, tensorrt_cache_path, ) __session_cache[model] = cached return cached