import types import argparse import torch import torch.nn.functional as F import numpy as np import onnx import onnxsim from basicsr.archs.ddcolor_arch import DDColor from onnx import load_model, save_model, shape_inference from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference def parse_args(): parser = argparse.ArgumentParser(description="Export DDColor model to ONNX.") parser.add_argument( "--input_size", type=int, default=512, help="Input image dimension.", ) parser.add_argument( "--batch_size", type=int, default=1, help="Input batch size.", ) parser.add_argument( "--model_path", type=str, required=True, help="Path to export ONNX model.", ) parser.add_argument( "--model_size", type=str, default="tiny", help="Path to export ONNX model.", ) parser.add_argument( "--decoder_type", type=str, default="MultiScaleColorDecoder", help="Path to export ONNX model.", ) parser.add_argument( "--export_path", type=str, default="./model.onnx", help="Path to export ONNX model.", ) parser.add_argument( "--opset", type=int, default=12, help="ONNX opset version.", ) return parser.parse_args() def create_onnx_export(args): input_size = args.input_size device = torch.device('cpu') if args.model_size == 'tiny': encoder_name = 'convnext-t' else: encoder_name = 'convnext-l' # hardcoded in inference/colorization_pipeline.py # decoder_type = "MultiScaleColorDecoder" if args.decoder_type == 'MultiScaleColorDecoder': model = DDColor( encoder_name=encoder_name, decoder_name='MultiScaleColorDecoder', input_size=[input_size, input_size], num_output_channels=2, last_norm='Spectral', do_normalize=False, num_queries=100, num_scales=3, dec_layers=9, ).to(device) elif args.decoder_type == 'SingleColorDecoder': model = DDColor( encoder_name=encoder_name, decoder_name='SingleColorDecoder', input_size=[input_size, input_size], num_output_channels=2, last_norm='Spectral', do_normalize=False, num_queries=256, ).to(device) else: raise("decoder_type not implemented.") model.load_state_dict( torch.load(args.model_path, map_location=device)['params'], strict=False) model.eval() channels = 3 # RGB image has 3 channels random_input = torch.rand((args.batch_size, channels, input_size, input_size), dtype=torch.float32) dynamic_axes = {} if args.batch_size == 0: dynamic_axes[0] = "batch" if input_size == 0: dynamic_axes[2] = "height" dynamic_axes[3] = "width" torch.onnx.export( model, random_input, args.export_path, opset_version=args.opset, input_names=["input"], output_names=["output"], dynamic_axes={ "input": dynamic_axes, "output": dynamic_axes }, ) def check_onnx_export(export_path): save_model( shape_inference.infer_shapes( load_model(export_path), check_type=True, strict_mode=True, data_prop=True ), export_path ) save_model( SymbolicShapeInference.infer_shapes(load_model(export_path), auto_merge=True, guess_output_rank=True ), export_path, ) model_onnx = onnx.load(export_path) # load onnx model onnx.checker.check_model(model_onnx) # check onnx model model_onnx, check = onnxsim.simplify(model_onnx) assert check, "assert check failed" onnx.save(model_onnx, export_path) if __name__ == '__main__': args = parse_args() create_onnx_export(args) print(f'ONNX file successfully created at {args.export_path}') check_onnx_export(args.export_path) print(f'ONNX file at {args.export_path} verifed shapes and simplified')