Spaces:
Runtime error
Runtime error
File size: 1,766 Bytes
c19ca42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
from __future__ import annotations
from typing import Literal, Tuple, TYPE_CHECKING
import onnxoptimizer
if TYPE_CHECKING:
import onnxruntime as ort
from onnx.onnx_pb import ModelProto
OnnxInputShape = Literal["BCHW", "BHWC"]
def as_int(value) -> int | None:
if isinstance(value, int):
return value
return None
def parse_onnx_shape(
shape: Tuple[int | str, str | int, str | int, str | int]
) -> Tuple[OnnxInputShape, int, int | None, int | None]:
if isinstance(shape[1], int) and shape[1] <= 4:
return "BCHW", shape[1], as_int(shape[3]), as_int(shape[2])
elif isinstance(shape[3], int) and shape[3] <= 4:
return "BHWC", shape[3], as_int(shape[2]), as_int(shape[1])
else:
return "BCHW", 3, as_int(shape[3]), as_int(shape[2])
def get_input_shape(
session: ort.InferenceSession,
) -> Tuple[OnnxInputShape, int, int | None, int | None]:
"""
Returns the input shape, input channels, input width (optional), and input height (optional).
"""
return parse_onnx_shape(session.get_inputs()[0].shape)
def get_output_shape(
session: ort.InferenceSession,
) -> Tuple[OnnxInputShape, int, int | None, int | None]:
"""
Returns the output shape, output channels, output width (optional), and output height (optional).
"""
return parse_onnx_shape(session.get_outputs()[0].shape)
def safely_optimize_onnx_model(model_proto: ModelProto) -> ModelProto:
"""
Optimizes the model using onnxoptimizer. If onnxoptimizer is not installed, the model is returned as is.
"""
try:
passes = onnxoptimizer.get_fuse_and_elimination_passes()
model_proto = onnxoptimizer.optimize(model_proto, passes)
except:
pass
return model_proto
|