bilegentile's picture
Upload folder using huggingface_hub
c19ca42 verified
raw
history blame contribute delete
1.77 kB
from __future__ import annotations
from typing import Literal, Tuple, TYPE_CHECKING
import onnxoptimizer
if TYPE_CHECKING:
import onnxruntime as ort
from onnx.onnx_pb import ModelProto
OnnxInputShape = Literal["BCHW", "BHWC"]
def as_int(value) -> int | None:
if isinstance(value, int):
return value
return None
def parse_onnx_shape(
shape: Tuple[int | str, str | int, str | int, str | int]
) -> Tuple[OnnxInputShape, int, int | None, int | None]:
if isinstance(shape[1], int) and shape[1] <= 4:
return "BCHW", shape[1], as_int(shape[3]), as_int(shape[2])
elif isinstance(shape[3], int) and shape[3] <= 4:
return "BHWC", shape[3], as_int(shape[2]), as_int(shape[1])
else:
return "BCHW", 3, as_int(shape[3]), as_int(shape[2])
def get_input_shape(
session: ort.InferenceSession,
) -> Tuple[OnnxInputShape, int, int | None, int | None]:
"""
Returns the input shape, input channels, input width (optional), and input height (optional).
"""
return parse_onnx_shape(session.get_inputs()[0].shape)
def get_output_shape(
session: ort.InferenceSession,
) -> Tuple[OnnxInputShape, int, int | None, int | None]:
"""
Returns the output shape, output channels, output width (optional), and output height (optional).
"""
return parse_onnx_shape(session.get_outputs()[0].shape)
def safely_optimize_onnx_model(model_proto: ModelProto) -> ModelProto:
"""
Optimizes the model using onnxoptimizer. If onnxoptimizer is not installed, the model is returned as is.
"""
try:
passes = onnxoptimizer.get_fuse_and_elimination_passes()
model_proto = onnxoptimizer.optimize(model_proto, passes)
except:
pass
return model_proto