|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import warnings |
|
from typing import Any, Tuple |
|
|
|
import torch |
|
|
|
|
|
def export_onnx( |
|
model: torch.nn.Module, |
|
input_shape: Tuple[int], |
|
export_path: str, |
|
opset: int, |
|
export_dtype: torch.dtype, |
|
export_device: torch.device, |
|
) -> None: |
|
model.eval() |
|
|
|
dummy_input = {"x": torch.randn(input_shape, dtype=export_dtype, device=export_device)} |
|
dynamic_axes = { |
|
"x": {0: "batch_size"}, |
|
} |
|
|
|
|
|
|
|
output_names = ["image_embeddings"] |
|
|
|
export_dir = os.path.dirname(export_path) |
|
if not os.path.exists(export_dir): |
|
os.makedirs(export_dir) |
|
|
|
with warnings.catch_warnings(): |
|
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) |
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
print(f"Exporting onnx model to {export_path}...") |
|
with open(export_path, "wb") as f: |
|
torch.onnx.export( |
|
model, |
|
tuple(dummy_input.values()), |
|
f, |
|
export_params=True, |
|
verbose=False, |
|
opset_version=opset, |
|
do_constant_folding=True, |
|
input_names=list(dummy_input.keys()), |
|
output_names=output_names, |
|
dynamic_axes=dynamic_axes, |
|
) |
|
|