LPDoctor commited on
Commit
1514c39
·
verified ·
1 Parent(s): ba47ec9

update the export-onnx.py

Browse files
Files changed (1) hide show
  1. export-onnx.py +199 -0
export-onnx.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
3
+
4
+ """
5
+ We use
6
+ https://hf-mirror.com/yuekai/model_repo_sense_voice_small/blob/main/export_onnx.py
7
+ as a reference while writing this file.
8
+
9
+ Thanks to https://github.com/yuekaizhang for making the file public.
10
+ """
11
+
12
+ import os
13
+ from typing import Any, Dict, Tuple
14
+
15
+ import onnx
16
+ import torch
17
+ from model import SenseVoiceSmall
18
+ from onnxruntime.quantization import QuantType, quantize_dynamic
19
+
20
+
21
+ def add_meta_data(filename: str, meta_data: Dict[str, Any]):
22
+ """Add meta data to an ONNX model. It is changed in-place.
23
+
24
+ Args:
25
+ filename:
26
+ Filename of the ONNX model to be changed.
27
+ meta_data:
28
+ Key-value pairs.
29
+ """
30
+ model = onnx.load(filename)
31
+ while len(model.metadata_props):
32
+ model.metadata_props.pop()
33
+
34
+ for key, value in meta_data.items():
35
+ meta = model.metadata_props.add()
36
+ meta.key = key
37
+ meta.value = str(value)
38
+
39
+ onnx.save(model, filename)
40
+
41
+
42
+ def modified_forward(
43
+ self,
44
+ x: torch.Tensor,
45
+ x_length: torch.Tensor,
46
+ language: torch.Tensor,
47
+ text_norm: torch.Tensor,
48
+ ):
49
+ """
50
+ Args:
51
+ x:
52
+ A 3-D tensor of shape (N, T, C) with dtype torch.float32
53
+ x_length:
54
+ A 1-D tensor of shape (N,) with dtype torch.int32
55
+ language:
56
+ A 1-D tensor of shape (N,) with dtype torch.int32
57
+ See also https://github.com/FunAudioLLM/SenseVoice/blob/a80e676461b24419cf1130a33d4dd2f04053e5cc/model.py#L640
58
+ text_norm:
59
+ A 1-D tensor of shape (N,) with dtype torch.int32
60
+ See also https://github.com/FunAudioLLM/SenseVoice/blob/a80e676461b24419cf1130a33d4dd2f04053e5cc/model.py#L642
61
+ """
62
+ language_query = self.embed(language).unsqueeze(1)
63
+ text_norm_query = self.embed(text_norm).unsqueeze(1)
64
+
65
+ event_emo_query = self.embed(torch.LongTensor([[1, 2]])).repeat(x.size(0), 1, 1)
66
+
67
+ x = torch.cat((language_query, event_emo_query, text_norm_query, x), dim=1)
68
+ x_length += 4
69
+
70
+ encoder_out, encoder_out_lens = self.encoder(x, x_length)
71
+ if isinstance(encoder_out, tuple):
72
+ encoder_out = encoder_out[0]
73
+
74
+ ctc_logits = self.ctc.ctc_lo(encoder_out)
75
+
76
+ return ctc_logits
77
+
78
+
79
+ def load_cmvn(filename) -> Tuple[str, str]:
80
+ neg_mean = None
81
+ inv_stddev = None
82
+
83
+ with open(filename) as f:
84
+ for line in f:
85
+ if not line.startswith("<LearnRateCoef>"):
86
+ continue
87
+ t = line.split()[3:-1]
88
+
89
+ if neg_mean is None:
90
+ neg_mean = ",".join(t)
91
+ else:
92
+ inv_stddev = ",".join(t)
93
+
94
+ return neg_mean, inv_stddev
95
+
96
+
97
+ def generate_tokens(params):
98
+ sp = params["tokenizer"].sp
99
+ with open("tokens.txt", "w", encoding="utf-8") as f:
100
+ for i in range(sp.vocab_size()):
101
+ f.write(f"{sp.id_to_piece(i)} {i}\n")
102
+
103
+ os.system("head tokens.txt; tail -n200 tokens.txt")
104
+
105
+
106
+ def display_params(params):
107
+ print("----------params----------")
108
+ print(params)
109
+
110
+ print("----------frontend_conf----------")
111
+ print(params["frontend_conf"])
112
+
113
+ os.system(f"cat {params['frontend_conf']['cmvn_file']}")
114
+
115
+ print("----------config----------")
116
+ print(params["config"])
117
+
118
+ os.system(f"cat {params['config']}")
119
+
120
+
121
+ def main():
122
+ model, params = SenseVoiceSmall.from_pretrained(model="iic/SenseVoiceSmall")
123
+ display_params(params)
124
+
125
+ generate_tokens(params)
126
+
127
+ model.__class__.forward = modified_forward
128
+
129
+ x = torch.randn(2, 100, 560, dtype=torch.float32)
130
+ x_length = torch.tensor([80, 100], dtype=torch.int32)
131
+ language = torch.tensor([0, 3], dtype=torch.int32)
132
+ text_norm = torch.tensor([14, 15], dtype=torch.int32)
133
+
134
+ opset_version = 13
135
+ filename = "model.onnx"
136
+ torch.onnx.export(
137
+ model,
138
+ (x, x_length, language, text_norm),
139
+ filename,
140
+ opset_version=opset_version,
141
+ input_names=["x", "x_length", "language", "text_norm"],
142
+ output_names=["logits"],
143
+ dynamic_axes={
144
+ "x": {0: "N", 1: "T"},
145
+ "x_length": {0: "N"},
146
+ "language": {0: "N"},
147
+ "text_norm": {0: "N"},
148
+ "logits": {0: "N", 1: "T"},
149
+ },
150
+ )
151
+
152
+ lfr_window_size = params["frontend_conf"]["lfr_m"]
153
+ lfr_window_shift = params["frontend_conf"]["lfr_n"]
154
+
155
+ neg_mean, inv_stddev = load_cmvn(params["frontend_conf"]["cmvn_file"])
156
+ vocab_size = params["tokenizer"].sp.vocab_size()
157
+
158
+ meta_data = {
159
+ "lfr_window_size": lfr_window_size,
160
+ "lfr_window_shift": lfr_window_shift,
161
+ "normalize_samples": 0, # input should be in the range [-32768, 32767]
162
+ "neg_mean": neg_mean,
163
+ "inv_stddev": inv_stddev,
164
+ "model_type": "sense_voice_ctc",
165
+ # version 1: Use QInt8
166
+ # version 2: Use QUInt8
167
+ "version": "2",
168
+ "model_author": "iic",
169
+ "maintainer": "k2-fsa",
170
+ "vocab_size": vocab_size,
171
+ "comment": "iic/SenseVoiceSmall",
172
+ "lang_auto": model.lid_dict["auto"],
173
+ "lang_zh": model.lid_dict["zh"],
174
+ "lang_en": model.lid_dict["en"],
175
+ "lang_yue": model.lid_dict["yue"], # cantonese
176
+ "lang_ja": model.lid_dict["ja"],
177
+ "lang_ko": model.lid_dict["ko"],
178
+ "lang_nospeech": model.lid_dict["nospeech"],
179
+ "with_itn": model.textnorm_dict["withitn"],
180
+ "without_itn": model.textnorm_dict["woitn"],
181
+ "url": "https://huggingface.co/FunAudioLLM/SenseVoiceSmall",
182
+ }
183
+ add_meta_data(filename=filename, meta_data=meta_data)
184
+
185
+ filename_int8 = "model.int8.onnx"
186
+ quantize_dynamic(
187
+ model_input=filename,
188
+ model_output=filename_int8,
189
+ op_types_to_quantize=["MatMul"],
190
+ # Note that we have to use QUInt8 here.
191
+ #
192
+ # When QInt8 is used, C++ onnxruntime produces incorrect results
193
+ weight_type=QuantType.QUInt8,
194
+ )
195
+
196
+
197
+ if __name__ == "__main__":
198
+ torch.manual_seed(20240717)
199
+ main()