File size: 1,513 Bytes
a7ab59e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
import torchvision.models as models
from model_code import InitialOnlyImageTagger  # Assume model_code.py classes are accessible
from safetensors.torch import load_file

# Load the trained weights (Initial-only model). Adjust path to your weights file.
#weights_path = "model_initial_only.pt"
safetensors_path = 'model_initial.safetensors'
state_dict = load_file(safetensors_path, device='cpu')
#state_dict = torch.load(weights_path, map_location="cpu")
# Instantiate the model with the same parameters as training
model = InitialOnlyImageTagger(total_tags=70527, dataset=None, pretrained=True)  # dataset not needed for forward
model.load_state_dict(state_dict)
model.eval()  # set to evaluation mode

# Define example input – a dummy image tensor of the expected input shape (1, 3, 512, 512)
dummy_input = torch.randn(1, 3, 512, 512, dtype=torch.float32)

# Export to ONNX
onnx_path = "camie_tagger_initial_v15.onnx"
torch.onnx.export(
    model, dummy_input, onnx_path,
    export_params=True,        # store the trained parameter weights in the model file
    opset_version=13,          # ONNX opset version (13 is widely supported)
    do_constant_folding=True,  # optimize constant expressions
    input_names=["input"], 
    output_names=["initial_logits", "refined_logits"],  # model.forward returns two outputs (identical for InitialOnly)
    dynamic_axes={"input": {0: "batch_size"}}  # allow variable batch size
)
print(f"ONNX model saved to: {onnx_path}")