mrfakename commited on
Commit
82bc972
·
verified ·
1 Parent(s): 9f98fd9

Upload 51 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. LICENSE +21 -0
  3. MODEL-LICENSE +395 -0
  4. config.py +254 -0
  5. copy_codebase.py +56 -0
  6. data/__init__.py +0 -0
  7. data/combined_dataset.py +466 -0
  8. data/emilia_preprocessing/delete_tar_files.sh +42 -0
  9. data/emilia_preprocessing/encodec.py +1554 -0
  10. data/emilia_preprocessing/sha256hash.py +14 -0
  11. data/emilia_preprocessing/step1_download.py +9 -0
  12. data/emilia_preprocessing/step2_log_tar_files.sh +27 -0
  13. data/emilia_preprocessing/step3_untar.sh +101 -0
  14. data/emilia_preprocessing/step4_construct_manifest.py +251 -0
  15. data/emilia_preprocessing/step5_phonemize.py +158 -0
  16. data/emilia_preprocessing/step6_encodec_encode.py +177 -0
  17. data/emilia_preprocessing/step6_encodec_encode_script.sh +19 -0
  18. data/encodec.py +1554 -0
  19. data/ll60k_preprocessing/config.yaml +75 -0
  20. data/ll60k_preprocessing/encodec.py +1554 -0
  21. data/ll60k_preprocessing/step1_download.sh +42 -0
  22. data/ll60k_preprocessing/step2_resplit_long.py +51 -0
  23. data/ll60k_preprocessing/step3_seg_phn_manifest.py +194 -0
  24. data/ll60k_preprocessing/step4_encodec_encode.py +184 -0
  25. data/ll60k_preprocessing/step4_encodec_encode_script.sh +19 -0
  26. data/ll60k_preprocessing/step5_find_nearest_neighbor.py +157 -0
  27. data/ll60k_preprocessing/step6_forced_alignment.py +86 -0
  28. data/ll60k_preprocessing/step6_forced_alignment.sh +13 -0
  29. data/ll60k_preprocessing/step7_ipa_alignment.py +114 -0
  30. data/ll60k_preprocessing/tokenizer.py +460 -0
  31. data/tokenizer.py +295 -0
  32. demo/5895_34622_000026_000002.wav +3 -0
  33. generated_tts/generated.wav +3 -0
  34. inference_commandline.py +192 -0
  35. inference_gradio.py +334 -0
  36. inference_tts_utils.py +155 -0
  37. main.py +82 -0
  38. models/modules/__init__.py +0 -0
  39. models/modules/activation.py +781 -0
  40. models/modules/embedding.py +158 -0
  41. models/modules/sampling.py +63 -0
  42. models/modules/scaling.py +1406 -0
  43. models/modules/transformer.py +1089 -0
  44. models/modules/utils.py +36 -0
  45. models/modules/visualizer.py +107 -0
  46. models/voice_star.py +784 -0
  47. pretrained/.gitkeep +0 -0
  48. steps/__init__.py +0 -0
  49. steps/optim.py +1123 -0
  50. steps/trainer.py +717 -0
.gitattributes CHANGED
@@ -36,3 +36,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
36
  examples/web_6f93090a-81f6-489e-bb35-1a2838b18c01.png filter=lfs diff=lfs merge=lfs -text
37
  examples/web_dfacd48d-d2c2-492f-b94c-41e6a34ea99f.png filter=lfs diff=lfs merge=lfs -text
38
  illustration.png filter=lfs diff=lfs merge=lfs -text
 
 
 
36
  examples/web_6f93090a-81f6-489e-bb35-1a2838b18c01.png filter=lfs diff=lfs merge=lfs -text
37
  examples/web_dfacd48d-d2c2-492f-b94c-41e6a34ea99f.png filter=lfs diff=lfs merge=lfs -text
38
  illustration.png filter=lfs diff=lfs merge=lfs -text
39
+ demo/5895_34622_000026_000002.wav filter=lfs diff=lfs merge=lfs -text
40
+ generated_tts/generated.wav filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Puyuan Peng
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
MODEL-LICENSE ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Attribution 4.0 International
2
+
3
+ =======================================================================
4
+
5
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
6
+ does not provide legal services or legal advice. Distribution of
7
+ Creative Commons public licenses does not create a lawyer-client or
8
+ other relationship. Creative Commons makes its licenses and related
9
+ information available on an "as-is" basis. Creative Commons gives no
10
+ warranties regarding its licenses, any material licensed under their
11
+ terms and conditions, or any related information. Creative Commons
12
+ disclaims all liability for damages resulting from their use to the
13
+ fullest extent possible.
14
+
15
+ Using Creative Commons Public Licenses
16
+
17
+ Creative Commons public licenses provide a standard set of terms and
18
+ conditions that creators and other rights holders may use to share
19
+ original works of authorship and other material subject to copyright
20
+ and certain other rights specified in the public license below. The
21
+ following considerations are for informational purposes only, are not
22
+ exhaustive, and do not form part of our licenses.
23
+
24
+ Considerations for licensors: Our public licenses are
25
+ intended for use by those authorized to give the public
26
+ permission to use material in ways otherwise restricted by
27
+ copyright and certain other rights. Our licenses are
28
+ irrevocable. Licensors should read and understand the terms
29
+ and conditions of the license they choose before applying it.
30
+ Licensors should also secure all rights necessary before
31
+ applying our licenses so that the public can reuse the
32
+ material as expected. Licensors should clearly mark any
33
+ material not subject to the license. This includes other CC-
34
+ licensed material, or material used under an exception or
35
+ limitation to copyright. More considerations for licensors:
36
+ wiki.creativecommons.org/Considerations_for_licensors
37
+
38
+ Considerations for the public: By using one of our public
39
+ licenses, a licensor grants the public permission to use the
40
+ licensed material under specified terms and conditions. If
41
+ the licensor's permission is not necessary for any reason--for
42
+ example, because of any applicable exception or limitation to
43
+ copyright--then that use is not regulated by the license. Our
44
+ licenses grant only permissions under copyright and certain
45
+ other rights that a licensor has authority to grant. Use of
46
+ the licensed material may still be restricted for other
47
+ reasons, including because others have copyright or other
48
+ rights in the material. A licensor may make special requests,
49
+ such as asking that all changes be marked or described.
50
+ Although not required by our licenses, you are encouraged to
51
+ respect those requests where reasonable. More considerations
52
+ for the public:
53
+ wiki.creativecommons.org/Considerations_for_licensees
54
+
55
+ =======================================================================
56
+
57
+ Creative Commons Attribution 4.0 International Public License
58
+
59
+ By exercising the Licensed Rights (defined below), You accept and agree
60
+ to be bound by the terms and conditions of this Creative Commons
61
+ Attribution 4.0 International Public License ("Public License"). To the
62
+ extent this Public License may be interpreted as a contract, You are
63
+ granted the Licensed Rights in consideration of Your acceptance of
64
+ these terms and conditions, and the Licensor grants You such rights in
65
+ consideration of benefits the Licensor receives from making the
66
+ Licensed Material available under these terms and conditions.
67
+
68
+
69
+ Section 1 -- Definitions.
70
+
71
+ a. Adapted Material means material subject to Copyright and Similar
72
+ Rights that is derived from or based upon the Licensed Material
73
+ and in which the Licensed Material is translated, altered,
74
+ arranged, transformed, or otherwise modified in a manner requiring
75
+ permission under the Copyright and Similar Rights held by the
76
+ Licensor. For purposes of this Public License, where the Licensed
77
+ Material is a musical work, performance, or sound recording,
78
+ Adapted Material is always produced where the Licensed Material is
79
+ synched in timed relation with a moving image.
80
+
81
+ b. Adapter's License means the license You apply to Your Copyright
82
+ and Similar Rights in Your contributions to Adapted Material in
83
+ accordance with the terms and conditions of this Public License.
84
+
85
+ c. Copyright and Similar Rights means copyright and/or similar rights
86
+ closely related to copyright including, without limitation,
87
+ performance, broadcast, sound recording, and Sui Generis Database
88
+ Rights, without regard to how the rights are labeled or
89
+ categorized. For purposes of this Public License, the rights
90
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
91
+ Rights.
92
+
93
+ d. Effective Technological Measures means those measures that, in the
94
+ absence of proper authority, may not be circumvented under laws
95
+ fulfilling obligations under Article 11 of the WIPO Copyright
96
+ Treaty adopted on December 20, 1996, and/or similar international
97
+ agreements.
98
+
99
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
100
+ any other exception or limitation to Copyright and Similar Rights
101
+ that applies to Your use of the Licensed Material.
102
+
103
+ f. Licensed Material means the artistic or literary work, database,
104
+ or other material to which the Licensor applied this Public
105
+ License.
106
+
107
+ g. Licensed Rights means the rights granted to You subject to the
108
+ terms and conditions of this Public License, which are limited to
109
+ all Copyright and Similar Rights that apply to Your use of the
110
+ Licensed Material and that the Licensor has authority to license.
111
+
112
+ h. Licensor means the individual(s) or entity(ies) granting rights
113
+ under this Public License.
114
+
115
+ i. Share means to provide material to the public by any means or
116
+ process that requires permission under the Licensed Rights, such
117
+ as reproduction, public display, public performance, distribution,
118
+ dissemination, communication, or importation, and to make material
119
+ available to the public including in ways that members of the
120
+ public may access the material from a place and at a time
121
+ individually chosen by them.
122
+
123
+ j. Sui Generis Database Rights means rights other than copyright
124
+ resulting from Directive 96/9/EC of the European Parliament and of
125
+ the Council of 11 March 1996 on the legal protection of databases,
126
+ as amended and/or succeeded, as well as other essentially
127
+ equivalent rights anywhere in the world.
128
+
129
+ k. You means the individual or entity exercising the Licensed Rights
130
+ under this Public License. Your has a corresponding meaning.
131
+
132
+
133
+ Section 2 -- Scope.
134
+
135
+ a. License grant.
136
+
137
+ 1. Subject to the terms and conditions of this Public License,
138
+ the Licensor hereby grants You a worldwide, royalty-free,
139
+ non-sublicensable, non-exclusive, irrevocable license to
140
+ exercise the Licensed Rights in the Licensed Material to:
141
+
142
+ a. reproduce and Share the Licensed Material, in whole or
143
+ in part; and
144
+
145
+ b. produce, reproduce, and Share Adapted Material.
146
+
147
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
148
+ Exceptions and Limitations apply to Your use, this Public
149
+ License does not apply, and You do not need to comply with
150
+ its terms and conditions.
151
+
152
+ 3. Term. The term of this Public License is specified in Section
153
+ 6(a).
154
+
155
+ 4. Media and formats; technical modifications allowed. The
156
+ Licensor authorizes You to exercise the Licensed Rights in
157
+ all media and formats whether now known or hereafter created,
158
+ and to make technical modifications necessary to do so. The
159
+ Licensor waives and/or agrees not to assert any right or
160
+ authority to forbid You from making technical modifications
161
+ necessary to exercise the Licensed Rights, including
162
+ technical modifications necessary to circumvent Effective
163
+ Technological Measures. For purposes of this Public License,
164
+ simply making modifications authorized by this Section 2(a)
165
+ (4) never produces Adapted Material.
166
+
167
+ 5. Downstream recipients.
168
+
169
+ a. Offer from the Licensor -- Licensed Material. Every
170
+ recipient of the Licensed Material automatically
171
+ receives an offer from the Licensor to exercise the
172
+ Licensed Rights under the terms and conditions of this
173
+ Public License.
174
+
175
+ b. No downstream restrictions. You may not offer or impose
176
+ any additional or different terms or conditions on, or
177
+ apply any Effective Technological Measures to, the
178
+ Licensed Material if doing so restricts exercise of the
179
+ Licensed Rights by any recipient of the Licensed
180
+ Material.
181
+
182
+ 6. No endorsement. Nothing in this Public License constitutes or
183
+ may be construed as permission to assert or imply that You
184
+ are, or that Your use of the Licensed Material is, connected
185
+ with, or sponsored, endorsed, or granted official status by,
186
+ the Licensor or others designated to receive attribution as
187
+ provided in Section 3(a)(1)(A)(i).
188
+
189
+ b. Other rights.
190
+
191
+ 1. Moral rights, such as the right of integrity, are not
192
+ licensed under this Public License, nor are publicity,
193
+ privacy, and/or other similar personality rights; however, to
194
+ the extent possible, the Licensor waives and/or agrees not to
195
+ assert any such rights held by the Licensor to the limited
196
+ extent necessary to allow You to exercise the Licensed
197
+ Rights, but not otherwise.
198
+
199
+ 2. Patent and trademark rights are not licensed under this
200
+ Public License.
201
+
202
+ 3. To the extent possible, the Licensor waives any right to
203
+ collect royalties from You for the exercise of the Licensed
204
+ Rights, whether directly or through a collecting society
205
+ under any voluntary or waivable statutory or compulsory
206
+ licensing scheme. In all other cases the Licensor expressly
207
+ reserves any right to collect such royalties.
208
+
209
+
210
+ Section 3 -- License Conditions.
211
+
212
+ Your exercise of the Licensed Rights is expressly made subject to the
213
+ following conditions.
214
+
215
+ a. Attribution.
216
+
217
+ 1. If You Share the Licensed Material (including in modified
218
+ form), You must:
219
+
220
+ a. retain the following if it is supplied by the Licensor
221
+ with the Licensed Material:
222
+
223
+ i. identification of the creator(s) of the Licensed
224
+ Material and any others designated to receive
225
+ attribution, in any reasonable manner requested by
226
+ the Licensor (including by pseudonym if
227
+ designated);
228
+
229
+ ii. a copyright notice;
230
+
231
+ iii. a notice that refers to this Public License;
232
+
233
+ iv. a notice that refers to the disclaimer of
234
+ warranties;
235
+
236
+ v. a URI or hyperlink to the Licensed Material to the
237
+ extent reasonably practicable;
238
+
239
+ b. indicate if You modified the Licensed Material and
240
+ retain an indication of any previous modifications; and
241
+
242
+ c. indicate the Licensed Material is licensed under this
243
+ Public License, and include the text of, or the URI or
244
+ hyperlink to, this Public License.
245
+
246
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
247
+ reasonable manner based on the medium, means, and context in
248
+ which You Share the Licensed Material. For example, it may be
249
+ reasonable to satisfy the conditions by providing a URI or
250
+ hyperlink to a resource that includes the required
251
+ information.
252
+
253
+ 3. If requested by the Licensor, You must remove any of the
254
+ information required by Section 3(a)(1)(A) to the extent
255
+ reasonably practicable.
256
+
257
+ 4. If You Share Adapted Material You produce, the Adapter's
258
+ License You apply must not prevent recipients of the Adapted
259
+ Material from complying with this Public License.
260
+
261
+
262
+ Section 4 -- Sui Generis Database Rights.
263
+
264
+ Where the Licensed Rights include Sui Generis Database Rights that
265
+ apply to Your use of the Licensed Material:
266
+
267
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
268
+ to extract, reuse, reproduce, and Share all or a substantial
269
+ portion of the contents of the database;
270
+
271
+ b. if You include all or a substantial portion of the database
272
+ contents in a database in which You have Sui Generis Database
273
+ Rights, then the database in which You have Sui Generis Database
274
+ Rights (but not its individual contents) is Adapted Material; and
275
+
276
+ c. You must comply with the conditions in Section 3(a) if You Share
277
+ all or a substantial portion of the contents of the database.
278
+
279
+ For the avoidance of doubt, this Section 4 supplements and does not
280
+ replace Your obligations under this Public License where the Licensed
281
+ Rights include other Copyright and Similar Rights.
282
+
283
+
284
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
285
+
286
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
287
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
288
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
289
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
290
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
291
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
292
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
293
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
294
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
295
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
296
+
297
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
298
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
299
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
300
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
301
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
302
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
303
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
304
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
305
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
306
+
307
+ c. The disclaimer of warranties and limitation of liability provided
308
+ above shall be interpreted in a manner that, to the extent
309
+ possible, most closely approximates an absolute disclaimer and
310
+ waiver of all liability.
311
+
312
+
313
+ Section 6 -- Term and Termination.
314
+
315
+ a. This Public License applies for the term of the Copyright and
316
+ Similar Rights licensed here. However, if You fail to comply with
317
+ this Public License, then Your rights under this Public License
318
+ terminate automatically.
319
+
320
+ b. Where Your right to use the Licensed Material has terminated under
321
+ Section 6(a), it reinstates:
322
+
323
+ 1. automatically as of the date the violation is cured, provided
324
+ it is cured within 30 days of Your discovery of the
325
+ violation; or
326
+
327
+ 2. upon express reinstatement by the Licensor.
328
+
329
+ For the avoidance of doubt, this Section 6(b) does not affect any
330
+ right the Licensor may have to seek remedies for Your violations
331
+ of this Public License.
332
+
333
+ c. For the avoidance of doubt, the Licensor may also offer the
334
+ Licensed Material under separate terms or conditions or stop
335
+ distributing the Licensed Material at any time; however, doing so
336
+ will not terminate this Public License.
337
+
338
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
339
+ License.
340
+
341
+
342
+ Section 7 -- Other Terms and Conditions.
343
+
344
+ a. The Licensor shall not be bound by any additional or different
345
+ terms or conditions communicated by You unless expressly agreed.
346
+
347
+ b. Any arrangements, understandings, or agreements regarding the
348
+ Licensed Material not stated herein are separate from and
349
+ independent of the terms and conditions of this Public License.
350
+
351
+
352
+ Section 8 -- Interpretation.
353
+
354
+ a. For the avoidance of doubt, this Public License does not, and
355
+ shall not be interpreted to, reduce, limit, restrict, or impose
356
+ conditions on any use of the Licensed Material that could lawfully
357
+ be made without permission under this Public License.
358
+
359
+ b. To the extent possible, if any provision of this Public License is
360
+ deemed unenforceable, it shall be automatically reformed to the
361
+ minimum extent necessary to make it enforceable. If the provision
362
+ cannot be reformed, it shall be severed from this Public License
363
+ without affecting the enforceability of the remaining terms and
364
+ conditions.
365
+
366
+ c. No term or condition of this Public License will be waived and no
367
+ failure to comply consented to unless expressly agreed to by the
368
+ Licensor.
369
+
370
+ d. Nothing in this Public License constitutes or may be interpreted
371
+ as a limitation upon, or waiver of, any privileges and immunities
372
+ that apply to the Licensor or You, including from the legal
373
+ processes of any jurisdiction or authority.
374
+
375
+
376
+ =======================================================================
377
+
378
+ Creative Commons is not a party to its public
379
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
380
+ its public licenses to material it publishes and in those instances
381
+ will be considered the “Licensor.” The text of the Creative Commons
382
+ public licenses is dedicated to the public domain under the CC0 Public
383
+ Domain Dedication. Except for the limited purpose of indicating that
384
+ material is shared under a Creative Commons public license or as
385
+ otherwise permitted by the Creative Commons policies published at
386
+ creativecommons.org/policies, Creative Commons does not authorize the
387
+ use of the trademark "Creative Commons" or any other trademark or logo
388
+ of Creative Commons without its prior written consent including,
389
+ without limitation, in connection with any unauthorized modifications
390
+ to any of its public licenses or any other arrangements,
391
+ understandings, or agreements concerning use of licensed material. For
392
+ the avoidance of doubt, this paragraph does not form part of the
393
+ public licenses.
394
+
395
+ Creative Commons may be contacted at creativecommons.org.
config.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ def int_or_str(value):
4
+ """Custom function to allow both int and str types."""
5
+ try:
6
+ return int(value) # Try converting to integer
7
+ except ValueError:
8
+ return value # If conversion fails, return as string
9
+
10
+ def MyParser():
11
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
12
+ # general training
13
+ parser.add_argument("--seed", type=int, default=1)
14
+ parser.add_argument("--debug", type=int, default=0)
15
+ parser.add_argument("--multinodes", type=int, default=0)
16
+ parser.add_argument("--dist_url", default="env://", type=str)
17
+ parser.add_argument("--dist_backend", default="nccl", type=str)
18
+ parser.add_argument("--precision", type=str, default="float16", help="we might need float32 for NAR model")
19
+ parser.add_argument("--num_workers", type=int, default=8, help="per gpu")
20
+ parser.add_argument("--resume", action="store_true", default=False)
21
+ parser.add_argument("--tb_write_every_n_steps", type=int, default=100)
22
+ parser.add_argument("--print_every_n_steps", type=int, default=250)
23
+ parser.add_argument("--val_every_n_steps", type=int, default=500)
24
+ parser.add_argument("--inference_every_n_steps", type=int, default=3000, help="will only get to inference when model is saved, and therefore this needs to be multiple of val_every_n_steps")
25
+ parser.add_argument("--save_every_n_steps", type=int, default=10000000, help="save the model every n steps, will save the model as bundle_step$step.pth")
26
+ parser.add_argument("--lr", type=float, default=1e-4)
27
+ parser.add_argument("--batch_size", type=int, default=100, help="this is the effective batch size per gpu, no matter whether using gradient_accumulation_steps")
28
+ parser.add_argument("--weight_decay", type=float, default=1e-2)
29
+ parser.add_argument("--warmup_fraction", type=float, default=0.1, help="use linear warmup, the proportion of the training steps that are used for warming up")
30
+ parser.add_argument("--num_epochs", type=int, default=10)
31
+ parser.add_argument("--num_steps", type=int, default=None, help="if not None, will ignore n_epochs and use num_steps as the total number of amount of training, can try e.g. 400000 i.e. 400k steps")
32
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
33
+ parser.add_argument("--gradient_clip_val", type=float, default=1.0, help="the value for torch.nn.utils.clip_grad_norm_()")
34
+ parser.add_argument("--early_stop_step", type=int, default=3200, help="stop training after this many steps of non-improvement")
35
+ parser.add_argument("--early_stop_threshold", type=float, default=-1.0, help="early stop after the improvement is below this threshold for certain number of steps")
36
+
37
+
38
+ # path
39
+ parser.add_argument("--exp_dir", type=str, default='/saltpool0/scratch/pyp/VoiceEditor/', help="will be combined with dataset name")
40
+ parser.add_argument("--dataset", type=str, help="e.g. 'libritts', 'librilight', 'spotify', they are folder name in the data dir also")
41
+ parser.add_argument("--dataset_dir", type=str, help="need to be compatible with corresponding dataset py file")
42
+ parser.add_argument("--compact_folder_name", type=str, default=None, help="if not None, will use compact_combined_dataset.py, and this is the folder name of the compact dataset")
43
+ parser.add_argument("--inference_dataset_dir", type=str, default="/data/scratch/pyp/datasets/librilight/preprocessed", help="need to be compatible with corresponding dataset py file")
44
+
45
+ parser.add_argument("--training_stage", type=int, default=1, help="if 1, train VoiceEditor_one, if 2 train VoiceEditor_seven")
46
+ parser.add_argument("--local_wandb", type=int, default=0, help="if 1, will use local wandb, otherwise use the global one")
47
+ parser.add_argument("--wandb_entity", type=str, default="puyuanpeng", help="the entity (usually your username) for wandb")
48
+ # data
49
+ parser.add_argument("--librilight_ratio", type=float, default=1, help='the portion of lightlight compared to gigaspeech, 1 means equal, 2 means librilight data is twice as much as gigaspeech')
50
+ parser.add_argument("--plus_librilight_root", type=str, default=None, help="if not None, will combine gigaspeech and librilight, this is the root folder to librilight. Note that will need to merge the vocab.txt based on gigaspeech's, in order to be able to load a pretrained model")
51
+ parser.add_argument("--plus_librilight_phn_folder_name", type=str, default=None, help="if not None, will combine gigaspeech and librilight, this is the phoneme folder name of librilight")
52
+ parser.add_argument("--plus_librilight_encodec_folder_name", type=str, default=None, help="if not None, will combine gigaspeech and librilight, this is the encodec folder name of librilight")
53
+ parser.add_argument("--plus_librilight_manifest_name", type=str, default=None, help="if not None, will combine gigaspeech and librilight, this is the manifest folder name of librilight")
54
+ parser.add_argument("--skip_us", type=int, default=0, help="skip the giga utterances that contains 'j uː ɛ s' because of the tokenization issue")
55
+ parser.add_argument("--pseudo_epoch_size", type=int, default=37901, help="only use for Eden scheduler. 37901 is the epoch size in the default optim setting, this is probably too big")
56
+ parser.add_argument("--switch_order", type=int, default=0, help="this is only for hificodec, where we switch the order of 2 and 3nd codebook")
57
+ parser.add_argument("--phn_folder_name", type=str, default="phoneme", help="for libritts I also have arpa phns, in which case should be phonemes_arpa")
58
+ parser.add_argument("--encodec_folder_name", type=str, default="mimi_8cb", help="folder where encodec codes are stored")
59
+ parser.add_argument("--manifest_name", type=str, default="manifest_final", help="if using hificodec, it should be hificodec_menifest, if using encodec, it is the default")
60
+ parser.add_argument("--pad_x", type=int, default=1, help="whether or not always pad x to have text_max_length. select 1 to get the maximal memory consumption, but the actual case should be smaller, better to have it being 0")
61
+ parser.add_argument("--max_num_tokens", type=int, default=18750, help="max number of encodec tokens per gpu, this is only used when using dynamic batching, will ignore batch size. Note that batch size is the final effective batch size (sum of batch on each gpu), but max_num_tokens is per gpu")
62
+ parser.add_argument("--val_max_num_tokens", type=int, default=6000, help="FOR validation, this basically is for music-gen because of high mem consumption. max number of encodec tokens per gpu, this is only used when using dynamic batching, will ignore batch size. Note that batch size is the final effective batch size (sum of batch on each gpu), but max_num_tokens is per gpu")
63
+ parser.add_argument("--num_buckets", type=int, default=10)
64
+ parser.add_argument("--dynamic_batching", type=int, default=1)
65
+ parser.add_argument("--audio_max_length", type=float, default=120, help="in second, crop the audio is length is longer than this")
66
+ parser.add_argument("--audio_min_length", type=float, default=2, help="in second, drop the audio if length is shorter than this")
67
+ parser.add_argument("--text_max_length", type=int, default=1000, help='if too long, we crop')
68
+ parser.add_argument("--text_min_length", type=float, default=10, help="if too short, will drop")
69
+ parser.add_argument("--encodec_sr", type=float, default=50, help="for 24kHz mimi model, it produces 12.5 codes for 1 sec of audio")
70
+ parser.add_argument('--mask_len_min', type=int, default=20, help='Minimum mask length')
71
+ parser.add_argument('--mask_len_max', type=int, default=400, help='Maximum mask length')
72
+ parser.add_argument('--extra_mask_len_min', type=int, default=2, help='Minimum extra mask length')
73
+ parser.add_argument('--extra_mask_len_max', type=int, default=20, help='Maximum extra mask length')
74
+ parser.add_argument('--final_audio_token_len', type=int, default=772, help="this is only for stage 1 training, since we add eog, start_of_continue, and a random amount of extra mask, --audio_max_length won't be the final max length, the self.args.final_audio_token_len = self.args.audio_max_length*self.args.encodec_sr+self.args.extra_mask_len_max+2 ")
75
+
76
+ # model
77
+ parser.add_argument("--ttsonly", default=0, type=int, help="if 1, only train tts model, no CM3")
78
+ parser.add_argument("--load_existing_text_embedding", type=int, default=0, help="if 1, when load model and the text vocab doesn't match, will load the existing weights while the new weights will be initialized randomly")
79
+ parser.add_argument("--fly", type=int, default=0, help="if 1, encode chunked audio on the fly")
80
+ parser.add_argument("--encodec_ckpt", type=str, default="/data/scratch/pyp/exp_pyp/audiocraft/encodec/xps/6f79c6a8/checkpoint.th")
81
+ parser.add_argument("--downsample_rate", type=int, default=320, help="the downsample rate for the encodec model, 16000/320 = 50Hz")
82
+ parser.add_argument("--segtts_mask", type=int, default=0, help="if 1, use segtts_mask model, where we have a prefix and segment utterance into two and shifted separately for modeling, and use make use of mask:0, by insert two mask:0 in the middle of the two segments")
83
+ parser.add_argument("--segtts", type=int, default=0, help="if 1, use segtts model, where we have a prefix and segment utterance into two and shifted separately for modeling")
84
+ parser.add_argument("--edge", type=int, default=0, help="if 1, use edge prediction for the first codebook")
85
+ parser.add_argument("--duration_loss_weight", type=float, default=1.0, help="weight on the duration loss")
86
+ parser.add_argument("--drop_long", type=int, default=1, help="if this is true, will drop example whose encodec sequence or phone sequence is too long, rather than cropping as we did before, to avoid hellucination")
87
+ parser.add_argument("--eos", type=int, default=2051, help="this is to be used with reduced_eog, where we end the utterance with eos, and end the generated segment with eog, also when this is used, the n_special should be 4")
88
+ parser.add_argument("--reduced_eog", type=int, default=1, help="for the non-final segments, do not insert eog at the end, this could hopefully solve the early stopping issue when doing tts")
89
+
90
+ parser.add_argument("--valle_orig", type=int, default=0, help="the original valle model, trained for TTS")
91
+ parser.add_argument("--valle_max_prompt_len", type=float, default=6, help='in sec.')
92
+ # randomly choose a portion as tts examples during training
93
+ parser.add_argument("--tts_portion", type=float, default=0, help="randomly choose a portion of the training examples as tts examples, where no mask and rearrangement is used")
94
+
95
+ # put special tokens first to handle different vocab_size
96
+ parser.add_argument("--special_first", type=int, default=0, help="if 1, need to have special tokens to be the first few tokens, e.g. 0, 1, 2, which means we need to adjust the preprocessing and postprocessing of the encodec codes. note that we hard coded to have 3 special tokens")
97
+ parser.add_argument("--n_special", type=int, default=4, help="empty, eog, pad, eos")
98
+
99
+ # weight codebook differently
100
+ parser.add_argument("--codebook_weight", type=str, default=None, help="e.g. ['5','1','0.5','0.1']")
101
+
102
+ # args for MusicGen
103
+ parser.add_argument("--mask_span_weight", default=1.0, type=float, help="the weight on the tokens in masked span")
104
+ parser.add_argument("--unmask_span_weight", default=1.0, type=float, help="the weight on unmasked span")
105
+ parser.add_argument("--start_end_weight", default=None, type=str, help="weight the start x tokens and end x tokens differently, e.g. (10,2.0), means x == 10, weight==2.0")
106
+ # for now not consider the two weights above, only consider eog_weight, which is defined below somewhere, as the above two are not super principled
107
+
108
+ parser.add_argument("--musicgen", type=int, default=0, help="whether or not use this model, will also have an impact on the output shape of the dataset")
109
+ parser.add_argument("--enc_dec", default=0, type=int, help="use enc-dec architecture, text is from the enc, only for musicgen")
110
+ parser.add_argument("--dec", default=0, type=int, help="use dec only architecture, text is from the enc, only for musicgen. Exclusive with --enc_dec")
111
+ parser.add_argument("--empty_token", default=2048, type=int, help="indicating the no token at the position for the codebook")
112
+ # args for the optimizer and scheduler from Feiteng
113
+ # original setup for the 3 params are 5000 4 and 1000
114
+ # but that's because set_epoch is run on num_gradient_accumulation_step*step (with 4 being the accumulation step)
115
+ # so I scaled down them a little bit
116
+ # will try scaling them back if this doesn't work
117
+ parser.add_argument("--optimizer_name", type=str, default="AdamW", help="can also use ScaledAdam, in which case we'll also use the Eden scheduler")
118
+ parser.add_argument("--reduce_lr_start_step", type=int, default=3000, help='after which significantly reduce the lr. a param for the eden optimizer')
119
+ parser.add_argument("--reduce_lr_start_epoch", type=int, default=4)
120
+ parser.add_argument("--clipping_update_period", type=int, default=600)
121
+
122
+
123
+ # below are args for valle
124
+ # below are args for valle
125
+ parser.add_argument("--valle", type=int, default=0, help="if 1, use valle model (cm3)")
126
+ parser.add_argument("--decoder_dim", type=int, default=1024)
127
+ parser.add_argument("--norm_first", action="store_true", default=True)
128
+ parser.add_argument("--add_prenet", action="store_true", default=False)
129
+ parser.add_argument("--prefix_mode", type=int, default=5, help="this is for NAR, we only do 5, which is CM3")
130
+ parser.add_argument("--share_embedding", action="store_true", default=False)
131
+ parser.add_argument("--nar_scale_factor", type=float, default=1.0)
132
+ parser.add_argument("--prepend_bos", action="store_true", default=False)
133
+ parser.add_argument("--sync_nar", type=int, default=0, help="whether to choose the same NAR model to run for training_stage==2 across different process (this is only for DDP)")
134
+ # above are args for valle
135
+ # above are args for valle
136
+
137
+
138
+
139
+ # add parallel_pattern
140
+ parser.add_argument("--parallel_pattern", type=int, default=0, help="if 1, use parallel pattern, we also use LFSC codec")
141
+ parser.add_argument("--full_prediction", type=int, default=0, help='this is for ve1, if 1, use full autoregressive mask, and calculate loss over all tokens, except for mask_tokens')
142
+ parser.add_argument("--multicm3", type=int, default=0, help='cm3 model but allows multiple mask spans')
143
+ parser.add_argument("--max_mask_portion",type=float,default=0.7,help="should mask a utterance for more than this portion")
144
+ parser.add_argument("--max_n_spans", type=int, default=8, help='maximal number of spans, only use when using multicm3, this is used to decide number of mask_embedding, and max clamp value if use Poisson distribution, if use uniform distribution to sample number of spans if will be uniform(1,max_n_spans)')
145
+ parser.add_argument("--shuffle_mask_embedding", type=int, default=0, help="whether shuffle the mask embedding, so that mask:0 is not the most well trained, default is not shuffling. The default has it's benefit, as it make sure that mask:0 always appear the first")
146
+ parser.add_argument("--mask_sample_dist", type=str, default="uniform", help="uniform or poissonx, e.g. poisson1, meaning the parameter lambda is 1, it will most likely sample 1 masks")
147
+ parser.add_argument("--min_gap", type=int, default=10, help="after sampled starts, delete later one if it closer to the former start than the min_gap")
148
+
149
+ parser.add_argument('--cm3', type=int, default=0, help="use cm3 style for ve1, the input from dataloader is going to be just raw data, all masking and rearrangement will happen whin the model")
150
+
151
+ parser.add_argument('--sep_special_token', type=int, default=0, help="remove text/audio pad token, set audio_mask_token and start of continue to be separately learned embeddings. Therefore, for ve1 self.n_text_tokens == self.args.text_vocab_size, self.n_audio_tokens == self.args.audio_vocab_size + 2, for ve7, self.n_text_tokens == self.args.text_vocab_size, self.n_audio_tokens == self.args.audio_vocab_size")
152
+ parser.add_argument('--one_causal', type=int, default=0, help="whether model VE_one generation as autoregressive gen or non-autoregressive gen")
153
+ parser.add_argument('--n_codebooks', type=int, default=8)
154
+ parser.add_argument('--weight_sharing', type=int, default=0, help="sharing weights between VE_seven predict layer and embedding layer")
155
+ parser.add_argument('--text_vocab_size', type=int, default=86, help='Size of text vocabulary')
156
+ parser.add_argument('--text_pad_token', type=int, default=86, help='padding of the text tokens, not attended')
157
+ # parser.add_argument('--audio_vocab_size', type=int, default=1024, help='Size of audio vocabulary')
158
+ parser.add_argument('--audio_vocab_size', type=str, default='2048', help="Size of audio vocabulary, can be specified as '[128,512,1024,2048]'")
159
+ parser.add_argument('--audio_mask_token', type=int, default=1024, help='Audio mask token, this the the extra mask used in the masked region for AR, for NAR, the entire masked region will be filled with it')
160
+ parser.add_argument('--bog', type=int, default=1025, help='Begin of generation token')
161
+ parser.add_argument('--eog', type=int, default=2049, help='End of generation token')
162
+ parser.add_argument('--start_of_continue', type=int, default=1027, help='this token follows the masked region, proceeds the first unmasked token, to indicate that gt tokens starts')
163
+ parser.add_argument('--audio_pad_token', type=int, default=2050, help='padding of the encodec codes, not attended')
164
+ parser.add_argument('--d_model', type=int, default=1024, help='Model dimension')
165
+ parser.add_argument('--audio_embedding_dim', type=int, default=128, help='dimension for encodec continues embedding (before being quantized)')
166
+ parser.add_argument('--text_embedding_dropout', type=float, default=0.1, help='Dropout for text embedding')
167
+ parser.add_argument('--audio_embedding_dropout', type=float, default=0, help='Dropout for audio embedding')
168
+ parser.add_argument('--text_positional_embedding_dropout', type=float, default=0.1, help='Dropout for text positional embedding')
169
+ parser.add_argument('--audio_positional_embedding_dropout', type=float, default=0.1, help='Dropout for audio positional embedding')
170
+ parser.add_argument('--trm_dropout', type=float, default=0.1, help='Dropout for transformer')
171
+ parser.add_argument('--nhead', type=int, default=16, help='Number of attention heads')
172
+ parser.add_argument('--num_encoder_layers', type=int, default=12, help='Number of encoder layers')
173
+ parser.add_argument('--num_decoder_layers', type=int, default=12, help='Number of decoder layers')
174
+ parser.add_argument('--eog_weight', type=float, default=1.0, help='Weight for End of generation token')
175
+ parser.add_argument('--stage_one_load_encodec_embedding', type=str, default=None, help='Path to load encodec embedding for stage one. On our lab machine it is /saltpool0/scratch/pyp/VoiceEditor/encodec_embedding/24khz_8codebooks.pth, 8 is the n_codebooks')
176
+ parser.add_argument('--stage_two_load_encodec_embedding', type=str, default=None, help='Path to load encodec embedding for stage two。 On our lab machine it is /saltpool0/scratch/pyp/VoiceEditor/encodec_embedding/24khz_8codebooks.pth, 8 is the n_codebooks')
177
+ parser.add_argument('--stage_two_load_ve_one_embedding', type=str, default=None, help='Path to load VoiceEditor_one audio embedding for stage two')
178
+ parser.add_argument('--load_model_from', type=str, default=None, help='Path to load model from, this will be effective last, so will overwrite all previous load, including resume')
179
+ parser.add_argument('--load_model_from_ve1', type=str, default=None, help='Path to load ve1 model weights from, this will be effective last, designed for loading the encoder weights of the VE7 from a pretrained VE1')
180
+
181
+
182
+ ## below are args for the new long model
183
+ parser.add_argument("--target_time_stretch_prob", type=float, default=0, help="the probability of time stretching the target audio")
184
+ parser.add_argument("--target_time_stretch_bound", type=float, default=0.1, help="the bound of the time stretching target audio, e.g. 0.1 means the audio will be stretched by 0.9 to 1.1")
185
+ parser.add_argument("--time_stretch_prob", type=float, default=0, help="the probability of time stretching the audio")
186
+ parser.add_argument("--time_stretch_bound", type=float, default=0.3, help="the bound of the time stretching, e.g. 0.3 means the audio will be stretched by 0.7 to 1.3")
187
+ parser.add_argument("--no_loss_on_prefix", type=int, default=0, help="if 1, will not calculate loss on the prefix acoustic tokens")
188
+ parser.add_argument("--x_sep_token", type=int, default=None, help="if not None, will use this token in between prompt text and target generation text")
189
+ parser.add_argument("--y_sep_token", type=int, default=None, help="if not None, will use this token in between prompt codec tokens and target codec tokens")
190
+ parser.add_argument("--neighbor_prompt_prob", type=float, default=0, help="the probability of using the prompt from the neighbor")
191
+ parser.add_argument("--neighbor_folder_name", type=str, default='neighbors',help="folder where the neighbors of the current audio files are stored, each row contains three tab separated entries: neighbor_fn, neighbor_temporal_distance, neighbor_duration")
192
+ parser.add_argument("--alignment_folder_name", type=str, default='alignment', help="folder where the forced alignment of the current audio files are stored, in csv format, each row contains five comma separated entries: begin, end, label, type, speaker, the first row is header")
193
+ parser.add_argument("--ipa_alignment_folder_name", type=str, default='ipa_alignment', help="folder where the forced alignment of the current audio files are stored, in txt format, each row contains three tab separated entries: begin, end, ipa phn sequence, generated using data/ll60k_preprocessing/step7_ipa_alignment.py")
194
+ parser.add_argument("--max_prompt_len", type=float, default=30, help="in sec., maximal prompt length selected from some neighboring file")
195
+ parser.add_argument("--min_prompt_len", type=float, default=0.5, help="in sec., minimal prompt length selected from some neighboring file")
196
+ parser.add_argument("--neighbor_selection_method", type=str, default="maxdist_60", help="maxdist_60 means uniformly select a neighbor that's within 60 sec of the current audio file")
197
+ parser.add_argument("--num_trial", type=int, default=5, help="number of tries to select a neighbor")
198
+ parser.add_argument("--prompt_start_from_begining_prob", type=float, default=0.5, help="the probability of starting the prompt from the beginning of the neighbor")
199
+ parser.add_argument("--min_alignment_len", type=int, default=5, help="in number of words")
200
+ parser.add_argument("--audio_folder_name", type=str, default='audio', help="folder where the audio files are stored")
201
+
202
+ # rope parameters
203
+ parser.add_argument("--decoder_regular_rope", type=int, default=0, help="if 1, will use regular rope for the decoder (note that we always use regular rope for encoder). ")
204
+ parser.add_argument("--progress_no_multiple", type=int, default=0, help="if 1, will not multiple the percentage progress by the length of the key, see apply_rotary_pos_emb in models/modules/activation.py, this applies to both rope and sinusoidal positional encoding. Note that progress scale is still applied, i.e. when we only apply progress scale, but not multiple, the scaling factor is constant for every sample, rather than sample dependent")
205
+ parser.add_argument("--add_eos_to_text", type=int, default=0, help="if not 0, use this number as eos and add to the end of text token, usually use the second to last token in the vocab size")
206
+ parser.add_argument("--add_bos_to_text", type=int, default=0, help="if not 0, use this number as bos and add to the begining of text token, usually use the third to last token in the vocab size")
207
+ parser.add_argument("--use_sinusoidal", type=int, default=0, help="if 1, will use sinusoidal positional encoding, otherwise use rope. BUT if rope_base is None, will use sinusoidal")
208
+ parser.add_argument("--sinusoidal_base", type=int, default=1e4, help="the base of the exponential function, default is 1e4")
209
+ parser.add_argument("--use_sinusoidal_progress", type=int, default=0, help="if 1, will use sinusoidal positional encoding for progress, otherwise use rope")
210
+ parser.add_argument("--rope_base", type=int, default=None, help="the base of the exponential function, default is 1e4, if None, will not use rope")
211
+ parser.add_argument("--multiple_key_length", type=int, default=0, help="if 1, during progress calculation, will multiple the precentage progress by the length of the key, otherwise multiple with length of query. see models/rope_playground.ipynb")
212
+ parser.add_argument("--progress_scale", type=float, default=1.0, help="scale the progress, the smaller the value, the bigger the diagonal in attention score, see models/rope_playground.ipynb")
213
+
214
+ # attention alignment loss
215
+ parser.add_argument("--attention_alignment_loss", type=float, default=0.0, help="the weight on the attention alignment loss, if 0, will not calculate the loss")
216
+ parser.add_argument("--alignment_loss_layer", type=str, default="['0-1', '2', '3']", help='the layers to calculate the alignment loss, e.g. ["0-1", "2", "3"]')
217
+ parser.add_argument("--alignment_loss_head", type=str, default="['0-1', '2', '3']", help='the attention heads to calculate the alignment loss, e.g. ["0-1", "2", "3"]')
218
+ parser.add_argument("--alignment_blank_logit", type=float, default=-1.0, help="the logit for the blank token added to the attention weights")
219
+
220
+ # inference parameters
221
+ parser.add_argument("--metrics", type=str, default="['spk_sim','wer','mcd','pitch','energy','pesq','utmos']")
222
+ parser.add_argument("--res_jsonl_root", type=str, default="/home/pyp/BoostedVoiceEditor/res")
223
+ parser.add_argument("--res_name", type=str, default="2jan25.jsonl")
224
+ parser.add_argument("--inference_seed", type=int, default=1)
225
+ parser.add_argument("--codec_audio_sr", type=int, default=16000)
226
+ parser.add_argument("--codec_sr", type=float, default=50)
227
+ parser.add_argument("--top_k", type=int, default=0)
228
+ parser.add_argument("--top_p", type=float, default=0.9)
229
+ parser.add_argument("--temperature", type=float, default=1)
230
+ parser.add_argument("--silence_tokens", type=list, default=[])
231
+ parser.add_argument("--kvcache", type=int, default=0)
232
+ parser.add_argument("--stop_repetition", type=int, default=3)
233
+ parser.add_argument("--sample_batch_size", type=int, default=1)
234
+ parser.add_argument("--inference_manifest_fns", type=str, default="['/home/pyp/BoostedVoiceEditor/manifests/debug.jsonl']")
235
+ parser.add_argument("--use_gt_duration", type=int, default=1)
236
+ parser.add_argument("--save_root", type=str, default="/data/scratch/pyp/exp_pyp/BoostedVoiceEditor/gens")
237
+ parser.add_argument("--encodec_signature", type=str, default="/data/scratch/pyp/exp_pyp/audiocraft/encodec/xps/6f79c6a8/checkpoint.th")
238
+ parser.add_argument("--extra_cutoff", type=float, default=5, help="in rare cases where the model doesn't follow specified target duration (only happened in extrapolation cases), we will terminate generation once the extra duration exceeds this value")
239
+ parser.add_argument("--duration_margin", type=float, default=0.04, help="used along with extra_cutoff, when extra_cutoff is used (i.e. model doesn't follow specified target_duration), we terminate the generate, and cut the results to target_duration + duration_margin")
240
+ # add repeat_prompt and asr_model_name
241
+ parser.add_argument("--repeat_prompt", type=int_or_str, default=0, help="if 1, will repeat the prompt for each segment")
242
+ parser.add_argument("--asr_model_name", type=str, default="w2v2", help="the name of the asr model, if not None, will use the asr model to generate the prompt")
243
+
244
+ # depth transformer parameters
245
+ parser.add_argument("--depth_dec_num_layers", type=int, default=0)
246
+ parser.add_argument("--depth_dec_d_model", type=int, default=768)
247
+ parser.add_argument("--depth_dec_nhead", type=int, default=12)
248
+ parser.add_argument("--moshi_depth", type=int, default=0, help="if 1, will use the same parameterization as moshi, i.e. temporal trm output will gets added to every transformed token embedding")
249
+
250
+ parser.add_argument("--validation_sample_cap", type=int, default=None, help="cap the validation data to this number")
251
+ parser.add_argument("--no_libri_in_training", type=int, default=None, help="if 1, will not use librilight in training, only use in validation")
252
+ parser.add_argument("--uniform_weight_start_step", type=int, default=1e50, help="set all codebook weight to be uniform starting from this step")
253
+
254
+ return parser
copy_codebase.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import shutil
4
+ import fnmatch
5
+
6
+ def parse_gitignore(gitignore_path):
7
+ """Parse a .gitignore file and return a list of patterns."""
8
+ patterns = []
9
+ with open(gitignore_path, "r") as f:
10
+ for line in f:
11
+ # Ignore comments and blank lines
12
+ line = line.strip()
13
+ if not line or line.startswith("#"):
14
+ continue
15
+ # Handle wildcards and directory separators
16
+ patterns.append(line)
17
+ return patterns
18
+
19
+ def file_matches_patterns(file_path, patterns):
20
+ """Check if a file matches any of the patterns in .gitignore."""
21
+ for pattern in patterns:
22
+ if fnmatch.fnmatch(file_path, pattern):
23
+ return True
24
+ return False
25
+
26
+ def copy_codebase(src, dst, max_size_mb=5, gitignore_path=None):
27
+ """ Copy files from src to dst, skipping files larger than max_size_mb and matching .gitignore patterns. """
28
+ if gitignore_path and os.path.exists(gitignore_path):
29
+ patterns = parse_gitignore(gitignore_path)
30
+ else:
31
+ patterns = []
32
+ print("patterns to ignore: ", patterns)
33
+ os.makedirs(dst, exist_ok=True)
34
+ for root, dirs, files in os.walk(src):
35
+ for file in files:
36
+ file_path = os.path.join(root, file)
37
+ relative_path = os.path.relpath(file_path, src)
38
+ dst_path = os.path.join(dst, relative_path)
39
+ # ignore .git because of permission issues
40
+ if "/.git/" in file_path:
41
+ continue
42
+
43
+ # Check .gitignore patterns
44
+ if file_matches_patterns(file_path, patterns):
45
+ # print(f"Skipping {file_path} because it matches a pattern in .gitignore")
46
+ continue
47
+
48
+ # Check file size
49
+ if os.path.getsize(file_path) > max_size_mb * 1024 * 1024:
50
+ print(f"Skipping {file_path} because it's larger than {max_size_mb}MB")
51
+ continue
52
+
53
+
54
+ # Make sure the destination directory exists
55
+ os.makedirs(os.path.dirname(dst_path), exist_ok=True)
56
+ shutil.copy(file_path, dst_path)
data/__init__.py ADDED
File without changes
data/combined_dataset.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import ffmpeg
3
+ import torch
4
+ import random
5
+ import copy
6
+ import logging
7
+ import torch.distributed as dist
8
+ import shutil
9
+ import csv
10
+ import torchaudio
11
+ import glob
12
+ import numpy as np
13
+ from data.tokenizer import TextTokenizer, tokenize_text, AudioTokenizer
14
+ def find_files(root_dir, endswith=".wav"):
15
+ files = []
16
+ # os.walk generates the file names in a directory tree
17
+ for dirpath, dirnames, filenames in os.walk(root_dir):
18
+ for filename in filenames:
19
+ # os.path.splitext splits the file name into a base and extension
20
+ # base, ext = os.path.splitext(filename)
21
+ if filename.lower().endswith(endswith):
22
+ # os.path.join combines one or more path names into a single path
23
+ full_path = os.path.join(dirpath, filename)
24
+ files.append(full_path)
25
+ return files
26
+
27
+ class dataset(torch.utils.data.Dataset):
28
+ def __init__(self, args, split):
29
+ super().__init__()
30
+ self.args = args
31
+ self.args.target_time_stretch_prob = getattr(self.args, "target_time_stretch_prob", 0)
32
+ self.args.target_time_stretch_bound = getattr(self.args, "target_time_stretch_bound", 0.1)
33
+ self.split = split
34
+
35
+ assert self.split in ['train', 'valid', 'test'], f"split should be one of ['train', 'valid', 'test'], but it's {split}"
36
+
37
+ if "[" not in self.args.dataset_dir or "]" not in self.args.dataset_dir:
38
+ self.dataset_dir = f"['{self.args.dataset_dir}']"
39
+ else:
40
+ self.dataset_dir = copy.deepcopy(self.args.dataset_dir)
41
+ self.dataset_dir = eval(self.dataset_dir)
42
+ data = []
43
+ if "[" not in self.args.manifest_name or "]" not in self.args.manifest_name:
44
+ self.args.manifest_name = f"['{self.args.manifest_name}']"
45
+ else:
46
+ self.args.manifest_name = copy.deepcopy(self.args.manifest_name)
47
+ self.manifest_name = eval(self.args.manifest_name)
48
+ if len(self.manifest_name) != len(self.dataset_dir):
49
+ assert len(self.manifest_name) == 1, f"len(self.manifest_name) should be 1 or equal to len(self.dataset_dir), but it's {len(self.manifest_name)}"
50
+ self.manifest_name = self.manifest_name * len(self.dataset_dir)
51
+ for i_data, dataset_dir in enumerate(self.dataset_dir):
52
+ if getattr(self.args, "no_libri_in_training", None) != None and ("librilight" in dataset_dir) and self.split == "train":
53
+ if not dist.is_initialized() or dist.get_rank() == 0:
54
+ logging.info(f"skipping librilight in training split")
55
+ continue
56
+ n_datapoints = 0
57
+ manifest_fn = os.path.join(dataset_dir, self.manifest_name[i_data], self.split+".txt")
58
+ if not os.path.isfile(manifest_fn):
59
+ all_manifest_fn = glob.glob(manifest_fn.replace(".txt", "_*=*.txt"))
60
+ if len(all_manifest_fn) == 0:
61
+ logging.info(f"no manifest file found for {split} split in {dataset_dir}")
62
+ continue
63
+ if self.args.debug:
64
+ logging.info(f"debugging mode, only using the frist found manifest file: {all_manifest_fn[0]}")
65
+ all_manifest_fn = all_manifest_fn[:1]
66
+ else:
67
+ if dist.is_initialized() and dist.get_rank() == 0:
68
+ logging.info(f"Combining found manifest files for {split}: {all_manifest_fn}")
69
+ for cur_manifest_fn in all_manifest_fn:
70
+ with open(cur_manifest_fn, "r") as rf:
71
+ tmp = [l.strip().split("\t") + [i_data] for l in rf.readlines()] # i_data is the index of the dataset
72
+ n_datapoints += len(tmp)
73
+ data += tmp
74
+ else:
75
+ with open(manifest_fn, "r") as rf:
76
+ tmp = [l.strip().split("\t") + [i_data] for l in rf.readlines()]
77
+ data += tmp
78
+ n_datapoints += len(tmp)
79
+ if dist.is_initialized() and dist.get_rank() == 0:
80
+ logging.info(f"number of data points for {split} split in {dataset_dir}: {n_datapoints}")
81
+ assert len(data) > 0, f"no data found for {split} split"
82
+ lengths_list = [int(item[1]) for item in data] # use 1 because there might be more than 1 columns (for gigaspeech we have 3 columns: path, duration, selfsim)
83
+ self.data = []
84
+ self.lengths_list = []
85
+ total_duration = 0
86
+ for d, l in zip(data, lengths_list):
87
+ if l >= self.args.encodec_sr*self.args.audio_min_length:
88
+ if self.args.drop_long and l > self.args.encodec_sr*self.args.audio_max_length:
89
+ continue
90
+ self.data.append(d)
91
+ self.lengths_list.append(l)
92
+ total_duration += l / self.args.encodec_sr / 3600
93
+ # logging.info(f"for now cut the dataset to only have 500 examples for debugging")
94
+ # self.data = self.data[:1000]
95
+ # self.lengths_list = self.lengths_list[:1000]
96
+ if dist.is_initialized() and dist.get_rank() == 0:
97
+ logging.info(f"TOTAL number of data points for {self.split} split: {len(self.lengths_list)}")
98
+ logging.info(f"TOTAL duration for {self.split} split: {total_duration:.1f} hours")
99
+ # phoneme vocabulary
100
+ phn_set = set()
101
+ for dataset_dir in self.dataset_dir:
102
+ vocab_fn = os.path.join(dataset_dir, "vocab.txt")
103
+ with open(vocab_fn, "r") as f:
104
+ temp = [l.strip().split("\t") for l in f.readlines() if len(l) != 0]
105
+ phn_set.update([item[-1] for item in temp])
106
+ self.phn2num = {item:i for i, item in enumerate(phn_set)}
107
+ assert self.args.text_vocab_size > len(self.phn2num), f"need self.args.text_vocab_size to be bigger than number of phns in vocab to handle OOD phn, but the former is {self.args.text_vocab_size} while the latter is {len(self.phn2num)}"
108
+
109
+ if (self.args.neighbor_prompt_prob > 0 and self.args.time_stretch_prob > 0) or self.args.target_time_stretch_prob > 0:
110
+ userdir = os.path.expanduser("~")
111
+ encodec_signature = getattr(self.args, "encodec_signature", os.path.join(userdir, "VoiceStar", "pretrained", "encodec_6f79c6a8.th"))
112
+ self.audio_tokenizer = AudioTokenizer(signature=encodec_signature, device=torch.device("cpu"), encode_only=True)
113
+ assert self.audio_tokenizer.sample_rate == self.args.codec_audio_sr, f"audio_tokenizer.sample_rate: {self.audio_tokenizer.sample_rate}, self.args.encodec_sr: {self.args.encodec_sr}"
114
+ if dist.is_initialized() and dist.get_rank() == 0:
115
+ logging.info(f"rank: {dist.get_rank()}, audio_tokenizer device: {self.audio_tokenizer._device}")
116
+
117
+ def __len__(self):
118
+ return len(self.lengths_list)
119
+
120
+ def _load_phn_enc(self, index):
121
+ item = self.data[index]
122
+ dataset_dir = self.dataset_dir[item[-1]]
123
+ pf = os.path.join(dataset_dir, self.args.phn_folder_name, item[0]+".txt")
124
+ ef = os.path.join(dataset_dir, self.args.encodec_folder_name, item[0]+".txt")
125
+ # with certain probability, we load the audio, and time stretch it, note that we should not hit self.args.audio_max_length
126
+ if "/librilight" in dataset_dir:
127
+ audio_ext = ".flac"
128
+ elif "/emilia" in dataset_dir:
129
+ audio_ext = ".mp3"
130
+ else:
131
+ raise NotImplementedError(f"dataset_dir: {dataset_dir}")
132
+
133
+ audio_fn = os.path.join(dataset_dir, self.args.audio_folder_name, item[0].replace(".txt", "")+audio_ext)
134
+ speed_factor = random.uniform(-self.args.target_time_stretch_bound, self.args.target_time_stretch_bound) + 1
135
+ length_ok = (float(item[1]) / self.args.encodec_sr) / speed_factor < self.args.audio_max_length # NOTE to calculate the maximal duration after time stretching, we should be used as orig/(1-bound), rather than orig*(1+bound)
136
+ if self.args.target_time_stretch_prob > 0 and random.random() < self.args.target_time_stretch_prob and os.path.isfile(audio_fn) and length_ok:
137
+ try:
138
+ with open(pf, "r") as p:
139
+ phns = [l.strip() for l in p.readlines()]
140
+ assert len(phns) == 1, phns
141
+ all_phns = phns[0].split(" ")
142
+ x = [self.phn2num[item] for item in all_phns if item in self.phn2num]
143
+ except:
144
+ logging.info(f"loading failed for {pf}, maybe files don't exist or are corrupted")
145
+ return [], [[]], dataset_dir, audio_ext
146
+ # time stretch
147
+ try:
148
+ process = (
149
+ ffmpeg.input(audio_fn, ss=0, t=float(item[1]) / self.args.encodec_sr)
150
+ .output('pipe:1', format='f32le', ac=1, ar=self.audio_tokenizer.sample_rate, filter='atempo={}'.format(speed_factor))
151
+ .run_async(pipe_stdout=True, pipe_stderr=True)
152
+ )
153
+ # Read the processed audio from ffmpeg stdout
154
+ output, _ = process.communicate()
155
+
156
+ # Convert the output to a numpy array
157
+ output_np = np.frombuffer(output, dtype=np.float32).copy()
158
+
159
+ # Reshape the numpy array back to the expected shape (1, samples for mono)
160
+ waveform = torch.from_numpy(output_np)
161
+ waveform = waveform.unsqueeze(0).unsqueeze(0)
162
+ assert waveform.ndim == 3 and waveform.shape[0] == 1 and waveform.shape[1] == 1, waveform.shape
163
+ with torch.no_grad():
164
+ encos = self.audio_tokenizer.encode(waveform.to(self.audio_tokenizer._device))
165
+ assert encos.shape[1] == self.args.n_codebooks, f"encos.shape: {encos.shape}"
166
+ encos = encos.cpu().squeeze(0).numpy().tolist() # [K, T]
167
+ if self.args.special_first:
168
+ raise NotImplementedError
169
+ # y = [[int(n)+self.args.n_special for n in l] for l in encos]
170
+ else:
171
+ y = [[int(n) for n in l] for l in encos]
172
+ return x, y, dataset_dir, audio_ext
173
+ except Exception as e:
174
+ logging.info(f"failed with time stretch and codec encode for {audio_fn}")
175
+ logging.info(f"error: {e}")
176
+ pass
177
+
178
+ try:
179
+ with open(pf, "r") as p, open(ef, "r") as e:
180
+ phns = [l.strip() for l in p.readlines()]
181
+ assert len(phns) == 1, phns
182
+ all_phns = phns[0].split(" ")
183
+ x = [self.phn2num[item] for item in all_phns if item in self.phn2num] # we assume that OOD will not happen, because phn vocab is small
184
+ encos = [l.strip().split() for k, l in enumerate(e.readlines()) if k < self.args.n_codebooks]
185
+
186
+ assert len(encos) == self.args.n_codebooks, ef
187
+
188
+ if self.args.special_first:
189
+ raise NotImplementedError
190
+ # y = [[int(n)+self.args.n_special for n in l] for l in encos]
191
+ else:
192
+ y = [[int(n) for n in l] for l in encos]
193
+ except:
194
+ logging.info(f"loading failed for {pf} and {ef}, maybe files don't exist or are corrupted")
195
+ return [], [[]], dataset_dir, audio_ext
196
+
197
+ return x, y, dataset_dir, audio_ext
198
+
199
+ # this uses the output of step7_ipa_alignment.py
200
+ def find_neighbor(self, neighbors, y_len, dataset_dir, audio_ext):
201
+ neighbor = random.choice(neighbors)
202
+ neighbor_enc_fn = os.path.join(dataset_dir, self.args.encodec_folder_name, neighbor[0])
203
+ if not os.path.isfile(neighbor_enc_fn):
204
+ return None, None
205
+ neighbor_audio_path = os.path.join(dataset_dir, self.args.audio_folder_name, neighbor[0].replace(".txt", audio_ext))
206
+ if getattr(self.args, "time_stretch_prob", 0) > 0 and not os.path.isfile(neighbor_audio_path):
207
+ logging.info(f"audio file not found: {neighbor_audio_path}")
208
+ return None, None
209
+ if random.random() < getattr(self.args, "time_stretch_prob", 0):
210
+ time_stretch_flag = True
211
+ speed_factor = random.uniform(-self.args.time_stretch_bound, self.args.time_stretch_bound) + 1
212
+ duration_factor = 1 / speed_factor
213
+ else:
214
+ time_stretch_flag = False
215
+ duration_factor = 1
216
+
217
+ ####################### TODO for now always use the entire neighbor for emilia
218
+ ####################### TODO for now always use the entire neighbor for emilia
219
+ # if it's gigaspeech or emilia, we did not run MFA forced alignment, and therefore no ipa alignment, and will just use the entire neighbor as the prompt
220
+
221
+ if "/emilia" in dataset_dir:
222
+ # get neighbor duration
223
+ neighbor_dur = float(neighbor[2])
224
+ if neighbor_dur * duration_factor + y_len / self.args.encodec_sr > self.args.audio_max_length or neighbor_dur * duration_factor < self.args.min_prompt_len:
225
+ return None, None
226
+ try:
227
+ neighbor_pf = os.path.join(dataset_dir, self.args.phn_folder_name, neighbor[0])
228
+ with open(neighbor_pf, "r") as p:
229
+ phns = [l.strip() for l in p.readlines()]
230
+ assert len(phns) == 1, phns
231
+ all_phns = phns[0].split(" ")
232
+ phn_token = [self.phn2num[item] for item in all_phns if item in self.phn2num]
233
+ except:
234
+ logging.info(f"loading failed for {neighbor_pf}, maybe files don't exist")
235
+ return None, None
236
+ # if do not stretch the audio
237
+ if not time_stretch_flag:
238
+ with open(neighbor_enc_fn, "r") as f:
239
+ neighbor_enc = [l.strip().split() for l in f.readlines()]
240
+ if len(neighbor_enc) != self.args.n_codebooks:
241
+ return None, None
242
+ # if too long
243
+ else:
244
+ if self.args.special_first:
245
+ raise NotImplementedError
246
+ # neighbor_enc = [[int(n)+self.args.n_special for n in l] for l in neighbor_enc]
247
+ else:
248
+ neighbor_enc = [[int(n) for n in l] for l in neighbor_enc]
249
+
250
+ return phn_token, neighbor_enc
251
+ else: # stretch the audio with ffmpeg-python
252
+ process = (
253
+ ffmpeg.input(neighbor_audio_path, ss=0, t=neighbor_dur)
254
+ .output('pipe:1', format='f32le', ac=1, ar=self.audio_tokenizer.sample_rate, filter='atempo={}'.format(speed_factor))
255
+ .run_async(pipe_stdout=True, pipe_stderr=True)
256
+ )
257
+ # Read the processed audio from ffmpeg stdout
258
+ output, _ = process.communicate()
259
+
260
+ # Convert the output to a numpy array
261
+ output_np = np.frombuffer(output, dtype=np.float32).copy()
262
+
263
+ # Reshape the numpy array back to the expected shape (1, samples for mono)
264
+ waveform = torch.from_numpy(output_np)
265
+ waveform = waveform.unsqueeze(0).unsqueeze(0)
266
+ assert waveform.ndim == 3 and waveform.shape[0] == 1 and waveform.shape[1] == 1, waveform.shape
267
+ with torch.no_grad():
268
+ encos = self.audio_tokenizer.encode(waveform.to(self.audio_tokenizer._device))
269
+ assert encos.shape[1] == self.args.n_codebooks, f"encos.shape: {encos.shape}"
270
+ neighbor_enc = encos.cpu().squeeze(0).numpy().tolist() # [K, T]
271
+ return phn_token, neighbor_enc
272
+ ####################### TODO for now always use the entire neighbor for emilia
273
+ ####################### TODO for now always use the entire neighbor for emilia
274
+ ipa_alignment_fn = os.path.join(dataset_dir, self.args.ipa_alignment_folder_name, neighbor[0])
275
+ if not os.path.isfile(ipa_alignment_fn):
276
+ # print(f"file not found: {ipa_alignment_fn}", flush=True)
277
+ return None, None
278
+ with open(ipa_alignment_fn, "r") as f:
279
+ alignments = [l.strip().split("\t") for l in f.readlines()]
280
+ alignments = [[float(l[0]), float(l[1]), l[2]] for l in alignments if len(l) == 3]
281
+ alignments = [l for l in alignments if self.args.min_prompt_len < (l[1] - l[0]) * duration_factor < self.args.max_prompt_len]
282
+ if len(alignments) == 0:
283
+ # print(f"no valid alignment found for {ipa_alignment_fn}")
284
+ return None, None
285
+ idx = random.choice(range(len(alignments)))
286
+ while (alignments[idx][1] - alignments[idx][0]) * duration_factor + y_len / self.args.encodec_sr > self.args.audio_max_length:
287
+ idx -= 1
288
+ if idx < 0:
289
+ # print(f"too long combined with y_len {ipa_alignment_fn=}, and {y_len=}")
290
+ return None, None
291
+ if (alignments[idx][1] - alignments[idx][0]) * duration_factor < self.args.min_prompt_len:
292
+ return None, None
293
+
294
+ start_time, end_time = alignments[idx][:2]
295
+ phn = alignments[idx][2].split(" ")
296
+ phn_token = [self.phn2num[item] for item in phn if item in self.phn2num]
297
+ if len(phn_token) == 0:
298
+ return None, None
299
+
300
+ if time_stretch_flag:
301
+ duration = end_time - start_time
302
+ process = (
303
+ ffmpeg.input(neighbor_audio_path, ss=start_time, t=duration)
304
+ .output('pipe:1', format='f32le', ac=1, ar=self.audio_tokenizer.sample_rate, filter='atempo={}'.format(speed_factor))
305
+ .run_async(pipe_stdout=True, pipe_stderr=True)
306
+ )
307
+ # Read the processed audio from ffmpeg stdout
308
+ output, _ = process.communicate()
309
+
310
+ # Convert the output to a numpy array
311
+ output_np = np.frombuffer(output, dtype=np.float32).copy()
312
+
313
+ # Reshape the numpy array back to the expected shape (1, samples for mono)
314
+ waveform = torch.from_numpy(output_np)
315
+ waveform = waveform.unsqueeze(0).unsqueeze(0)
316
+ assert waveform.ndim == 3 and waveform.shape[0] == 1 and waveform.shape[1] == 1, waveform.shape
317
+ try:
318
+ with torch.no_grad():
319
+ encos = self.audio_tokenizer.encode(waveform.to(self.audio_tokenizer._device))
320
+ except:
321
+ logging.info(f"failed with time stretch for {neighbor_audio_path}, from {start_time} to {end_time} with duration factor {duration_factor}, which leads to {duration*duration_factor} seconds")
322
+ return None, None
323
+ assert encos.shape[1] == self.args.n_codebooks, f"encos.shape: {encos.shape}"
324
+ neighbor_enc = encos.cpu().squeeze(0).numpy().tolist() # [K, T]
325
+ return phn_token, neighbor_enc
326
+ else:
327
+ # get encodec codes from storage
328
+ with open(neighbor_enc_fn, "r") as f:
329
+ neighbor_enc = [l.strip().split() for l in f.readlines()]
330
+ if len(neighbor_enc) != self.args.n_codebooks:
331
+ # print(f"wrong number of codebooks for {neighbor_enc_fn}")
332
+ return None, None
333
+ else:
334
+ # trim the encodec codes to the segment
335
+ start_enc_frame = int(start_time * self.args.encodec_sr)
336
+ end_enc_frame = int(end_time * self.args.encodec_sr)
337
+ neighbor_enc = [l[start_enc_frame:end_enc_frame] for l in neighbor_enc]
338
+ if len(neighbor_enc[0]) == 0:
339
+ # print(f"no valid encodec codes found for {neighbor_enc_fn}")
340
+ return None, None
341
+ if self.args.special_first:
342
+ raise NotImplementedError
343
+ else:
344
+ neighbor_enc = [[int(n) for n in l] for l in neighbor_enc]
345
+ return phn_token, neighbor_enc
346
+
347
+ def __getitem__(self, index):
348
+ x, y, dataset_dir, audio_ext = self._load_phn_enc(index)
349
+ x_len, y_len = len(x), len(y[0])
350
+ extra_ret = {'x_sep_token_position': 0, 'y_sep_token_position': 0}
351
+ if x_len == 0 or y_len == 0:
352
+ ret = {
353
+ "x": None,
354
+ "x_len": None,
355
+ "y": None,
356
+ "y_len": None,
357
+ }
358
+ ret.update(extra_ret)
359
+ return ret
360
+ while y_len < self.args.encodec_sr*self.args.audio_min_length:
361
+ assert not self.args.dynamic_batching
362
+ index = random.choice(range(len(self))) # regenerate an index
363
+ x, y, dataset_dir, audio_ext = self._load_phn_enc(index)
364
+ x_len, y_len = len(x), len(y[0])
365
+
366
+ # if use neighbor prompt
367
+ x_neighbor, y_neighbor = None, None
368
+ use_neighbor_prob = random.random()
369
+ neighbor_fn = os.path.join(dataset_dir, self.args.neighbor_folder_name, self.data[index][0]+".txt")
370
+ if self.args.neighbor_prompt_prob > 0 and use_neighbor_prob < self.args.neighbor_prompt_prob and os.path.isfile(neighbor_fn): # it might not exist, just because we didn't find neighbor for this file (other than itself, which is common for emilia)
371
+ with open(neighbor_fn, "r") as f:
372
+ neighbors = [l.strip().split("\t") for l in f.readlines()]
373
+ # select neighbors
374
+ if "maxdist" in self.args.neighbor_selection_method:
375
+ maxdist = int(self.args.neighbor_selection_method.split("_")[-1])
376
+ # only keep neighbors with distance within maxdist
377
+ neighbors = [n for n in neighbors if float(n[1]) <= maxdist]
378
+ else:
379
+ raise NotImplementedError
380
+ x_neighbor, y_neighbor = None, None
381
+ if len(neighbors) > 0:
382
+ x_neighbor, y_neighbor = self.find_neighbor(neighbors, y_len, dataset_dir, audio_ext)
383
+ i_trial = 0
384
+ while x_neighbor is None and i_trial < self.args.num_trial and i_trial < len(neighbors):
385
+ x_neighbor, y_neighbor = self.find_neighbor(neighbors, y_len, dataset_dir, audio_ext)
386
+ i_trial += 1
387
+
388
+ if x_neighbor != None:
389
+ if self.args.x_sep_token != None:
390
+ x = x_neighbor + [self.args.x_sep_token] + x
391
+ else:
392
+ x = x_neighbor + x
393
+ if self.args.y_sep_token != None:
394
+ y = [y_neighbor[i] + [self.args.y_sep_token] + y[i] for i in range(len(y))]
395
+ else:
396
+ y = [y_neighbor[i] + y[i] for i in range(len(y))]
397
+ extra_ret['y_sep_token_position'] = len(y_neighbor[0]) + 1 # if using y_sep_token, this is actually the position of the token right before the y_sep_token, but since y_sep_token is ignored in loss computation, it's fine that we use the position of the token right before it
398
+ extra_ret['x_sep_token_position'] = len(x_neighbor) + 1
399
+ x_len, y_len = len(x), len(y[0])
400
+
401
+
402
+ # consider adding eos to the end of the text
403
+ if self.args.add_eos_to_text != 0:
404
+ x.append(self.args.add_eos_to_text)
405
+ x_len += 1
406
+ if getattr(self.args, "add_bos_to_text", 0) != 0:
407
+ x = [self.args.add_bos_to_text] + x
408
+ x_len += 1
409
+ ### padding and cropping ###
410
+ ### padding and cropping ###
411
+ # adjust the length of encodec codes, pad to max_len or randomly crop
412
+ orig_y_len = copy.copy(y_len)
413
+ max_len = int(self.args.audio_max_length * self.args.encodec_sr)
414
+ if y_len > max_len + 10: # give it some margin for rounding error
415
+ raise RuntimeError(f"audio is too long, {y_len=}, {max_len=}")
416
+ else:
417
+ audio_start = 0
418
+ if not self.args.dynamic_batching:
419
+ pad = [0] * (max_len - y_len) if self.args.sep_special_token else [self.args.audio_pad_token] * (max_len - y_len)
420
+ for i in range(len(y)):
421
+ y[i] = y[i] + pad
422
+
423
+ if self.args.pad_x and x_len <= self.args.text_max_length:
424
+ pad = [0] * (self.args.text_max_length - x_len) if self.args.sep_special_token else [self.args.text_pad_token] * (self.args.text_max_length - x_len)
425
+ x = x + pad
426
+
427
+ ret = {
428
+ "x": torch.LongTensor(x),
429
+ "x_len": x_len,
430
+ "y": torch.LongTensor(y),
431
+ "y_len": y_len,
432
+ }
433
+ ret.update(extra_ret)
434
+
435
+ return ret
436
+
437
+
438
+ def collate(self, batch):
439
+ # make sure keys in every batch is the same
440
+ for batch1, batch2 in zip(batch[:-1], batch[1:]):
441
+ assert set(batch1.keys()) == set(batch2.keys()), f"keys in batch1: {batch1.keys()} and keys in batch2: {batch2.keys()} are different"
442
+ out = {key:[] for key in batch[0]}
443
+ for item in batch:
444
+ if item['x'] == None: # deal with load failure
445
+ continue
446
+ for key, val in item.items():
447
+ out[key].append(val)
448
+ res = {}
449
+ if self.args.pad_x:
450
+ res["x"] = torch.stack(out["x"], dim=0)
451
+ else:
452
+ res["x"] = torch.nn.utils.rnn.pad_sequence(out["x"], batch_first=True, padding_value=self.args.text_pad_token)
453
+ res["x_lens"] = torch.LongTensor(out["x_len"])
454
+ if self.args.dynamic_batching:
455
+ res['y'] = torch.nn.utils.rnn.pad_sequence([item.transpose(1,0) for item in out['y']],padding_value=self.args.audio_pad_token)
456
+ res['y'] = res['y'].permute(1,2,0) # T B K -> B K T
457
+ else:
458
+ res['y'] = torch.stack(out['y'], dim=0)
459
+ res["y_lens"] = torch.LongTensor(out["y_len"])
460
+ res["text_padding_mask"] = torch.arange(res['x'][0].shape[-1]).unsqueeze(0) >= res['x_lens'].unsqueeze(1)
461
+ res["audio_padding_mask"] = torch.arange(res['y'][0].shape[-1]).unsqueeze(0) >= res['y_lens'].unsqueeze(1)
462
+ if "y_sep_token_position" in out:
463
+ res["y_sep_token_position"] = torch.LongTensor(out["y_sep_token_position"])
464
+ if "x_sep_token_position" in out:
465
+ res["x_sep_token_position"] = torch.LongTensor(out["x_sep_token_position"])
466
+ return res
data/emilia_preprocessing/delete_tar_files.sh ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Define the root directory where the tar files are located
4
+ root=${root:-/data/scratch/pyp/datasets/emilia/downloads} # Example: /data/scratch/pyp/datasets/emilia/downloads
5
+ exist_log_file="file_log_debug.txt" # Log of files to delete
6
+ delete_log="deleted_files.log" # Log of successfully deleted files
7
+ error_log="delete_errors.log" # Log of errors (e.g., missing files)
8
+
9
+ # Clear previous logs
10
+ > "$delete_log"
11
+ > "$error_log"
12
+
13
+ echo "Starting deletion of tar files listed in $exist_log_file..."
14
+
15
+ # Loop through each line in exist_log_file
16
+ while IFS=',' read -r filename size local_sha256 original_filename url; do
17
+ # Trim leading/trailing whitespace
18
+ original_filename=$(echo "$original_filename" | xargs)
19
+
20
+ # Construct the full path to the tar file
21
+ tar_file="${root}/${original_filename}"
22
+
23
+ # Check if the tar file exists
24
+ if [ -f "$tar_file" ]; then
25
+ # Attempt to delete the file
26
+ if rm -f "$tar_file"; then
27
+ echo "✅ Deleted: $tar_file"
28
+ echo "$tar_file" >> "$delete_log"
29
+ else
30
+ echo "❌ Failed to delete: $tar_file"
31
+ echo "$tar_file" >> "$error_log"
32
+ fi
33
+ else
34
+ # Log missing files
35
+ echo "❌ File not found: $tar_file"
36
+ echo "$tar_file" >> "$error_log"
37
+ fi
38
+ done < "$exist_log_file"
39
+
40
+ echo "Deletion process completed."
41
+ echo "Deleted files are logged in $delete_log."
42
+ echo "Errors (if any) are logged in $error_log."
data/emilia_preprocessing/encodec.py ADDED
@@ -0,0 +1,1554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Compression models or wrapper around existing models.
7
+ Also defines the main interface that a model must follow to be usable as an audio tokenizer.
8
+ """
9
+
10
+ from abc import ABC, abstractmethod
11
+ from dataclasses import dataclass, field
12
+ import logging
13
+ import math
14
+ from pathlib import Path
15
+ import typing as tp
16
+
17
+ import numpy as np
18
+ import torch
19
+ from torch import nn
20
+ from torch import einsum
21
+ import torch.nn.functional as F
22
+ from torch.nn.utils import spectral_norm, weight_norm
23
+
24
+ import logging
25
+ import warnings
26
+ from einops import rearrange, repeat
27
+ import omegaconf
28
+ # import flashy
29
+
30
+ CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
31
+ 'time_group_norm'])
32
+
33
+ def dict_from_config(cfg: omegaconf.DictConfig) -> dict:
34
+ """Convenience function to map an omegaconf configuration to a dictionary.
35
+
36
+ Args:
37
+ cfg (omegaconf.DictConfig): Original configuration to map to dict.
38
+ Returns:
39
+ dict: Config as dictionary object.
40
+ """
41
+ dct = omegaconf.OmegaConf.to_container(cfg, resolve=True)
42
+ assert isinstance(dct, dict)
43
+ return dct
44
+
45
+ @dataclass
46
+ class QuantizedResult:
47
+ x: torch.Tensor
48
+ codes: torch.Tensor
49
+ bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
50
+ penalty: tp.Optional[torch.Tensor] = None
51
+ metrics: dict = field(default_factory=dict)
52
+
53
+ class BaseQuantizer(nn.Module):
54
+ """Base class for quantizers.
55
+ """
56
+
57
+ def forward(self, x: torch.Tensor, frame_rate: int) -> QuantizedResult:
58
+ """
59
+ Given input tensor x, returns first the quantized (or approximately quantized)
60
+ representation along with quantized codes, bandwidth, and any penalty term for the loss.
61
+ Finally, this returns a dict of metrics to update logging etc.
62
+ Frame rate must be passed so that the bandwidth is properly computed.
63
+ """
64
+ raise NotImplementedError()
65
+
66
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
67
+ """Encode a given input tensor with the specified sample rate at the given bandwidth."""
68
+ raise NotImplementedError()
69
+
70
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
71
+ """Decode the given codes to the quantized representation."""
72
+ raise NotImplementedError()
73
+
74
+ @property
75
+ def total_codebooks(self):
76
+ """Total number of codebooks."""
77
+ raise NotImplementedError()
78
+
79
+ @property
80
+ def num_codebooks(self):
81
+ """Number of active codebooks."""
82
+ raise NotImplementedError()
83
+
84
+ def set_num_codebooks(self, n: int):
85
+ """Set the number of active codebooks."""
86
+ raise NotImplementedError()
87
+
88
+ class CompressionModel(ABC, nn.Module):
89
+ """Base API for all compression model that aim at being used as audio tokenizers
90
+ with a language model.
91
+ """
92
+
93
+ @abstractmethod
94
+ def forward(self, x: torch.Tensor) -> QuantizedResult:
95
+ ...
96
+
97
+ @abstractmethod
98
+ def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
99
+ """See `EncodecModel.encode`."""
100
+ ...
101
+
102
+ @abstractmethod
103
+ def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
104
+ """See `EncodecModel.decode`."""
105
+ ...
106
+
107
+ @abstractmethod
108
+ def decode_latent(self, codes: torch.Tensor):
109
+ """Decode from the discrete codes to continuous latent space."""
110
+ ...
111
+
112
+ @property
113
+ @abstractmethod
114
+ def channels(self) -> int:
115
+ ...
116
+
117
+ @property
118
+ @abstractmethod
119
+ def frame_rate(self) -> float:
120
+ ...
121
+
122
+ @property
123
+ @abstractmethod
124
+ def sample_rate(self) -> int:
125
+ ...
126
+
127
+ @property
128
+ @abstractmethod
129
+ def cardinality(self) -> int:
130
+ ...
131
+
132
+ @property
133
+ @abstractmethod
134
+ def num_codebooks(self) -> int:
135
+ ...
136
+
137
+ @property
138
+ @abstractmethod
139
+ def total_codebooks(self) -> int:
140
+ ...
141
+
142
+ @abstractmethod
143
+ def set_num_codebooks(self, n: int):
144
+ """Set the active number of codebooks used by the quantizer."""
145
+ ...
146
+
147
+ def apply_parametrization_norm(module: nn.Module, norm: str = 'none'):
148
+ assert norm in CONV_NORMALIZATIONS
149
+ if norm == 'weight_norm':
150
+ return weight_norm(module)
151
+ elif norm == 'spectral_norm':
152
+ return spectral_norm(module)
153
+ else:
154
+ # We already check was in CONV_NORMALIZATION, so any other choice
155
+ # doesn't need reparametrization.
156
+ return module
157
+
158
+
159
+ def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs):
160
+ """Return the proper normalization module. If causal is True, this will ensure the returned
161
+ module is causal, or return an error if the normalization doesn't support causal evaluation.
162
+ """
163
+ assert norm in CONV_NORMALIZATIONS
164
+ if norm == 'time_group_norm':
165
+ if causal:
166
+ raise ValueError("GroupNorm doesn't support causal evaluation.")
167
+ assert isinstance(module, nn.modules.conv._ConvNd)
168
+ return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
169
+ else:
170
+ return nn.Identity()
171
+
172
+
173
+ def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
174
+ padding_total: int = 0) -> int:
175
+ """See `pad_for_conv1d`."""
176
+ length = x.shape[-1]
177
+ n_frames = (length - kernel_size + padding_total) / stride + 1
178
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
179
+ return ideal_length - length
180
+
181
+
182
+ def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
183
+ """Pad for a convolution to make sure that the last window is full.
184
+ Extra padding is added at the end. This is required to ensure that we can rebuild
185
+ an output of the same length, as otherwise, even with padding, some time steps
186
+ might get removed.
187
+ For instance, with total padding = 4, kernel size = 4, stride = 2:
188
+ 0 0 1 2 3 4 5 0 0 # (0s are padding)
189
+ 1 2 3 # (output frames of a convolution, last 0 is never used)
190
+ 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
191
+ 1 2 3 4 # once you removed padding, we are missing one time step !
192
+ """
193
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
194
+ return F.pad(x, (0, extra_padding))
195
+
196
+
197
+ def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
198
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
199
+ If this is the case, we insert extra 0 padding to the right before the reflection happen.
200
+ """
201
+ length = x.shape[-1]
202
+ padding_left, padding_right = paddings
203
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
204
+ if mode == 'reflect':
205
+ max_pad = max(padding_left, padding_right)
206
+ extra_pad = 0
207
+ if length <= max_pad:
208
+ extra_pad = max_pad - length + 1
209
+ x = F.pad(x, (0, extra_pad))
210
+ padded = F.pad(x, paddings, mode, value)
211
+ end = padded.shape[-1] - extra_pad
212
+ return padded[..., :end]
213
+ else:
214
+ return F.pad(x, paddings, mode, value)
215
+
216
+
217
+ def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
218
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
219
+ padding_left, padding_right = paddings
220
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
221
+ assert (padding_left + padding_right) <= x.shape[-1]
222
+ end = x.shape[-1] - padding_right
223
+ return x[..., padding_left: end]
224
+
225
+
226
+ class NormConv1d(nn.Module):
227
+ """Wrapper around Conv1d and normalization applied to this conv
228
+ to provide a uniform interface across normalization approaches.
229
+ """
230
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
231
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
232
+ super().__init__()
233
+ self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
234
+ self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
235
+ self.norm_type = norm
236
+
237
+ def forward(self, x):
238
+ x = self.conv(x)
239
+ x = self.norm(x)
240
+ return x
241
+
242
+
243
+ class NormConv2d(nn.Module):
244
+ """Wrapper around Conv2d and normalization applied to this conv
245
+ to provide a uniform interface across normalization approaches.
246
+ """
247
+ def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
248
+ super().__init__()
249
+ self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
250
+ self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
251
+ self.norm_type = norm
252
+
253
+ def forward(self, x):
254
+ x = self.conv(x)
255
+ x = self.norm(x)
256
+ return x
257
+
258
+
259
+ class NormConvTranspose1d(nn.Module):
260
+ """Wrapper around ConvTranspose1d and normalization applied to this conv
261
+ to provide a uniform interface across normalization approaches.
262
+ """
263
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
264
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
265
+ super().__init__()
266
+ self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
267
+ self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
268
+ self.norm_type = norm
269
+
270
+ def forward(self, x):
271
+ x = self.convtr(x)
272
+ x = self.norm(x)
273
+ return x
274
+
275
+
276
+ class NormConvTranspose2d(nn.Module):
277
+ """Wrapper around ConvTranspose2d and normalization applied to this conv
278
+ to provide a uniform interface across normalization approaches.
279
+ """
280
+ def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
281
+ super().__init__()
282
+ self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
283
+ self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
284
+
285
+ def forward(self, x):
286
+ x = self.convtr(x)
287
+ x = self.norm(x)
288
+ return x
289
+
290
+
291
+ class StreamableConv1d(nn.Module):
292
+ """Conv1d with some builtin handling of asymmetric or causal padding
293
+ and normalization.
294
+ """
295
+ def __init__(self, in_channels: int, out_channels: int,
296
+ kernel_size: int, stride: int = 1, dilation: int = 1,
297
+ groups: int = 1, bias: bool = True, causal: bool = False,
298
+ norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
299
+ pad_mode: str = 'reflect'):
300
+ super().__init__()
301
+ # warn user on unusual setup between dilation and stride
302
+ if stride > 1 and dilation > 1:
303
+ warnings.warn("StreamableConv1d has been initialized with stride > 1 and dilation > 1"
304
+ f" (kernel_size={kernel_size} stride={stride}, dilation={dilation}).")
305
+ self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
306
+ dilation=dilation, groups=groups, bias=bias, causal=causal,
307
+ norm=norm, norm_kwargs=norm_kwargs)
308
+ self.causal = causal
309
+ self.pad_mode = pad_mode
310
+
311
+ def forward(self, x):
312
+ B, C, T = x.shape
313
+ kernel_size = self.conv.conv.kernel_size[0]
314
+ stride = self.conv.conv.stride[0]
315
+ dilation = self.conv.conv.dilation[0]
316
+ kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
317
+ padding_total = kernel_size - stride
318
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
319
+ if self.causal:
320
+ # Left padding for causal
321
+ x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
322
+ else:
323
+ # Asymmetric padding required for odd strides
324
+ padding_right = padding_total // 2
325
+ padding_left = padding_total - padding_right
326
+ x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
327
+ return self.conv(x)
328
+
329
+
330
+ class StreamableConvTranspose1d(nn.Module):
331
+ """ConvTranspose1d with some builtin handling of asymmetric or causal padding
332
+ and normalization.
333
+ """
334
+ def __init__(self, in_channels: int, out_channels: int,
335
+ kernel_size: int, stride: int = 1, causal: bool = False,
336
+ norm: str = 'none', trim_right_ratio: float = 1.,
337
+ norm_kwargs: tp.Dict[str, tp.Any] = {}):
338
+ super().__init__()
339
+ self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
340
+ causal=causal, norm=norm, norm_kwargs=norm_kwargs)
341
+ self.causal = causal
342
+ self.trim_right_ratio = trim_right_ratio
343
+ assert self.causal or self.trim_right_ratio == 1., \
344
+ "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
345
+ assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
346
+
347
+ def forward(self, x):
348
+ kernel_size = self.convtr.convtr.kernel_size[0]
349
+ stride = self.convtr.convtr.stride[0]
350
+ padding_total = kernel_size - stride
351
+
352
+ y = self.convtr(x)
353
+
354
+ # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
355
+ # removed at the very end, when keeping only the right length for the output,
356
+ # as removing it here would require also passing the length at the matching layer
357
+ # in the encoder.
358
+ if self.causal:
359
+ # Trim the padding on the right according to the specified ratio
360
+ # if trim_right_ratio = 1.0, trim everything from right
361
+ padding_right = math.ceil(padding_total * self.trim_right_ratio)
362
+ padding_left = padding_total - padding_right
363
+ y = unpad1d(y, (padding_left, padding_right))
364
+ else:
365
+ # Asymmetric padding required for odd strides
366
+ padding_right = padding_total // 2
367
+ padding_left = padding_total - padding_right
368
+ y = unpad1d(y, (padding_left, padding_right))
369
+ return y
370
+
371
+
372
+ class StreamableLSTM(nn.Module):
373
+ """LSTM without worrying about the hidden state, nor the layout of the data.
374
+ Expects input as convolutional layout.
375
+ """
376
+ def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
377
+ super().__init__()
378
+ self.skip = skip
379
+ self.lstm = nn.LSTM(dimension, dimension, num_layers)
380
+
381
+ def forward(self, x):
382
+ x = x.permute(2, 0, 1)
383
+ y, _ = self.lstm(x)
384
+ if self.skip:
385
+ y = y + x
386
+ y = y.permute(1, 2, 0)
387
+ return y
388
+
389
+
390
+ class SEANetResnetBlock(nn.Module):
391
+ """Residual block from SEANet model.
392
+
393
+ Args:
394
+ dim (int): Dimension of the input/output.
395
+ kernel_sizes (list): List of kernel sizes for the convolutions.
396
+ dilations (list): List of dilations for the convolutions.
397
+ activation (str): Activation function.
398
+ activation_params (dict): Parameters to provide to the activation function.
399
+ norm (str): Normalization method.
400
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
401
+ causal (bool): Whether to use fully causal convolution.
402
+ pad_mode (str): Padding mode for the convolutions.
403
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3).
404
+ true_skip (bool): Whether to use true skip connection or a simple
405
+ (streamable) convolution as the skip connection.
406
+ """
407
+ def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1],
408
+ activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
409
+ norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False,
410
+ pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True):
411
+ super().__init__()
412
+ assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations'
413
+ act = getattr(nn, activation)
414
+ hidden = dim // compress
415
+ block = []
416
+ for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
417
+ in_chs = dim if i == 0 else hidden
418
+ out_chs = dim if i == len(kernel_sizes) - 1 else hidden
419
+ block += [
420
+ act(**activation_params),
421
+ StreamableConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation,
422
+ norm=norm, norm_kwargs=norm_params,
423
+ causal=causal, pad_mode=pad_mode),
424
+ ]
425
+ self.block = nn.Sequential(*block)
426
+ self.shortcut: nn.Module
427
+ if true_skip:
428
+ self.shortcut = nn.Identity()
429
+ else:
430
+ self.shortcut = StreamableConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params,
431
+ causal=causal, pad_mode=pad_mode)
432
+
433
+ def forward(self, x):
434
+ return self.shortcut(x) + self.block(x)
435
+
436
+
437
+ class SEANetEncoder(nn.Module):
438
+ """SEANet encoder.
439
+
440
+ Args:
441
+ channels (int): Audio channels.
442
+ dimension (int): Intermediate representation dimension.
443
+ n_filters (int): Base width for the model.
444
+ n_residual_layers (int): nb of residual layers.
445
+ ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
446
+ upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
447
+ that must match the decoder order. We use the decoder order as some models may only employ the decoder.
448
+ activation (str): Activation function.
449
+ activation_params (dict): Parameters to provide to the activation function.
450
+ norm (str): Normalization method.
451
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
452
+ kernel_size (int): Kernel size for the initial convolution.
453
+ last_kernel_size (int): Kernel size for the initial convolution.
454
+ residual_kernel_size (int): Kernel size for the residual layers.
455
+ dilation_base (int): How much to increase the dilation with each layer.
456
+ causal (bool): Whether to use fully causal convolution.
457
+ pad_mode (str): Padding mode for the convolutions.
458
+ true_skip (bool): Whether to use true skip connection or a simple
459
+ (streamable) convolution as the skip connection in the residual network blocks.
460
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3).
461
+ lstm (int): Number of LSTM layers at the end of the encoder.
462
+ disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
463
+ For the encoder, it corresponds to the N first blocks.
464
+ """
465
+ def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
466
+ ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
467
+ norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
468
+ last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
469
+ pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0,
470
+ disable_norm_outer_blocks: int = 0):
471
+ super().__init__()
472
+ self.channels = channels
473
+ self.dimension = dimension
474
+ self.n_filters = n_filters
475
+ self.ratios = list(reversed(ratios))
476
+ del ratios
477
+ self.n_residual_layers = n_residual_layers
478
+ self.hop_length = np.prod(self.ratios)
479
+ self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
480
+ self.disable_norm_outer_blocks = disable_norm_outer_blocks
481
+ assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \
482
+ "Number of blocks for which to disable norm is invalid." \
483
+ "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
484
+
485
+ act = getattr(nn, activation)
486
+ mult = 1
487
+ model: tp.List[nn.Module] = [
488
+ StreamableConv1d(channels, mult * n_filters, kernel_size,
489
+ norm='none' if self.disable_norm_outer_blocks >= 1 else norm,
490
+ norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
491
+ ]
492
+ # Downsample to raw audio scale
493
+ for i, ratio in enumerate(self.ratios):
494
+ block_norm = 'none' if self.disable_norm_outer_blocks >= i + 2 else norm
495
+ # Add residual layers
496
+ for j in range(n_residual_layers):
497
+ model += [
498
+ SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1],
499
+ dilations=[dilation_base ** j, 1],
500
+ norm=block_norm, norm_params=norm_params,
501
+ activation=activation, activation_params=activation_params,
502
+ causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
503
+
504
+ # Add downsampling layers
505
+ model += [
506
+ act(**activation_params),
507
+ StreamableConv1d(mult * n_filters, mult * n_filters * 2,
508
+ kernel_size=ratio * 2, stride=ratio,
509
+ norm=block_norm, norm_kwargs=norm_params,
510
+ causal=causal, pad_mode=pad_mode),
511
+ ]
512
+ mult *= 2
513
+
514
+ if lstm:
515
+ model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
516
+
517
+ model += [
518
+ act(**activation_params),
519
+ StreamableConv1d(mult * n_filters, dimension, last_kernel_size,
520
+ norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm,
521
+ norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
522
+ ]
523
+
524
+ self.model = nn.Sequential(*model)
525
+
526
+ def forward(self, x):
527
+ return self.model(x)
528
+
529
+
530
+ class SEANetDecoder(nn.Module):
531
+ """SEANet decoder.
532
+
533
+ Args:
534
+ channels (int): Audio channels.
535
+ dimension (int): Intermediate representation dimension.
536
+ n_filters (int): Base width for the model.
537
+ n_residual_layers (int): nb of residual layers.
538
+ ratios (Sequence[int]): kernel size and stride ratios.
539
+ activation (str): Activation function.
540
+ activation_params (dict): Parameters to provide to the activation function.
541
+ final_activation (str): Final activation function after all convolutions.
542
+ final_activation_params (dict): Parameters to provide to the activation function.
543
+ norm (str): Normalization method.
544
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
545
+ kernel_size (int): Kernel size for the initial convolution.
546
+ last_kernel_size (int): Kernel size for the initial convolution.
547
+ residual_kernel_size (int): Kernel size for the residual layers.
548
+ dilation_base (int): How much to increase the dilation with each layer.
549
+ causal (bool): Whether to use fully causal convolution.
550
+ pad_mode (str): Padding mode for the convolutions.
551
+ true_skip (bool): Whether to use true skip connection or a simple.
552
+ (streamable) convolution as the skip connection in the residual network blocks.
553
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3).
554
+ lstm (int): Number of LSTM layers at the end of the encoder.
555
+ disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
556
+ For the decoder, it corresponds to the N last blocks.
557
+ trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
558
+ If equal to 1.0, it means that all the trimming is done at the right.
559
+ """
560
+ def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
561
+ ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
562
+ final_activation: tp.Optional[str] = None, final_activation_params: tp.Optional[dict] = None,
563
+ norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
564
+ last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
565
+ pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0,
566
+ disable_norm_outer_blocks: int = 0, trim_right_ratio: float = 1.0):
567
+ super().__init__()
568
+ self.dimension = dimension
569
+ self.channels = channels
570
+ self.n_filters = n_filters
571
+ self.ratios = ratios
572
+ del ratios
573
+ self.n_residual_layers = n_residual_layers
574
+ self.hop_length = np.prod(self.ratios)
575
+ self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
576
+ self.disable_norm_outer_blocks = disable_norm_outer_blocks
577
+ assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \
578
+ "Number of blocks for which to disable norm is invalid." \
579
+ "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
580
+
581
+ act = getattr(nn, activation)
582
+ mult = int(2 ** len(self.ratios))
583
+ model: tp.List[nn.Module] = [
584
+ StreamableConv1d(dimension, mult * n_filters, kernel_size,
585
+ norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm,
586
+ norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
587
+ ]
588
+
589
+ if lstm:
590
+ model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
591
+
592
+ # Upsample to raw audio scale
593
+ for i, ratio in enumerate(self.ratios):
594
+ block_norm = 'none' if self.disable_norm_outer_blocks >= self.n_blocks - (i + 1) else norm
595
+ # Add upsampling layers
596
+ model += [
597
+ act(**activation_params),
598
+ StreamableConvTranspose1d(mult * n_filters, mult * n_filters // 2,
599
+ kernel_size=ratio * 2, stride=ratio,
600
+ norm=block_norm, norm_kwargs=norm_params,
601
+ causal=causal, trim_right_ratio=trim_right_ratio),
602
+ ]
603
+ # Add residual layers
604
+ for j in range(n_residual_layers):
605
+ model += [
606
+ SEANetResnetBlock(mult * n_filters // 2, kernel_sizes=[residual_kernel_size, 1],
607
+ dilations=[dilation_base ** j, 1],
608
+ activation=activation, activation_params=activation_params,
609
+ norm=block_norm, norm_params=norm_params, causal=causal,
610
+ pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
611
+
612
+ mult //= 2
613
+
614
+ # Add final layers
615
+ model += [
616
+ act(**activation_params),
617
+ StreamableConv1d(n_filters, channels, last_kernel_size,
618
+ norm='none' if self.disable_norm_outer_blocks >= 1 else norm,
619
+ norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
620
+ ]
621
+ # Add optional final activation to decoder (eg. tanh)
622
+ if final_activation is not None:
623
+ final_act = getattr(nn, final_activation)
624
+ final_activation_params = final_activation_params or {}
625
+ model += [
626
+ final_act(**final_activation_params)
627
+ ]
628
+ self.model = nn.Sequential(*model)
629
+
630
+ def forward(self, z):
631
+ y = self.model(z)
632
+ return y
633
+
634
+
635
+ def exists(val: tp.Optional[tp.Any]) -> bool:
636
+ return val is not None
637
+
638
+
639
+ def default(val: tp.Any, d: tp.Any) -> tp.Any:
640
+ return val if exists(val) else d
641
+
642
+
643
+ def l2norm(t):
644
+ return F.normalize(t, p=2, dim=-1)
645
+
646
+
647
+ def ema_inplace(moving_avg, new, decay: float):
648
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
649
+
650
+
651
+ def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
652
+ return (x + epsilon) / (x.sum() + n_categories * epsilon)
653
+
654
+
655
+ def uniform_init(*shape: int):
656
+ t = torch.empty(shape)
657
+ nn.init.kaiming_uniform_(t)
658
+ return t
659
+
660
+
661
+ def sample_vectors(samples, num: int):
662
+ num_samples, device = samples.shape[0], samples.device
663
+
664
+ if num_samples >= num:
665
+ indices = torch.randperm(num_samples, device=device)[:num]
666
+ else:
667
+ indices = torch.randint(0, num_samples, (num,), device=device)
668
+
669
+ return samples[indices]
670
+
671
+
672
+ def kmeans(samples, num_clusters: int, num_iters: int = 10):
673
+ dim, dtype = samples.shape[-1], samples.dtype
674
+
675
+ means = sample_vectors(samples, num_clusters)
676
+
677
+ for _ in range(num_iters):
678
+ diffs = rearrange(samples, "n d -> n () d") - rearrange(
679
+ means, "c d -> () c d"
680
+ )
681
+ dists = -(diffs ** 2).sum(dim=-1)
682
+
683
+ buckets = dists.max(dim=-1).indices
684
+ bins = torch.bincount(buckets, minlength=num_clusters)
685
+ zero_mask = bins == 0
686
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
687
+
688
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
689
+ new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
690
+ new_means = new_means / bins_min_clamped[..., None]
691
+
692
+ means = torch.where(zero_mask[..., None], means, new_means)
693
+
694
+ return means, bins
695
+
696
+
697
+ def orthogonal_loss_fn(t):
698
+ # eq (2) from https://arxiv.org/abs/2112.00384
699
+ n = t.shape[0]
700
+ normed_codes = l2norm(t)
701
+ identity = torch.eye(n, device=t.device)
702
+ cosine_sim = einsum("i d, j d -> i j", normed_codes, normed_codes)
703
+ return ((cosine_sim - identity) ** 2).sum() / (n ** 2)
704
+
705
+
706
+ class EuclideanCodebook(nn.Module):
707
+ """Codebook with Euclidean distance.
708
+
709
+ Args:
710
+ dim (int): Dimension.
711
+ codebook_size (int): Codebook size.
712
+ kmeans_init (bool): Whether to use k-means to initialize the codebooks.
713
+ If set to true, run the k-means algorithm on the first training batch and use
714
+ the learned centroids as initialization.
715
+ kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
716
+ decay (float): Decay for exponential moving average over the codebooks.
717
+ epsilon (float): Epsilon value for numerical stability.
718
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
719
+ that have an exponential moving average cluster size less than the specified threshold with
720
+ randomly selected vector from the current batch.
721
+ """
722
+ def __init__(
723
+ self,
724
+ dim: int,
725
+ codebook_size: int,
726
+ kmeans_init: int = False,
727
+ kmeans_iters: int = 10,
728
+ decay: float = 0.8,
729
+ epsilon: float = 1e-5,
730
+ threshold_ema_dead_code: int = 2,
731
+ ):
732
+ super().__init__()
733
+ self.decay = decay
734
+ init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
735
+ embed = init_fn(codebook_size, dim)
736
+
737
+ self.codebook_size = codebook_size
738
+
739
+ self.kmeans_iters = kmeans_iters
740
+ self.epsilon = epsilon
741
+ self.threshold_ema_dead_code = threshold_ema_dead_code
742
+
743
+ self.register_buffer("inited", torch.Tensor([not kmeans_init]))
744
+ self.register_buffer("cluster_size", torch.zeros(codebook_size))
745
+ self.register_buffer("embed", embed)
746
+ self.register_buffer("embed_avg", embed.clone())
747
+
748
+ @torch.jit.ignore
749
+ def init_embed_(self, data):
750
+ if self.inited:
751
+ return
752
+
753
+ embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
754
+ self.embed.data.copy_(embed)
755
+ self.embed_avg.data.copy_(embed.clone())
756
+ self.cluster_size.data.copy_(cluster_size)
757
+ self.inited.data.copy_(torch.Tensor([True]))
758
+ # Make sure all buffers across workers are in sync after initialization
759
+ flashy.distrib.broadcast_tensors(self.buffers())
760
+
761
+ def replace_(self, samples, mask):
762
+ modified_codebook = torch.where(
763
+ mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
764
+ )
765
+ self.embed.data.copy_(modified_codebook)
766
+
767
+ def expire_codes_(self, batch_samples):
768
+ if self.threshold_ema_dead_code == 0:
769
+ return
770
+
771
+ expired_codes = self.cluster_size < self.threshold_ema_dead_code
772
+ if not torch.any(expired_codes):
773
+ return
774
+
775
+ batch_samples = rearrange(batch_samples, "... d -> (...) d")
776
+ self.replace_(batch_samples, mask=expired_codes)
777
+ flashy.distrib.broadcast_tensors(self.buffers())
778
+
779
+ def preprocess(self, x):
780
+ x = rearrange(x, "... d -> (...) d")
781
+ return x
782
+
783
+ def quantize(self, x):
784
+ embed = self.embed.t()
785
+ dist = -(
786
+ x.pow(2).sum(1, keepdim=True)
787
+ - 2 * x @ embed
788
+ + embed.pow(2).sum(0, keepdim=True)
789
+ )
790
+ embed_ind = dist.max(dim=-1).indices
791
+ return embed_ind
792
+
793
+ def postprocess_emb(self, embed_ind, shape):
794
+ return embed_ind.view(*shape[:-1])
795
+
796
+ def dequantize(self, embed_ind):
797
+ quantize = F.embedding(embed_ind, self.embed)
798
+ return quantize
799
+
800
+ def encode(self, x):
801
+ shape = x.shape
802
+ # pre-process
803
+ x = self.preprocess(x)
804
+ # quantize
805
+ embed_ind = self.quantize(x)
806
+ # post-process
807
+ embed_ind = self.postprocess_emb(embed_ind, shape)
808
+ return embed_ind
809
+
810
+ def decode(self, embed_ind):
811
+ quantize = self.dequantize(embed_ind)
812
+ return quantize
813
+
814
+ def forward(self, x):
815
+ raise NotImplementedError()
816
+ shape, dtype = x.shape, x.dtype
817
+ x = self.preprocess(x)
818
+ self.init_embed_(x)
819
+
820
+ embed_ind = self.quantize(x)
821
+ embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
822
+ embed_ind = self.postprocess_emb(embed_ind, shape)
823
+ quantize = self.dequantize(embed_ind)
824
+
825
+ if self.training:
826
+ # We do the expiry of code at that point as buffers are in sync
827
+ # and all the workers will take the same decision.
828
+ self.expire_codes_(x)
829
+ ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
830
+ embed_sum = x.t() @ embed_onehot
831
+ ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
832
+ cluster_size = (
833
+ laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
834
+ * self.cluster_size.sum()
835
+ )
836
+ embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
837
+ self.embed.data.copy_(embed_normalized)
838
+
839
+ return quantize, embed_ind
840
+
841
+
842
+ class VectorQuantization(nn.Module):
843
+ """Vector quantization implementation.
844
+ Currently supports only euclidean distance.
845
+
846
+ Args:
847
+ dim (int): Dimension
848
+ codebook_size (int): Codebook size
849
+ codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
850
+ decay (float): Decay for exponential moving average over the codebooks.
851
+ epsilon (float): Epsilon value for numerical stability.
852
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
853
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
854
+ threshold_ema_dead_code (int):
855
+ channels_last (bool): Channels are the last dimension in the input tensors.
856
+ commitment_weight (float): Weight for commitment loss.
857
+ orthogonal_reg_weight (float): Orthogonal regularization weights.
858
+ orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes.
859
+ orthogonal_reg_max_codes (optional int): Maximum number of codes to consider
860
+ for orthogonal regularization.
861
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
862
+ that have an exponential moving average cluster size less than the specified threshold with
863
+ randomly selected vector from the current batch.
864
+ """
865
+ def __init__(
866
+ self,
867
+ dim: int,
868
+ codebook_size: int,
869
+ codebook_dim: tp.Optional[int] = None,
870
+ decay: float = 0.8,
871
+ epsilon: float = 1e-5,
872
+ kmeans_init: bool = False,
873
+ kmeans_iters: int = 10,
874
+ threshold_ema_dead_code: int = 2,
875
+ channels_last: bool = False,
876
+ commitment_weight: float = 1.,
877
+ orthogonal_reg_weight: float = 0.0,
878
+ orthogonal_reg_active_codes_only: bool = False,
879
+ orthogonal_reg_max_codes: tp.Optional[int] = None,
880
+ ):
881
+ super().__init__()
882
+ _codebook_dim: int = default(codebook_dim, dim)
883
+
884
+ requires_projection = _codebook_dim != dim
885
+ self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity())
886
+ self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity())
887
+
888
+ self.epsilon = epsilon
889
+ self.commitment_weight = commitment_weight
890
+
891
+ self.orthogonal_reg_weight = orthogonal_reg_weight
892
+ self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
893
+ self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
894
+
895
+ self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
896
+ kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
897
+ decay=decay, epsilon=epsilon,
898
+ threshold_ema_dead_code=threshold_ema_dead_code)
899
+ self.codebook_size = codebook_size
900
+
901
+ self.channels_last = channels_last
902
+
903
+ @property
904
+ def codebook(self):
905
+ return self._codebook.embed
906
+
907
+ @property
908
+ def inited(self):
909
+ return self._codebook.inited
910
+
911
+ def _preprocess(self, x):
912
+ if not self.channels_last:
913
+ x = rearrange(x, "b d n -> b n d")
914
+ return x
915
+
916
+ def _postprocess(self, quantize):
917
+ if not self.channels_last:
918
+ quantize = rearrange(quantize, "b n d -> b d n")
919
+ return quantize
920
+
921
+ def encode(self, x):
922
+ x = self._preprocess(x)
923
+ x = self.project_in(x)
924
+ embed_in = self._codebook.encode(x)
925
+ return embed_in
926
+
927
+ def decode(self, embed_ind):
928
+ quantize = self._codebook.decode(embed_ind)
929
+ quantize = self.project_out(quantize)
930
+ quantize = self._postprocess(quantize)
931
+ return quantize
932
+
933
+ def forward(self, x):
934
+ device = x.device
935
+ x = self._preprocess(x)
936
+
937
+ x = self.project_in(x)
938
+ quantize, embed_ind = self._codebook(x)
939
+
940
+ if self.training:
941
+ quantize = x + (quantize - x).detach()
942
+
943
+ loss = torch.tensor([0.0], device=device, requires_grad=self.training)
944
+
945
+ if self.training:
946
+ if self.commitment_weight > 0:
947
+ commit_loss = F.mse_loss(quantize.detach(), x)
948
+ loss = loss + commit_loss * self.commitment_weight
949
+
950
+ if self.orthogonal_reg_weight > 0:
951
+ codebook = self.codebook
952
+
953
+ if self.orthogonal_reg_active_codes_only:
954
+ # only calculate orthogonal loss for the activated codes for this batch
955
+ unique_code_ids = torch.unique(embed_ind)
956
+ codebook = codebook[unique_code_ids]
957
+
958
+ num_codes = codebook.shape[0]
959
+ if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes:
960
+ rand_ids = torch.randperm(num_codes, device=device)[:self.orthogonal_reg_max_codes]
961
+ codebook = codebook[rand_ids]
962
+
963
+ orthogonal_reg_loss = orthogonal_loss_fn(codebook)
964
+ loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight
965
+
966
+ quantize = self.project_out(quantize)
967
+ quantize = self._postprocess(quantize)
968
+
969
+ return quantize, embed_ind, loss
970
+
971
+
972
+ class ResidualVectorQuantization(nn.Module):
973
+ """Residual vector quantization implementation.
974
+
975
+ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
976
+ """
977
+ def __init__(self, *, num_quantizers, **kwargs):
978
+ super().__init__()
979
+ codebook_size = kwargs.pop('codebook_size', None)
980
+ if codebook_size is None:
981
+ raise ValueError("codebook_size must be provided in kwargs")
982
+ if type(codebook_size) != list:
983
+ codebook_size = [codebook_size] * num_quantizers
984
+ self.layers = nn.ModuleList(
985
+ [VectorQuantization(codebook_size=cur_codebook_size, **kwargs) for _,cur_codebook_size in zip(range(num_quantizers), codebook_size)]
986
+ )
987
+
988
+
989
+ # self.layers = nn.ModuleList(
990
+ # [VectorQuantization(**kwargs) for _ in range(num_quantizers)]
991
+ # )
992
+
993
+ def forward(self, x, n_q: tp.Optional[int] = None):
994
+ quantized_out = 0.0
995
+ residual = x
996
+
997
+ all_losses = []
998
+ all_indices = []
999
+
1000
+ n_q = n_q or len(self.layers)
1001
+
1002
+ for i, layer in enumerate(self.layers[:n_q]):
1003
+ quantized, indices, loss = layer(residual)
1004
+ residual = residual - quantized
1005
+ quantized_out = quantized_out + quantized
1006
+ all_indices.append(indices)
1007
+ all_losses.append(loss)
1008
+
1009
+ out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
1010
+ return quantized_out, out_indices, out_losses
1011
+
1012
+ def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
1013
+ residual = x
1014
+ all_indices = []
1015
+ n_q = n_q or len(self.layers)
1016
+ for layer in self.layers[:n_q]:
1017
+ indices = layer.encode(residual)
1018
+ quantized = layer.decode(indices)
1019
+ # the original code is below
1020
+ # since quantize has the gradient of residual, according to line 321
1021
+ # quantize = x + (quantize - x).detach()
1022
+ # the code below will make commitment loss to be 0 for all codebooks except for codebook1
1023
+ # https://github.com/facebookresearch/encodec/issues/25
1024
+ # therefore we change it
1025
+
1026
+ residual = residual - quantized
1027
+ # residual = residual - quantized.detach()
1028
+ # since commitment loss is averaged, the scale of the loss won't get change (not as said in the issue above)
1029
+ all_indices.append(indices)
1030
+ out_indices = torch.stack(all_indices)
1031
+ return out_indices
1032
+
1033
+ def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
1034
+ quantized_out = torch.tensor(0.0, device=q_indices.device)
1035
+ for i, indices in enumerate(q_indices):
1036
+ layer = self.layers[i]
1037
+ quantized = layer.decode(indices)
1038
+ quantized_out = quantized_out + quantized
1039
+ return quantized_out
1040
+
1041
+
1042
+ class ResidualVectorQuantizer(BaseQuantizer):
1043
+ """Residual Vector Quantizer.
1044
+
1045
+ Args:
1046
+ dimension (int): Dimension of the codebooks.
1047
+ n_q (int): Number of residual vector quantizers used.
1048
+ q_dropout (bool): Random quantizer drop out at train time.
1049
+ bins (int): Codebook size.
1050
+ decay (float): Decay for exponential moving average over the codebooks.
1051
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
1052
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
1053
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
1054
+ that have an exponential moving average cluster size less than the specified threshold with
1055
+ randomly selected vector from the current batch.
1056
+ orthogonal_reg_weight (float): Orthogonal regularization weights.
1057
+ orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes.
1058
+ orthogonal_reg_max_codes (optional int): Maximum number of codes to consider.
1059
+ for orthogonal regularization.
1060
+ """
1061
+ def __init__(
1062
+ self,
1063
+ dimension: int = 256,
1064
+ n_q: int = 8,
1065
+ q_dropout: bool = False,
1066
+ bins: tp.Union[int, tp.List[int]] = 1024,
1067
+ decay: float = 0.99,
1068
+ kmeans_init: bool = True,
1069
+ kmeans_iters: int = 10,
1070
+ threshold_ema_dead_code: int = 2,
1071
+ orthogonal_reg_weight: float = 0.0,
1072
+ orthogonal_reg_active_codes_only: bool = False,
1073
+ orthogonal_reg_max_codes: tp.Optional[int] = None,
1074
+ ):
1075
+ super().__init__()
1076
+ self.max_n_q = n_q
1077
+ self.n_q = n_q
1078
+ self.q_dropout = q_dropout
1079
+ self.dimension = dimension
1080
+ self.bins = bins
1081
+ self.decay = decay
1082
+ self.kmeans_init = kmeans_init
1083
+ self.kmeans_iters = kmeans_iters
1084
+ self.threshold_ema_dead_code = threshold_ema_dead_code
1085
+ self.orthogonal_reg_weight = orthogonal_reg_weight
1086
+ self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
1087
+ self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
1088
+ self.vq = ResidualVectorQuantization(
1089
+ dim=self.dimension,
1090
+ codebook_size=self.bins,
1091
+ num_quantizers=self.n_q,
1092
+ decay=self.decay,
1093
+ kmeans_init=self.kmeans_init,
1094
+ kmeans_iters=self.kmeans_iters,
1095
+ threshold_ema_dead_code=self.threshold_ema_dead_code,
1096
+ orthogonal_reg_weight=self.orthogonal_reg_weight,
1097
+ orthogonal_reg_active_codes_only=self.orthogonal_reg_active_codes_only,
1098
+ orthogonal_reg_max_codes=self.orthogonal_reg_max_codes,
1099
+ channels_last=False
1100
+ )
1101
+
1102
+ def forward(self, x: torch.Tensor, frame_rate: int):
1103
+ n_q = self.n_q
1104
+ if self.training and self.q_dropout:
1105
+ n_q = int(torch.randint(1, self.n_q + 1, (1,)).item())
1106
+ if type(self.bins) == list:
1107
+ bins = self.bins
1108
+ else:
1109
+ bins = [self.bins] * self.n_q
1110
+ bw_per_q = [math.log2(bin) * frame_rate / 1000 for bin in bins]
1111
+ bw = torch.tensor(sum(bw_per_q)).to(x)
1112
+ quantized, codes, commit_loss = self.vq(x, n_q=n_q)
1113
+ codes = codes.transpose(0, 1)
1114
+ # codes is [B, K, T], with T frames, K nb of codebooks.
1115
+ return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss))
1116
+
1117
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
1118
+ """Encode a given input tensor with the specified frame rate at the given bandwidth.
1119
+ The RVQ encode method sets the appropriate number of quantizer to use
1120
+ and returns indices for each quantizer.
1121
+ """
1122
+ n_q = self.n_q
1123
+ codes = self.vq.encode(x, n_q=n_q)
1124
+ codes = codes.transpose(0, 1)
1125
+ # codes is [B, K, T], with T frames, K nb of codebooks.
1126
+ return codes
1127
+
1128
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
1129
+ """Decode the given codes to the quantized representation."""
1130
+ # codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T].
1131
+ codes = codes.transpose(0, 1)
1132
+ quantized = self.vq.decode(codes)
1133
+ return quantized
1134
+
1135
+ @property
1136
+ def total_codebooks(self):
1137
+ return self.max_n_q
1138
+
1139
+ @property
1140
+ def num_codebooks(self):
1141
+ return self.n_q
1142
+
1143
+ def set_num_codebooks(self, n: int):
1144
+ assert n > 0 and n <= self.max_n_q
1145
+ self.n_q = n
1146
+
1147
+ class DummyQuantizer(BaseQuantizer):
1148
+ """Fake quantizer that actually does not perform any quantization.
1149
+ """
1150
+ def __init__(self):
1151
+ super().__init__()
1152
+
1153
+ def forward(self, x: torch.Tensor, frame_rate: int):
1154
+ q = x.unsqueeze(1)
1155
+ return QuantizedResult(x, q, torch.tensor(q.numel() * 32 * frame_rate / 1000 / len(x)).to(x))
1156
+
1157
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
1158
+ """Encode a given input tensor with the specified sample rate at the given bandwidth.
1159
+ In the case of the DummyQuantizer, the codes are actually identical
1160
+ to the input and resulting quantized representation as no quantization is done.
1161
+ """
1162
+ return x.unsqueeze(1)
1163
+
1164
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
1165
+ """Decode the given codes to the quantized representation.
1166
+ In the case of the DummyQuantizer, the codes are actually identical
1167
+ to the input and resulting quantized representation as no quantization is done.
1168
+ """
1169
+ return codes.squeeze(1)
1170
+
1171
+ @property
1172
+ def total_codebooks(self):
1173
+ """Total number of codebooks."""
1174
+ return 1
1175
+
1176
+ @property
1177
+ def num_codebooks(self):
1178
+ """Total number of codebooks."""
1179
+ return self.total_codebooks
1180
+
1181
+ def set_num_codebooks(self, n: int):
1182
+ """Set the number of active codebooks."""
1183
+ raise AttributeError("Cannot override the number of codebooks for the dummy quantizer")
1184
+
1185
+
1186
+ class EncodecModel(CompressionModel):
1187
+ """Encodec model operating on the raw waveform.
1188
+
1189
+ Args:
1190
+ encoder (nn.Module): Encoder network.
1191
+ decoder (nn.Module): Decoder network.
1192
+ quantizer (BaseQuantizer): Quantizer network.
1193
+ frame_rate (int): Frame rate for the latent representation.
1194
+ sample_rate (int): Audio sample rate.
1195
+ channels (int): Number of audio channels.
1196
+ causal (bool): Whether to use a causal version of the model.
1197
+ renormalize (bool): Whether to renormalize the audio before running the model.
1198
+ """
1199
+ # we need assignment to override the property in the abstract class,
1200
+ # I couldn't find a better way...
1201
+ frame_rate: float = 0
1202
+ sample_rate: int = 0
1203
+ channels: int = 0
1204
+
1205
+ def __init__(self,
1206
+ encoder: nn.Module,
1207
+ decoder: nn.Module,
1208
+ quantizer: BaseQuantizer,
1209
+ frame_rate: int,
1210
+ sample_rate: int,
1211
+ channels: int,
1212
+ causal: bool = False,
1213
+ renormalize: bool = False):
1214
+ super().__init__()
1215
+ self.encoder = encoder
1216
+ self.decoder = decoder
1217
+ self.quantizer = quantizer
1218
+ self.frame_rate = frame_rate
1219
+ self.sample_rate = sample_rate
1220
+ self.channels = channels
1221
+ self.renormalize = renormalize
1222
+ self.causal = causal
1223
+ if self.causal:
1224
+ # we force disabling here to avoid handling linear overlap of segments
1225
+ # as supported in original EnCodec codebase.
1226
+ assert not self.renormalize, 'Causal model does not support renormalize'
1227
+
1228
+ @property
1229
+ def total_codebooks(self):
1230
+ """Total number of quantizer codebooks available."""
1231
+ return self.quantizer.total_codebooks
1232
+
1233
+ @property
1234
+ def num_codebooks(self):
1235
+ """Active number of codebooks used by the quantizer."""
1236
+ return self.quantizer.num_codebooks
1237
+
1238
+ def set_num_codebooks(self, n: int):
1239
+ """Set the active number of codebooks used by the quantizer."""
1240
+ self.quantizer.set_num_codebooks(n)
1241
+
1242
+ @property
1243
+ def cardinality(self):
1244
+ """Cardinality of each codebook."""
1245
+ return self.quantizer.bins
1246
+
1247
+ def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
1248
+ scale: tp.Optional[torch.Tensor]
1249
+ if self.renormalize:
1250
+ mono = x.mean(dim=1, keepdim=True)
1251
+ volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt()
1252
+ scale = 1e-8 + volume
1253
+ x = x / scale
1254
+ scale = scale.view(-1, 1)
1255
+ else:
1256
+ scale = None
1257
+ return x, scale
1258
+
1259
+ def postprocess(self,
1260
+ x: torch.Tensor,
1261
+ scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
1262
+ if scale is not None:
1263
+ assert self.renormalize
1264
+ x = x * scale.view(-1, 1, 1)
1265
+ return x
1266
+
1267
+ def forward(self, x: torch.Tensor, encode=False) -> QuantizedResult:
1268
+ if encode:
1269
+ return self.encode(x)
1270
+ else:
1271
+ raise NotImplementedError("model forward and training is not supported.")
1272
+ assert x.dim() == 3
1273
+ length = x.shape[-1]
1274
+ x, scale = self.preprocess(x)
1275
+
1276
+ emb = self.encoder(x)
1277
+ q_res = self.quantizer(emb, self.frame_rate)
1278
+ out = self.decoder(q_res.x)
1279
+
1280
+ # remove extra padding added by the encoder and decoder
1281
+ assert out.shape[-1] >= length, (out.shape[-1], length)
1282
+ out = out[..., :length]
1283
+
1284
+ q_res.x = self.postprocess(out, scale)
1285
+
1286
+ return q_res
1287
+
1288
+ def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
1289
+ """Encode the given input tensor to quantized representation along with scale parameter.
1290
+
1291
+ Args:
1292
+ x (torch.Tensor): Float tensor of shape [B, C, T]
1293
+
1294
+ Returns:
1295
+ codes, scale (tuple of torch.Tensor, torch.Tensor): Tuple composed of:
1296
+ codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep.
1297
+ scale a float tensor containing the scale for audio renormalizealization.
1298
+ """
1299
+ assert x.dim() == 3
1300
+ x, scale = self.preprocess(x)
1301
+ emb = self.encoder(x)
1302
+ codes = self.quantizer.encode(emb)
1303
+ return codes, scale
1304
+
1305
+ def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
1306
+ """Decode the given codes to a reconstructed representation, using the scale to perform
1307
+ audio denormalization if needed.
1308
+
1309
+ Args:
1310
+ codes (torch.Tensor): Int tensor of shape [B, K, T]
1311
+ scale (torch.Tensor, optional): Float tensor containing the scale value.
1312
+
1313
+ Returns:
1314
+ out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio.
1315
+ """
1316
+ emb = self.decode_latent(codes)
1317
+ out = self.decoder(emb)
1318
+ out = self.postprocess(out, scale)
1319
+ # out contains extra padding added by the encoder and decoder
1320
+ return out
1321
+
1322
+ def decode_latent(self, codes: torch.Tensor):
1323
+ """Decode from the discrete codes to continuous latent space."""
1324
+ return self.quantizer.decode(codes)
1325
+
1326
+ class EncodecModel_encode_only(CompressionModel):
1327
+ """Encodec model operating on the raw waveform. Encode only, so no decoder
1328
+
1329
+ Args:
1330
+ encoder (nn.Module): Encoder network.
1331
+ quantizer (BaseQuantizer): Quantizer network.
1332
+ frame_rate (int): Frame rate for the latent representation.
1333
+ sample_rate (int): Audio sample rate.
1334
+ channels (int): Number of audio channels.
1335
+ causal (bool): Whether to use a causal version of the model.
1336
+ renormalize (bool): Whether to renormalize the audio before running the model.
1337
+ """
1338
+ # we need assignment to override the property in the abstract class,
1339
+ # I couldn't find a better way...
1340
+ frame_rate: float = 0
1341
+ sample_rate: int = 0
1342
+ channels: int = 0
1343
+
1344
+ def __init__(self,
1345
+ encoder: nn.Module,
1346
+ quantizer: BaseQuantizer,
1347
+ frame_rate: int,
1348
+ sample_rate: int,
1349
+ channels: int,
1350
+ causal: bool = False,
1351
+ renormalize: bool = False):
1352
+ super().__init__()
1353
+ self.encoder = encoder
1354
+ self.quantizer = quantizer
1355
+ self.frame_rate = frame_rate
1356
+ self.sample_rate = sample_rate
1357
+ self.channels = channels
1358
+ self.renormalize = renormalize
1359
+ self.causal = causal
1360
+ if self.causal:
1361
+ # we force disabling here to avoid handling linear overlap of segments
1362
+ # as supported in original EnCodec codebase.
1363
+ assert not self.renormalize, 'Causal model does not support renormalize'
1364
+
1365
+ @property
1366
+ def total_codebooks(self):
1367
+ """Total number of quantizer codebooks available."""
1368
+ return self.quantizer.total_codebooks
1369
+
1370
+ @property
1371
+ def num_codebooks(self):
1372
+ """Active number of codebooks used by the quantizer."""
1373
+ return self.quantizer.num_codebooks
1374
+
1375
+ def set_num_codebooks(self, n: int):
1376
+ """Set the active number of codebooks used by the quantizer."""
1377
+ self.quantizer.set_num_codebooks(n)
1378
+
1379
+ @property
1380
+ def cardinality(self):
1381
+ """Cardinality of each codebook."""
1382
+ return self.quantizer.bins
1383
+
1384
+ def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
1385
+ scale: tp.Optional[torch.Tensor]
1386
+ if self.renormalize:
1387
+ mono = x.mean(dim=1, keepdim=True)
1388
+ volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt()
1389
+ scale = 1e-8 + volume
1390
+ x = x / scale
1391
+ scale = scale.view(-1, 1)
1392
+ else:
1393
+ scale = None
1394
+ return x, scale
1395
+
1396
+ def postprocess(self,
1397
+ x: torch.Tensor,
1398
+ scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
1399
+ if scale is not None:
1400
+ assert self.renormalize
1401
+ x = x * scale.view(-1, 1, 1)
1402
+ return x
1403
+
1404
+ def forward(self, x: torch.Tensor, encode=False) -> QuantizedResult:
1405
+ if encode:
1406
+ return self.encode(x)
1407
+ else:
1408
+ raise NotImplementedError("model forward and training is not supported.")
1409
+ assert x.dim() == 3
1410
+ length = x.shape[-1]
1411
+ x, scale = self.preprocess(x)
1412
+
1413
+ emb = self.encoder(x)
1414
+ q_res = self.quantizer(emb, self.frame_rate)
1415
+ out = self.decoder(q_res.x)
1416
+
1417
+ # remove extra padding added by the encoder and decoder
1418
+ assert out.shape[-1] >= length, (out.shape[-1], length)
1419
+ out = out[..., :length]
1420
+
1421
+ q_res.x = self.postprocess(out, scale)
1422
+
1423
+ return q_res
1424
+
1425
+ def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
1426
+ """Encode the given input tensor to quantized representation along with scale parameter.
1427
+
1428
+ Args:
1429
+ x (torch.Tensor): Float tensor of shape [B, C, T]
1430
+
1431
+ Returns:
1432
+ codes, scale (tuple of torch.Tensor, torch.Tensor): Tuple composed of:
1433
+ codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep.
1434
+ scale a float tensor containing the scale for audio renormalizealization.
1435
+ """
1436
+ assert x.dim() == 3
1437
+ x, scale = self.preprocess(x)
1438
+ emb = self.encoder(x)
1439
+ codes = self.quantizer.encode(emb)
1440
+ return codes, scale
1441
+
1442
+ def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
1443
+ """Decode the given codes to a reconstructed representation, using the scale to perform
1444
+ audio denormalization if needed.
1445
+
1446
+ Args:
1447
+ codes (torch.Tensor): Int tensor of shape [B, K, T]
1448
+ scale (torch.Tensor, optional): Float tensor containing the scale value.
1449
+
1450
+ Returns:
1451
+ out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio.
1452
+ """
1453
+ raise NotImplementedError("Decode is not supported for encode only model")
1454
+ emb = self.decode_latent(codes)
1455
+ out = self.decoder(emb)
1456
+ out = self.postprocess(out, scale)
1457
+ # out contains extra padding added by the encoder and decoder
1458
+ return out
1459
+
1460
+ def decode_latent(self, codes: torch.Tensor):
1461
+ """Decode from the discrete codes to continuous latent space."""
1462
+ raise NotImplementedError("Decode is not supported for encode only model")
1463
+ return self.quantizer.decode(codes)
1464
+
1465
+ def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) -> BaseQuantizer:
1466
+ klass = {
1467
+ 'no_quant': DummyQuantizer,
1468
+ 'rvq': ResidualVectorQuantizer
1469
+ }[quantizer]
1470
+ kwargs = dict_from_config(getattr(cfg, quantizer))
1471
+ if quantizer != 'no_quant':
1472
+ kwargs['dimension'] = dimension
1473
+ return klass(**kwargs)
1474
+
1475
+ def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig):
1476
+ if encoder_name == 'seanet':
1477
+ kwargs = dict_from_config(getattr(cfg, 'seanet'))
1478
+ encoder_override_kwargs = kwargs.pop('encoder')
1479
+ decoder_override_kwargs = kwargs.pop('decoder')
1480
+ encoder_kwargs = {**kwargs, **encoder_override_kwargs}
1481
+ decoder_kwargs = {**kwargs, **decoder_override_kwargs}
1482
+ encoder = SEANetEncoder(**encoder_kwargs)
1483
+ decoder = SEANetDecoder(**decoder_kwargs)
1484
+ return encoder, decoder
1485
+ else:
1486
+ raise KeyError(f"Unexpected compression model {cfg.compression_model}")
1487
+
1488
+
1489
+ def get_compression_model(ckpt_fn, encode_only=False, device="cpu") -> CompressionModel:
1490
+ """Instantiate a compression model."""
1491
+ if device == None:
1492
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1493
+ state = torch.load(ckpt_fn, map_location='cpu')
1494
+ cfg = state['xp.cfg']
1495
+ cfg.device = str(device)
1496
+ weights = state['best_state']['model']
1497
+ assert cfg.compression_model == 'encodec', "Only Encodec model is supported for now."
1498
+ if encode_only:
1499
+ all_keys = list(weights.keys())
1500
+ for key in all_keys:
1501
+ if key.startswith('decoder'):
1502
+ del weights[key]
1503
+ kwargs = dict_from_config(getattr(cfg, 'encodec'))
1504
+ encoder_name = kwargs.pop('autoencoder')
1505
+ quantizer_name = kwargs.pop('quantizer')
1506
+ encoder, _ = get_encodec_autoencoder(encoder_name, cfg)
1507
+ quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension)
1508
+ frame_rate = kwargs['sample_rate'] // encoder.hop_length
1509
+ renormalize = kwargs.pop('renormalize', False)
1510
+ # deprecated params
1511
+ kwargs.pop('renorm', None)
1512
+ compression_model = EncodecModel_encode_only(encoder, quantizer,
1513
+ frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device)
1514
+ assert compression_model.sample_rate == cfg.sample_rate, "Compression model sample rate should match"
1515
+ compression_model.load_state_dict(weights)
1516
+ compression_model.eval()
1517
+ return compression_model
1518
+
1519
+ else:
1520
+ kwargs = dict_from_config(getattr(cfg, 'encodec'))
1521
+ encoder_name = kwargs.pop('autoencoder')
1522
+ quantizer_name = kwargs.pop('quantizer')
1523
+ encoder, decoder = get_encodec_autoencoder(encoder_name, cfg)
1524
+ quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension)
1525
+ frame_rate = kwargs['sample_rate'] // encoder.hop_length
1526
+ renormalize = kwargs.pop('renormalize', False)
1527
+ # deprecated params
1528
+ kwargs.pop('renorm', None)
1529
+ compression_model = EncodecModel(encoder, decoder, quantizer,
1530
+ frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device)
1531
+ assert compression_model.sample_rate == cfg.sample_rate, "Compression model sample rate should match"
1532
+ compression_model.load_state_dict(weights)
1533
+ compression_model.eval()
1534
+ return compression_model
1535
+
1536
+ if __name__ == "__main__":
1537
+ import torchaudio
1538
+ ckpt_fn = "/home/pyp/BoostedVoiceEditor/pretrained/encodec_6f79c6a8.th"
1539
+ audio_in_fns = ["/home/pyp/BoostedVoiceEditor/demo/pam.wav", "/home/pyp/BoostedVoiceEditor/demo/ray.wav", "/home/pyp/BoostedVoiceEditor/demo/84_121550_000074_000000.wav", "/home/pyp/BoostedVoiceEditor/demo/caribbean.wav", "/home/pyp/BoostedVoiceEditor/demo/bible.wav", "/home/pyp/BoostedVoiceEditor/demo/miley.wav"]
1540
+ audio_out_fns = ["/home/pyp/BoostedVoiceEditor/demo/pam_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/ray_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/84_121550_000074_000000_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/caribbean_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/bible_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/miley_encodecTest.wav"]
1541
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1542
+ model = get_compression_model(ckpt_fn, device=device)
1543
+
1544
+ for audio_in_fn, audio_out_fn in zip(audio_in_fns, audio_out_fns):
1545
+ audio_in, sr = torchaudio.load(audio_in_fn)
1546
+ if sr != model.sample_rate:
1547
+ audio_in = torchaudio.transforms.Resample(sr, model.sample_rate)(audio_in)
1548
+ if audio_in.shape[0] == 2:
1549
+ audio_in = audio_in.mean(dim=0, keepdim=True)
1550
+ audio_in = audio_in.unsqueeze(0)
1551
+ audio_in = audio_in.to(torch.float32).to(device)
1552
+ codes = model.encode(audio_in)[0]
1553
+ audio_out = model.decode(codes)[0].cpu()
1554
+ torchaudio.save(audio_out_fn, audio_out, model.sample_rate)
data/emilia_preprocessing/sha256hash.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import sys
3
+
4
+ def sha256_hash_file(filename):
5
+ sha256_hash = hashlib.sha256()
6
+ with open(filename, "rb") as file:
7
+ # Read and update hash string in chunks to handle large files
8
+ for byte_block in iter(lambda: file.read(4096), b""):
9
+ sha256_hash.update(byte_block)
10
+ return sha256_hash.hexdigest()
11
+
12
+ # Usage example
13
+ filename = sys.argv[1]
14
+ print(sha256_hash_file(filename))
data/emilia_preprocessing/step1_download.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # conda activate emilia
2
+ from datasets import load_dataset
3
+ import fire
4
+ def main(root: str="/data/scratch/pyp/datasets/emilia"):
5
+ path = "EN/*.tar.gz*"
6
+ dataset = load_dataset("amphion/Emilia-Dataset", data_files={"en": path}, split="en", streaming=False, revision="fc71e07e8572f5f3be1dbd02ed3172a4d298f152", cache_dir=root)
7
+
8
+ if __name__ == "__main__":
9
+ fire.Fire(main)
data/emilia_preprocessing/step2_log_tar_files.sh ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ emilia_root=$1 # /data/scratch/pyp/datasets/emilia/downloads
2
+ for file in ${emilia_root}/* ; do
3
+ # Check if gzip compressed archive
4
+ if file "$file" | grep -q 'gzip compressed data'; then
5
+ # Extract string of form 'was "EN_B00100.tar"'' from the output of the file command to keep EN_B00100
6
+ filename=$(file "$file" | grep -oP '(?<=was ")[^"]*' | sed 's/\.tar$//')
7
+ # Get the file size
8
+ size=$(du -sh "$file" | cut -f1)
9
+ original_filename=$(basename "$file")
10
+ # Get URL string from corresponding JSON file with same basename
11
+ json_file=$file.json
12
+ if [ -f "$json_file" ]; then
13
+ # url=$(jq -r '.url' "$json_file") # jq is not installed on the server
14
+ url=$(python3 -c "import sys, json; print(json.load(open('$json_file'))['url'])")
15
+ else
16
+ url="N/A"
17
+ fi
18
+ # Compute SHA256 hash of the file
19
+ hash=$(python sha256hash.py "$file")
20
+ echo $original_filename
21
+ # Write filename, size, hash, original filename, URL to output file
22
+ echo "$filename, $size, $hash, $original_filename, $url" >> file_log.txt
23
+ fi
24
+ done
25
+
26
+ # Sort the output file by filename
27
+ sort -o file_log.txt file_log.txt
data/emilia_preprocessing/step3_untar.sh ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Define the root directory where the tar files are located
4
+ root=$1 # /data/scratch/pyp/datasets/emilia/downloads
5
+ save_root=$2 # /data/scratch/pyp/datasets/emilia/preprocessed/audio
6
+
7
+ mkdir -p "${save_root}"
8
+
9
+ # Input log files
10
+ log_file="file_log.txt" # Full log of files to process
11
+ exist_log_file="file_log_debug.txt" # Log of already processed files
12
+ failure_log="untar_failures.log" # Log file for untar failures
13
+
14
+ # Clear previous failure log
15
+ > "$failure_log"
16
+
17
+ # Create an array of filenames already processed (from exist_log_file)
18
+ if [ -f "$exist_log_file" ]; then
19
+ mapfile -t existing_files < "$exist_log_file"
20
+ else
21
+ existing_files=()
22
+ fi
23
+
24
+ # Create a temporary filtered log of files to process
25
+ filtered_log="filtered_file_log.txt"
26
+ grep -v -F -f "$exist_log_file" "$log_file" > "$filtered_log"
27
+
28
+ # Count total filtered files
29
+ total_files=$(wc -l < "$filtered_log")
30
+ echo "Found $total_files entries to process in $filtered_log."
31
+
32
+ # Print the filtered files
33
+ echo "Filtered files to process:"
34
+ cat "$filtered_log"
35
+ echo
36
+
37
+ # Confirm before starting processing
38
+ read -p "Do you want to proceed with the above files? (y/n): " confirm
39
+ if [[ "$confirm" != "y" ]]; then
40
+ echo "Operation canceled."
41
+ rm -f "$filtered_log"
42
+ exit 1
43
+ fi
44
+
45
+ # Start time
46
+ start_time=$(date +%s)
47
+
48
+ # Counter for how many lines we've processed
49
+ count=0
50
+
51
+ # Process filtered log
52
+ while IFS=',' read -r filename size local_sha256 original_filename url; do
53
+ count=$((count + 1))
54
+
55
+ # Trim leading/trailing whitespace
56
+ filename=$(echo "$filename" | xargs)
57
+ size=$(echo "$size" | xargs)
58
+ local_sha256=$(echo "$local_sha256" | xargs)
59
+ original_filename=$(echo "$original_filename" | xargs)
60
+ url=$(echo "$url" | xargs)
61
+
62
+ # Construct the full path to the tar file
63
+ tar_file="${root}/${original_filename}"
64
+
65
+ # Check if the tar file exists
66
+ if [ ! -f "$tar_file" ]; then
67
+ echo "❌ File not found: $tar_file"
68
+ echo "$filename, $size, $local_sha256, $original_filename, $url" >> "$failure_log"
69
+ else
70
+ # Try to untar the file
71
+ echo "[$count/$total_files] Untarring $tar_file..."
72
+
73
+ if ! tar -xf "$tar_file" -C "${save_root}"; then
74
+ # If untar fails, log the failure
75
+ echo "❌ Failed to untar: $tar_file"
76
+ echo "$filename, $size, $local_sha256, $original_filename, $url" >> "$failure_log"
77
+ else
78
+ echo "✅ Successfully untarred: $tar_file"
79
+ # Append successfully untarred filename to exist_log_file
80
+ echo "$filename" >> "$exist_log_file"
81
+ fi
82
+ fi
83
+
84
+ # Calculate elapsed time, average time per file, and ETA
85
+ now=$(date +%s)
86
+ elapsed=$(( now - start_time )) # total seconds since the start
87
+ if [ $count -gt 0 ]; then
88
+ avg_time=$(awk "BEGIN { printf \"%.2f\", $elapsed / $count }")
89
+ remain=$(( total_files - count ))
90
+ eta_seconds=$(awk "BEGIN { printf \"%.0f\", $avg_time * $remain }")
91
+ eta_formatted=$(date -ud "@${eta_seconds}" +'%H:%M:%S')
92
+ echo "Elapsed: ${elapsed}s | Avg/f: ${avg_time}s | Remaining: $remain files | ETA: ~${eta_formatted}"
93
+ fi
94
+
95
+ done < "$filtered_log"
96
+
97
+ # Clean up temporary filtered log
98
+ rm -f "$filtered_log"
99
+
100
+ # Summary
101
+ echo "Untar operation completed. Check $failure_log for any failures."
data/emilia_preprocessing/step4_construct_manifest.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # construct manifest file for training, note that we only have one train split
2
+ # also create neighbors folder for each sample, which is simply done through speaker label in the original manifest where each file has rows
3
+ # path\tdistance\tduration
4
+ # where distance is always 0 because we don't know the distance between the samples
5
+
6
+ # waiting on Yushen Chen to provide data filtering approach
7
+ import sys, copy
8
+ import os, random, numpy as np, socket
9
+
10
+ import json
11
+ import tqdm
12
+ from multiprocessing import Pool
13
+ import glob, os
14
+ from collections import defaultdict
15
+ def write_jsonl(data, fn):
16
+ with open(fn, "w") as file:
17
+ for entry in data:
18
+ file.write(json.dumps(entry, ensure_ascii=False) + "\n")
19
+ def read_jsonl(file_path):
20
+ cur_data = []
21
+ with open(file_path, 'r', encoding='utf-8-sig') as file:
22
+ for line in file:
23
+ cur_data.append(json.loads(line.strip()))
24
+ return cur_data
25
+
26
+ def repetition_found(text, length=2, tolerance=10):
27
+ pattern_count = defaultdict(int)
28
+ for i in range(len(text) - length + 1):
29
+ pattern = text[i : i + length]
30
+ pattern_count[pattern] += 1
31
+ for pattern, count in pattern_count.items():
32
+ if count > tolerance:
33
+ return True
34
+ return False
35
+
36
+
37
+
38
+ out_en = {
39
+ "EN_B00013_S00913",
40
+ "EN_B00042_S00120",
41
+ "EN_B00055_S04111",
42
+ "EN_B00061_S00693",
43
+ "EN_B00061_S01494",
44
+ "EN_B00061_S03375",
45
+ "EN_B00059_S00092",
46
+ "EN_B00111_S04300",
47
+ "EN_B00100_S03759",
48
+ "EN_B00087_S03811",
49
+ "EN_B00059_S00950",
50
+ "EN_B00089_S00946",
51
+ "EN_B00078_S05127",
52
+ "EN_B00070_S04089",
53
+ "EN_B00074_S09659",
54
+ "EN_B00061_S06983",
55
+ "EN_B00061_S07060",
56
+ "EN_B00059_S08397",
57
+ "EN_B00082_S06192",
58
+ "EN_B00091_S01238",
59
+ "EN_B00089_S07349",
60
+ "EN_B00070_S04343",
61
+ "EN_B00061_S02400",
62
+ "EN_B00076_S01262",
63
+ "EN_B00068_S06467",
64
+ "EN_B00076_S02943",
65
+ "EN_B00064_S05954",
66
+ "EN_B00061_S05386",
67
+ "EN_B00066_S06544",
68
+ "EN_B00076_S06944",
69
+ "EN_B00072_S08620",
70
+ "EN_B00076_S07135",
71
+ "EN_B00076_S09127",
72
+ "EN_B00065_S00497",
73
+ "EN_B00059_S06227",
74
+ "EN_B00063_S02859",
75
+ "EN_B00075_S01547",
76
+ "EN_B00061_S08286",
77
+ "EN_B00079_S02901",
78
+ "EN_B00092_S03643",
79
+ "EN_B00096_S08653",
80
+ "EN_B00063_S04297",
81
+ "EN_B00063_S04614",
82
+ "EN_B00079_S04698",
83
+ "EN_B00104_S01666",
84
+ "EN_B00061_S09504",
85
+ "EN_B00061_S09694",
86
+ "EN_B00065_S05444",
87
+ "EN_B00063_S06860",
88
+ "EN_B00065_S05725",
89
+ "EN_B00069_S07628",
90
+ "EN_B00083_S03875",
91
+ "EN_B00071_S07665",
92
+ "EN_B00071_S07665",
93
+ "EN_B00062_S04187",
94
+ "EN_B00065_S09873",
95
+ "EN_B00065_S09922",
96
+ "EN_B00084_S02463",
97
+ "EN_B00067_S05066",
98
+ "EN_B00106_S08060",
99
+ "EN_B00073_S06399",
100
+ "EN_B00073_S09236",
101
+ "EN_B00087_S00432",
102
+ "EN_B00085_S05618",
103
+ "EN_B00064_S01262",
104
+ "EN_B00072_S01739",
105
+ "EN_B00059_S03913",
106
+ "EN_B00069_S04036",
107
+ "EN_B00067_S05623",
108
+ "EN_B00060_S05389",
109
+ "EN_B00060_S07290",
110
+ "EN_B00062_S08995",
111
+ }
112
+ en_filters = ["ا", "い", "て"]
113
+
114
+
115
+ from multiprocessing import Pool
116
+
117
+ def process_meta_item(item, root, sub_root, audio_folder, audio_ext, text_ext):
118
+ global filtered_duration, filtered_count, total_duration, total_count
119
+ # Data filtering following Yushen's approach
120
+ if (
121
+ item["wav"].split("/")[-1] in out_en
122
+ or any(t in item["text"] for t in en_filters)
123
+ or repetition_found(item["text"], length=4)
124
+ ):
125
+ return None, item["duration"], 1, 0, 0, (None, None) # Return filtered results
126
+
127
+ # Trim leading space from text if exists
128
+ if item["text"].startswith(" "):
129
+ item["text"] = item["text"][1:]
130
+
131
+ # write text to text file
132
+ text_fn = os.path.join(root, sub_root, audio_folder, item["wav"].replace(audio_ext, text_ext))
133
+ os.makedirs(os.path.dirname(text_fn), exist_ok=True)
134
+ with open(text_fn, "w") as f:
135
+ f.write(item["text"])
136
+
137
+ # spk2info[item["speaker"]].append(item)
138
+ return (
139
+ f"{item['wav']}\t{item['duration']}\n",
140
+ 0,
141
+ 0,
142
+ item["duration"],
143
+ 1,
144
+ (item['speaker'], item)
145
+ ) # Return processed results
146
+
147
+
148
+ def parallel_process_meta(meta, root, sub_root, audio_folder, num_workers, audio_ext, text_ext):
149
+ with Pool(num_workers) as pool:
150
+ results = pool.starmap(
151
+ process_meta_item,
152
+ [(item, root, sub_root, audio_folder, audio_ext, text_ext) for item in meta],
153
+ )
154
+
155
+ processed_items = []
156
+ spkitem = []
157
+ filtered_duration = 0
158
+ filtered_count = 0
159
+ total_duration = 0
160
+ total_count = 0
161
+
162
+ for result in results:
163
+ if result[0]: # If the item was processed
164
+ processed_items.append(result[0])
165
+ filtered_duration += result[1]
166
+ filtered_count += result[2]
167
+ total_duration += result[3]
168
+ total_count += result[4]
169
+ spkitem.append(result[5])
170
+
171
+ return processed_items, filtered_duration, filtered_count, total_duration, total_count, spkitem
172
+
173
+
174
+ def main(
175
+ root: str = "/data/scratch/pyp/datasets/emilia",
176
+ sub_root: str = "preprocessed",
177
+ audio_folder: str = "audio",
178
+ manifest_folder: str = "manifest_for_codec",
179
+ neighbors_folder: str = "neighbors",
180
+ audio_ext: str = ".mp3",
181
+ text_ext: str = ".txt",
182
+ num_workers: int = 8, # Specify the number of workers
183
+ ):
184
+ # Find the segments that are untarred
185
+ all_fns = [
186
+ item
187
+ for item in glob.glob(f"{root}/{sub_root}/{audio_folder}/*")
188
+ if os.path.basename(item).startswith("EN_") and os.path.isdir(item)
189
+ ]
190
+ print(f"found {len(all_fns)} untarred segments")
191
+ print(f"{all_fns[:3]}")
192
+
193
+ res = []
194
+ total_duration = 0
195
+ total_count = 0
196
+ filtered_duration = 0
197
+ filtered_count = 0
198
+
199
+ for fn in tqdm.tqdm(all_fns, desc="overall progress"):
200
+ spk2info = defaultdict(list)
201
+ metafn = os.path.join(root, "EN", os.path.basename(fn) + ".jsonl")
202
+ meta = read_jsonl(metafn)
203
+
204
+ # Parallel process metadata
205
+ processed_items, fd, fc, td, tc, spkitem = parallel_process_meta(
206
+ meta, root, sub_root, audio_folder, num_workers, audio_ext, text_ext
207
+ )
208
+
209
+ # Aggregate results
210
+ res.extend(processed_items)
211
+ filtered_duration += fd
212
+ filtered_count += fc
213
+ total_duration += td
214
+ total_count += tc
215
+
216
+ for spk, item in spkitem:
217
+ if spk:
218
+ spk2info[spk].append(item)
219
+
220
+ # Save neighbor files
221
+ for spk in spk2info:
222
+ for item in spk2info[spk]:
223
+ neighbor_fn = os.path.join(
224
+ root,
225
+ sub_root,
226
+ neighbors_folder,
227
+ item["wav"].replace(audio_ext, text_ext),
228
+ )
229
+ os.makedirs(os.path.dirname(neighbor_fn), exist_ok=True)
230
+ tobe_write = [f"{neighbor_item['wav'].replace(audio_ext, text_ext)}\t0\t{neighbor_item['duration']}\n" for neighbor_item in spk2info[spk] if neighbor_item["wav"] != item["wav"]]
231
+ if tobe_write:
232
+ with open(neighbor_fn, "w") as f:
233
+ f.writelines(tobe_write)
234
+
235
+ print(
236
+ f"total duration: {total_duration / 3600:.2f} hours, total count: {total_count}"
237
+ )
238
+ print(
239
+ f"filtered duration: {filtered_duration / 3600:.2f} hours, filtered count: {filtered_count}"
240
+ )
241
+ save_fn = os.path.join(root, sub_root, manifest_folder, "train.txt")
242
+ os.makedirs(os.path.dirname(save_fn), exist_ok=True)
243
+ with open(save_fn, "w") as f:
244
+ for item in res:
245
+ f.write(item)
246
+
247
+
248
+ if __name__ == "__main__":
249
+ import fire
250
+
251
+ fire.Fire(main)
data/emilia_preprocessing/step5_phonemize.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, copy
2
+ import os, random, numpy as np, socket
3
+
4
+ import json
5
+ import tqdm
6
+ from multiprocessing import Pool
7
+ import glob, os, fire
8
+ from collections import defaultdict
9
+ sys.path.insert(0, "../../")
10
+ from data.tokenizer import TextTokenizer, tokenize_text
11
+
12
+ def write_jsonl(data, fn):
13
+ with open(fn, "w") as file:
14
+ for entry in data:
15
+ file.write(json.dumps(entry, ensure_ascii=False) + "\n")
16
+
17
+
18
+ def read_jsonl(file_path):
19
+ cur_data = []
20
+ with open(file_path, 'r', encoding='utf-8-sig') as file:
21
+ for line in file:
22
+ cur_data.append(json.loads(line.strip()))
23
+ return cur_data
24
+
25
+
26
+ def phonemize_and_save(text, fn, text_tokenizer):
27
+ """Phonemizes the text and saves the result to a file."""
28
+ phn = tokenize_text(text_tokenizer, text)
29
+ os.makedirs(os.path.dirname(fn), exist_ok=True)
30
+ with open(fn, "w") as f:
31
+ f.write(" ".join(phn))
32
+ return set(phn)
33
+
34
+
35
+ def process_item(item, root, sub_root, audio_folder, phn_folder, audio_ext, text_ext, phn_ext, text_tokenizer):
36
+ """Worker function to process a single item."""
37
+ text_path = os.path.join(root, sub_root, audio_folder, item[0].replace(audio_ext, text_ext))
38
+ if not os.path.exists(text_path):
39
+ return {"missing_text": text_path, "success": False, "cur_phn_set": set()}
40
+
41
+ with open(text_path, "r") as f:
42
+ text = [line.strip() for line in f.readlines()]
43
+ text = " ".join(text)
44
+
45
+ phn_path = os.path.join(root, sub_root, phn_folder, item[0].replace(audio_ext, phn_ext))
46
+ cur_phn_set = phonemize_and_save(text, phn_path, text_tokenizer)
47
+ return {"missing_text": None, "success": True, "cur_phn_set": cur_phn_set}
48
+
49
+
50
+ def process_item_star(args):
51
+ """Unpacks arguments for `process_item` to work with `imap`."""
52
+ return process_item(*args)
53
+
54
+ def main(
55
+ root="/data/scratch/pyp/datasets/emilia",
56
+ sub_root="preprocessed",
57
+ manifest_folder="manifest_for_codec",
58
+ audio_folder="audio",
59
+ phn_folder="phoneme",
60
+ audio_ext=".mp3",
61
+ text_ext=".txt",
62
+ phn_ext=".txt",
63
+ num_workers=8,
64
+ ):
65
+ """Main function to process phoneme generation in parallel."""
66
+ # # Initialize the tokenizer
67
+ text_tokenizer = TextTokenizer()
68
+ all_fns = glob.glob(f"{root}/{sub_root}/{manifest_folder}/*.txt")
69
+ print(f"found {len(all_fns)} manifest files")
70
+ print(f"{all_fns[:3]=}")
71
+
72
+ data = []
73
+ for fn in all_fns:
74
+ with open(fn, "r") as f:
75
+ data += [line.strip().split("\t") for line in f]
76
+
77
+ vocab = set()
78
+
79
+ ################## parallel processing ##################
80
+ ################## parallel processing ##################
81
+ ################## parallel processing ##################
82
+ # Prepare arguments for the worker function
83
+ # tasks = [
84
+ # (
85
+ # item,
86
+ # root,
87
+ # sub_root,
88
+ # audio_folder,
89
+ # phn_folder,
90
+ # audio_ext,
91
+ # text_ext,
92
+ # phn_ext,
93
+ # text_tokenizer,
94
+ # )
95
+ # for item in data
96
+ # ]
97
+
98
+ # # Parallel processing with progress monitoring
99
+ # results = []
100
+ # with Pool(num_workers) as pool:
101
+ # for result in tqdm.tqdm(
102
+ # pool.imap_unordered(process_item_star, tasks),
103
+ # total=len(tasks),
104
+ # desc="Processing items",
105
+ # ):
106
+ # results.append(result)
107
+ # # read all manifest endswith .txt
108
+ # missing_text = [result["missing_text"] for result in results if not result["success"]]
109
+ # for result in results:
110
+ # if result['success']:
111
+ # vocab.update(result['cur_phn_set'])
112
+ ################## parallel processing ##################
113
+ ################## parallel processing ##################
114
+ ################## parallel processing ##################
115
+
116
+ ################## sequential processing ##################
117
+ ################## sequential processing ##################
118
+ ################## sequential processing ##################
119
+ missing_text = []
120
+ for item in tqdm.tqdm(data):
121
+ text_path = os.path.join(root, sub_root, audio_folder, item[0].replace(audio_ext, text_ext))
122
+ if not os.path.exists(text_path):
123
+ missing_text.append(text_path)
124
+ continue
125
+ try:
126
+ with open(text_path, "r") as f:
127
+ text = [line.strip() for line in f.readlines()]
128
+ text = " ".join(text)
129
+ except:
130
+ print(f"Error reading {text_path}")
131
+ continue
132
+ cur_phn_set = phonemize_and_save(text, os.path.join(root, sub_root, phn_folder, item[0].replace(audio_ext, phn_ext)), text_tokenizer)
133
+ vocab.update(cur_phn_set)
134
+ ################## sequential processing ##################
135
+ ################## sequential processing ##################
136
+ ################## sequential processing ##################
137
+
138
+ # save the vocab
139
+ vocab = list(vocab)
140
+ # sort the vocab
141
+ vocab.sort()
142
+ with open(os.path.join(root, sub_root, "vocab.txt"), "w") as f:
143
+ f.write("\n".join(vocab))
144
+
145
+ # Collect missing text paths
146
+ print(f"Missing text files: {len(missing_text)}")
147
+ if missing_text:
148
+ print("Some missing files:", missing_text[:10]) # Print the first 10 missing files as an example
149
+
150
+
151
+ if __name__ == "__main__":
152
+ fire.Fire(main)
153
+
154
+
155
+
156
+
157
+
158
+
data/emilia_preprocessing/step6_encodec_encode.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from email.policy import default
3
+ def parse_args():
4
+ parser = argparse.ArgumentParser(description="encode the dataset using codec model")
5
+ parser.add_argument('--root', type=str, default="/data/scratch/pyp/datasets/emilia", help="Path to the directory")
6
+ parser.add_argument('--sub_root', type=str, default="preprocessed", help="sub directory")
7
+ parser.add_argument('--encodec_name', type=str, default="encodec_6f79c6a8.th", help="name of the codec model")
8
+ parser.add_argument('--n_workers', type=int, default=16, help="Number of parallel worker processes")
9
+ parser.add_argument('--batch_size', type=int, default=16, help="batch size for codec encoding, decrease it if OOM. This is the sum of batch size *over each gpu*, so increase it if you are using more gpus")
10
+ parser.add_argument('--audio_sr', type=int, default=16000, help='input audio sample rate')
11
+ parser.add_argument('--model_sr', type=int, default=16000, help='encodec input audio sample rate')
12
+ parser.add_argument('--downsample_rate', type=int, default=320, help='encodec downsample rate')
13
+ parser.add_argument('--model_code_sr', type=float, default=50, help='codec model code sample rate')
14
+ parser.add_argument('--len_cap', type=float, default=1000, help='will drop audios that are longer than this number')
15
+ parser.add_argument('--min_len', type=float, default=0.5, help='will drop audios that are shorter than this number')
16
+ parser.add_argument('--partition', type=str, default="1/1", help='split for parallel processing')
17
+ parser.add_argument('--split', type=str, default='train', choices=['train', 'valid', 'test'])
18
+ return parser.parse_args()
19
+
20
+ if __name__ == "__main__":
21
+ import logging
22
+ formatter = (
23
+ "%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s"
24
+ )
25
+ logging.basicConfig(format=formatter, level=logging.INFO)
26
+
27
+ import os, sys
28
+ import numpy as np
29
+ import torch
30
+ import torchaudio
31
+ import tqdm
32
+ import time
33
+
34
+ args = parse_args()
35
+
36
+ def sort_by_audio_len(lens):
37
+ inds = np.argsort(lens).tolist()
38
+
39
+ logging.info(f"longest: {lens[inds[-1]]/args.downsample_rate} encodec codes, {lens[inds[-1]]/args.model_sr:.2f} sec.")
40
+ logging.info(f"shortest: {lens[inds[0]]/args.downsample_rate} encodec codes, {lens[inds[0]]/args.model_sr:.2f} sec.")
41
+ logging.info(f"median: {lens[inds[len(inds)//2]]/args.downsample_rate} encodec codes, {lens[inds[len(inds)//2]]/args.model_sr:.2f} sec.")
42
+ logging.info(f"95 percentile longest: {lens[inds[int(len(inds)*0.95)]]/args.downsample_rate} encodec codes, {lens[inds[int(len(inds)*0.95)]]/args.model_sr:.2f} sec.")
43
+ return inds[::-1]
44
+
45
+ def write_array_to_txt_file(array, filename):
46
+ with open(filename, 'w') as f:
47
+ for a in array[:-1]:
48
+ f.write(' '.join(map(str, a))+'\n')
49
+ f.write(' '.join(map(str, array[-1])))
50
+
51
+ class mydataset(torch.utils.data.Dataset):
52
+ def __init__(self, split):
53
+ super().__init__()
54
+ self.split = split
55
+ self.audio_dir = audio_dir
56
+ manifest_fn = os.path.join(encodec_manifest_dir, split+".txt")
57
+ cur_sp = int(args.partition.split("/")[0])-1
58
+ total_sp = int(args.partition.split("/")[1])
59
+ with open(manifest_fn, "r") as rf:
60
+ self.data = [l.strip().split("\t") for l in rf.readlines()][cur_sp::total_sp]
61
+ def __len__(self):
62
+ return len(self.data)
63
+ def __getitem__(self, ind):
64
+ try:
65
+ afn = self.data[ind][0]
66
+ fn = os.path.join(self.audio_dir, afn)
67
+ audio, sr = torchaudio.load(fn)
68
+ if sr != args.model_sr:
69
+ audio = torchaudio.transforms.Resample(sr, args.model_sr)(audio)
70
+ sr = args.model_sr
71
+ assert sr == args.model_sr, sr
72
+ except Exception as e:
73
+ # logging.info(f"{e}")
74
+ return None, None, None
75
+ assert audio.ndim==2 and audio.shape[0] == 1, audio.shape
76
+ return audio.type(torch.float32).squeeze(0), audio.shape[-1], os.path.splitext(afn)[0]
77
+ def collate(self, batch):
78
+ lens, audios, segment_ids = [], [], []
79
+ for item in batch:
80
+ if item[0] != None:
81
+ audios.append(item[0])
82
+ lens.append(item[1])
83
+ segment_ids.append(item[2])
84
+ return audios, lens, segment_ids
85
+
86
+ # roots
87
+ sub_root = args.sub_root
88
+ encodec_manifest_dir = os.path.join(args.root, sub_root, "manifest_for_codec")
89
+ audio_dir = os.path.join(args.root, sub_root, "audio")
90
+ save_manifest_dir = os.path.join(args.root, sub_root,"manifest_final_encodec")
91
+ if args.encodec_name == "encodec_6f79c6a8.th":
92
+ save_codes_dir = os.path.join(args.root, sub_root,"encodec_4cb")
93
+ elif args.encodec_name == "encodec_8cb1024_giga.th":
94
+ save_codes_dir = os.path.join(args.root, sub_root,"encodec_8cb")
95
+
96
+ os.makedirs(save_manifest_dir, exist_ok=True)
97
+ os.makedirs(save_codes_dir, exist_ok=True)
98
+
99
+ def import_encodec():
100
+ from encodec import get_compression_model
101
+ userdir = os.path.expanduser("~")
102
+ model = get_compression_model(os.path.join(userdir, "VoiceStar", f"pretrained/{args.encodec_name}"), encode_only=True, device="cuda")
103
+ model = torch.nn.DataParallel(model)
104
+ return model
105
+ model = import_encodec()
106
+
107
+ # setup dataloader
108
+ mega_batch_size = 2048
109
+ batch_size = args.batch_size
110
+
111
+ dataset = mydataset(args.split)
112
+ if len(dataset) == 0:
113
+ logging.info(f"no data found for split {args.split} partition {args.partition}")
114
+ sys.exit(0)
115
+ loader = torch.torch.utils.data.DataLoader(dataset, batch_size=mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=dataset.collate)
116
+ split = args.split
117
+
118
+ skip = 0
119
+ logging.info(f"now processing split {split} partition {args.partition}...")
120
+ mega_n_steps = int(np.ceil(len(loader.dataset) / mega_batch_size))
121
+ # mega_n_steps = int(np.ceil(len(gs) / mega_batch_size))
122
+ logging.info(f"partition the split {split} into {mega_n_steps} parts, each has at most {mega_batch_size} samples")
123
+ mani_fn = os.path.join(save_manifest_dir, f"{split}_{args.partition.replace('/', '=')}.txt")
124
+ logging.info(f"manifest for split {split} partition {args.partition.replace('/', '=')}.txt will be saved at {mani_fn}")
125
+ with open(mani_fn, "w") as mani_wf:
126
+ # with open(mani_fn, "a") as mani_wf: # resume from where we failed
127
+ for m, mega_batch in enumerate(tqdm.tqdm(loader, mininterval=60, maxinterval=60)):
128
+
129
+ logging.info(f"====================================")
130
+ logging.info(f"====================================")
131
+ logging.info(f"now processing mega step {m+1}/{mega_n_steps}")
132
+
133
+ try:
134
+ lengths = np.array(mega_batch[1])
135
+ sorted_inds = sort_by_audio_len(lengths)
136
+ for j in range(len(sorted_inds))[::-1]:
137
+ if lengths[sorted_inds[j]] < args.model_sr*args.min_len or lengths[sorted_inds[j]] > args.model_sr*args.len_cap: # skip samples that are too short (shorter than 0.2s), or too big (bigger than 80s)
138
+ skip += 1
139
+ del sorted_inds[j]
140
+
141
+ n_steps = int(np.ceil(len(sorted_inds) / batch_size))
142
+ for n in tqdm.tqdm(range(n_steps), disable=True):
143
+ inds_used = sorted_inds[n*batch_size:(n+1)*batch_size]
144
+ wav_batch = [mega_batch[0][id] for id in inds_used]
145
+ all_lens = [mega_batch[1][id] for id in inds_used]
146
+ segment_id_batch = [mega_batch[2][id] for id in inds_used]
147
+ padded_wav = torch.nn.utils.rnn.pad_sequence(wav_batch, batch_first=True).unsqueeze(1) # [B, T] -> [B, 1, T]
148
+ # Extract discrete codes from EnCodec
149
+ with torch.no_grad():
150
+ if max(all_lens) > 300000 and len(all_lens) > 1: # if utterances are long, simply pass half of them at a time
151
+ codes = []
152
+ inwav = padded_wav.cuda()
153
+ codes.append(model(inwav[:len(inwav)//2])[0].cpu())
154
+ codes.append(model(inwav[len(inwav)//2:])[0].cpu())
155
+ codes = torch.cat(codes, dim=0)
156
+ else:
157
+ encoded_frames = model(padded_wav.cuda())
158
+ codes = encoded_frames[0].cpu() # [B, n_codebook, T]
159
+
160
+ for i, length in enumerate(all_lens):
161
+ save_fn = os.path.join(save_codes_dir, segment_id_batch[i]+".txt")
162
+ actual_len = round(length / args.downsample_rate) # 320 is downsample rate for this model
163
+ cur_code = codes[i].tolist() if type(codes) == list else codes[i, :, :actual_len].tolist()
164
+ os.makedirs(os.path.dirname(save_fn), exist_ok=True)
165
+ write_array_to_txt_file(cur_code, save_fn)
166
+
167
+ mani_wf.write(f"{segment_id_batch[i]}\t{len(cur_code[0])}\n") # write to manifest file
168
+ # if i == 10:
169
+ # raise
170
+ except Exception as e:
171
+ print(f'exception!! at {m+1}')
172
+ print(e)
173
+ continue
174
+
175
+ # break
176
+ logging.info(f"split {split} partition {args.partition} has {len(loader.dataset)} samples in total, skipped {skip} due to utterance being too long or too short")
177
+ # break
data/emilia_preprocessing/step6_encodec_encode_script.sh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ source ~/miniconda3/etc/profile.d/conda.sh
3
+ conda activate voicestar
4
+
5
+ dir=${processed_dir:-/data/scratch/pyp/datasets/emilia}
6
+ sub_root=${sub_root:-preprocessed}
7
+ encodec_name=${encodec_name:-"encodec_6f79c6a8.th"}
8
+ n_workers=${n_workers:-64}
9
+ batch_size=${batch_size:-512}
10
+ audio_sr=16000
11
+ model_sr=16000
12
+ downsample_rate=320
13
+ model_code_sr=50
14
+ len_cap=1000
15
+ min_len=0.5
16
+ partition=${partition:-"1/1"}
17
+ split=${split:-"train"}
18
+
19
+ python step6_encodec_encode.py --root $dir --sub_root ${sub_root} --encodec_name ${encodec_name} --n_workers $n_workers --batch_size $batch_size --audio_sr $audio_sr --model_sr $model_sr --downsample_rate $downsample_rate --model_code_sr $model_code_sr --len_cap $len_cap --min_len $min_len --partition $partition --split $split
data/encodec.py ADDED
@@ -0,0 +1,1554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Compression models or wrapper around existing models.
7
+ Also defines the main interface that a model must follow to be usable as an audio tokenizer.
8
+ """
9
+
10
+ from abc import ABC, abstractmethod
11
+ from dataclasses import dataclass, field
12
+ import logging
13
+ import math
14
+ from pathlib import Path
15
+ import typing as tp
16
+
17
+ import numpy as np
18
+ import torch
19
+ from torch import nn
20
+ from torch import einsum
21
+ import torch.nn.functional as F
22
+ from torch.nn.utils import spectral_norm, weight_norm
23
+
24
+ import logging
25
+ import warnings
26
+ from einops import rearrange, repeat
27
+ import omegaconf
28
+ # import flashy
29
+
30
+ CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
31
+ 'time_group_norm'])
32
+
33
+ def dict_from_config(cfg: omegaconf.DictConfig) -> dict:
34
+ """Convenience function to map an omegaconf configuration to a dictionary.
35
+
36
+ Args:
37
+ cfg (omegaconf.DictConfig): Original configuration to map to dict.
38
+ Returns:
39
+ dict: Config as dictionary object.
40
+ """
41
+ dct = omegaconf.OmegaConf.to_container(cfg, resolve=True)
42
+ assert isinstance(dct, dict)
43
+ return dct
44
+
45
+ @dataclass
46
+ class QuantizedResult:
47
+ x: torch.Tensor
48
+ codes: torch.Tensor
49
+ bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
50
+ penalty: tp.Optional[torch.Tensor] = None
51
+ metrics: dict = field(default_factory=dict)
52
+
53
+ class BaseQuantizer(nn.Module):
54
+ """Base class for quantizers.
55
+ """
56
+
57
+ def forward(self, x: torch.Tensor, frame_rate: int) -> QuantizedResult:
58
+ """
59
+ Given input tensor x, returns first the quantized (or approximately quantized)
60
+ representation along with quantized codes, bandwidth, and any penalty term for the loss.
61
+ Finally, this returns a dict of metrics to update logging etc.
62
+ Frame rate must be passed so that the bandwidth is properly computed.
63
+ """
64
+ raise NotImplementedError()
65
+
66
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
67
+ """Encode a given input tensor with the specified sample rate at the given bandwidth."""
68
+ raise NotImplementedError()
69
+
70
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
71
+ """Decode the given codes to the quantized representation."""
72
+ raise NotImplementedError()
73
+
74
+ @property
75
+ def total_codebooks(self):
76
+ """Total number of codebooks."""
77
+ raise NotImplementedError()
78
+
79
+ @property
80
+ def num_codebooks(self):
81
+ """Number of active codebooks."""
82
+ raise NotImplementedError()
83
+
84
+ def set_num_codebooks(self, n: int):
85
+ """Set the number of active codebooks."""
86
+ raise NotImplementedError()
87
+
88
+ class CompressionModel(ABC, nn.Module):
89
+ """Base API for all compression model that aim at being used as audio tokenizers
90
+ with a language model.
91
+ """
92
+
93
+ @abstractmethod
94
+ def forward(self, x: torch.Tensor) -> QuantizedResult:
95
+ ...
96
+
97
+ @abstractmethod
98
+ def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
99
+ """See `EncodecModel.encode`."""
100
+ ...
101
+
102
+ @abstractmethod
103
+ def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
104
+ """See `EncodecModel.decode`."""
105
+ ...
106
+
107
+ @abstractmethod
108
+ def decode_latent(self, codes: torch.Tensor):
109
+ """Decode from the discrete codes to continuous latent space."""
110
+ ...
111
+
112
+ @property
113
+ @abstractmethod
114
+ def channels(self) -> int:
115
+ ...
116
+
117
+ @property
118
+ @abstractmethod
119
+ def frame_rate(self) -> float:
120
+ ...
121
+
122
+ @property
123
+ @abstractmethod
124
+ def sample_rate(self) -> int:
125
+ ...
126
+
127
+ @property
128
+ @abstractmethod
129
+ def cardinality(self) -> int:
130
+ ...
131
+
132
+ @property
133
+ @abstractmethod
134
+ def num_codebooks(self) -> int:
135
+ ...
136
+
137
+ @property
138
+ @abstractmethod
139
+ def total_codebooks(self) -> int:
140
+ ...
141
+
142
+ @abstractmethod
143
+ def set_num_codebooks(self, n: int):
144
+ """Set the active number of codebooks used by the quantizer."""
145
+ ...
146
+
147
+ def apply_parametrization_norm(module: nn.Module, norm: str = 'none'):
148
+ assert norm in CONV_NORMALIZATIONS
149
+ if norm == 'weight_norm':
150
+ return weight_norm(module)
151
+ elif norm == 'spectral_norm':
152
+ return spectral_norm(module)
153
+ else:
154
+ # We already check was in CONV_NORMALIZATION, so any other choice
155
+ # doesn't need reparametrization.
156
+ return module
157
+
158
+
159
+ def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs):
160
+ """Return the proper normalization module. If causal is True, this will ensure the returned
161
+ module is causal, or return an error if the normalization doesn't support causal evaluation.
162
+ """
163
+ assert norm in CONV_NORMALIZATIONS
164
+ if norm == 'time_group_norm':
165
+ if causal:
166
+ raise ValueError("GroupNorm doesn't support causal evaluation.")
167
+ assert isinstance(module, nn.modules.conv._ConvNd)
168
+ return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
169
+ else:
170
+ return nn.Identity()
171
+
172
+
173
+ def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
174
+ padding_total: int = 0) -> int:
175
+ """See `pad_for_conv1d`."""
176
+ length = x.shape[-1]
177
+ n_frames = (length - kernel_size + padding_total) / stride + 1
178
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
179
+ return ideal_length - length
180
+
181
+
182
+ def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
183
+ """Pad for a convolution to make sure that the last window is full.
184
+ Extra padding is added at the end. This is required to ensure that we can rebuild
185
+ an output of the same length, as otherwise, even with padding, some time steps
186
+ might get removed.
187
+ For instance, with total padding = 4, kernel size = 4, stride = 2:
188
+ 0 0 1 2 3 4 5 0 0 # (0s are padding)
189
+ 1 2 3 # (output frames of a convolution, last 0 is never used)
190
+ 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
191
+ 1 2 3 4 # once you removed padding, we are missing one time step !
192
+ """
193
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
194
+ return F.pad(x, (0, extra_padding))
195
+
196
+
197
+ def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
198
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
199
+ If this is the case, we insert extra 0 padding to the right before the reflection happen.
200
+ """
201
+ length = x.shape[-1]
202
+ padding_left, padding_right = paddings
203
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
204
+ if mode == 'reflect':
205
+ max_pad = max(padding_left, padding_right)
206
+ extra_pad = 0
207
+ if length <= max_pad:
208
+ extra_pad = max_pad - length + 1
209
+ x = F.pad(x, (0, extra_pad))
210
+ padded = F.pad(x, paddings, mode, value)
211
+ end = padded.shape[-1] - extra_pad
212
+ return padded[..., :end]
213
+ else:
214
+ return F.pad(x, paddings, mode, value)
215
+
216
+
217
+ def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
218
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
219
+ padding_left, padding_right = paddings
220
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
221
+ assert (padding_left + padding_right) <= x.shape[-1]
222
+ end = x.shape[-1] - padding_right
223
+ return x[..., padding_left: end]
224
+
225
+
226
+ class NormConv1d(nn.Module):
227
+ """Wrapper around Conv1d and normalization applied to this conv
228
+ to provide a uniform interface across normalization approaches.
229
+ """
230
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
231
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
232
+ super().__init__()
233
+ self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
234
+ self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
235
+ self.norm_type = norm
236
+
237
+ def forward(self, x):
238
+ x = self.conv(x)
239
+ x = self.norm(x)
240
+ return x
241
+
242
+
243
+ class NormConv2d(nn.Module):
244
+ """Wrapper around Conv2d and normalization applied to this conv
245
+ to provide a uniform interface across normalization approaches.
246
+ """
247
+ def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
248
+ super().__init__()
249
+ self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
250
+ self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
251
+ self.norm_type = norm
252
+
253
+ def forward(self, x):
254
+ x = self.conv(x)
255
+ x = self.norm(x)
256
+ return x
257
+
258
+
259
+ class NormConvTranspose1d(nn.Module):
260
+ """Wrapper around ConvTranspose1d and normalization applied to this conv
261
+ to provide a uniform interface across normalization approaches.
262
+ """
263
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
264
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
265
+ super().__init__()
266
+ self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
267
+ self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
268
+ self.norm_type = norm
269
+
270
+ def forward(self, x):
271
+ x = self.convtr(x)
272
+ x = self.norm(x)
273
+ return x
274
+
275
+
276
+ class NormConvTranspose2d(nn.Module):
277
+ """Wrapper around ConvTranspose2d and normalization applied to this conv
278
+ to provide a uniform interface across normalization approaches.
279
+ """
280
+ def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
281
+ super().__init__()
282
+ self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
283
+ self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
284
+
285
+ def forward(self, x):
286
+ x = self.convtr(x)
287
+ x = self.norm(x)
288
+ return x
289
+
290
+
291
+ class StreamableConv1d(nn.Module):
292
+ """Conv1d with some builtin handling of asymmetric or causal padding
293
+ and normalization.
294
+ """
295
+ def __init__(self, in_channels: int, out_channels: int,
296
+ kernel_size: int, stride: int = 1, dilation: int = 1,
297
+ groups: int = 1, bias: bool = True, causal: bool = False,
298
+ norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
299
+ pad_mode: str = 'reflect'):
300
+ super().__init__()
301
+ # warn user on unusual setup between dilation and stride
302
+ if stride > 1 and dilation > 1:
303
+ warnings.warn("StreamableConv1d has been initialized with stride > 1 and dilation > 1"
304
+ f" (kernel_size={kernel_size} stride={stride}, dilation={dilation}).")
305
+ self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
306
+ dilation=dilation, groups=groups, bias=bias, causal=causal,
307
+ norm=norm, norm_kwargs=norm_kwargs)
308
+ self.causal = causal
309
+ self.pad_mode = pad_mode
310
+
311
+ def forward(self, x):
312
+ B, C, T = x.shape
313
+ kernel_size = self.conv.conv.kernel_size[0]
314
+ stride = self.conv.conv.stride[0]
315
+ dilation = self.conv.conv.dilation[0]
316
+ kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
317
+ padding_total = kernel_size - stride
318
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
319
+ if self.causal:
320
+ # Left padding for causal
321
+ x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
322
+ else:
323
+ # Asymmetric padding required for odd strides
324
+ padding_right = padding_total // 2
325
+ padding_left = padding_total - padding_right
326
+ x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
327
+ return self.conv(x)
328
+
329
+
330
+ class StreamableConvTranspose1d(nn.Module):
331
+ """ConvTranspose1d with some builtin handling of asymmetric or causal padding
332
+ and normalization.
333
+ """
334
+ def __init__(self, in_channels: int, out_channels: int,
335
+ kernel_size: int, stride: int = 1, causal: bool = False,
336
+ norm: str = 'none', trim_right_ratio: float = 1.,
337
+ norm_kwargs: tp.Dict[str, tp.Any] = {}):
338
+ super().__init__()
339
+ self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
340
+ causal=causal, norm=norm, norm_kwargs=norm_kwargs)
341
+ self.causal = causal
342
+ self.trim_right_ratio = trim_right_ratio
343
+ assert self.causal or self.trim_right_ratio == 1., \
344
+ "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
345
+ assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
346
+
347
+ def forward(self, x):
348
+ kernel_size = self.convtr.convtr.kernel_size[0]
349
+ stride = self.convtr.convtr.stride[0]
350
+ padding_total = kernel_size - stride
351
+
352
+ y = self.convtr(x)
353
+
354
+ # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
355
+ # removed at the very end, when keeping only the right length for the output,
356
+ # as removing it here would require also passing the length at the matching layer
357
+ # in the encoder.
358
+ if self.causal:
359
+ # Trim the padding on the right according to the specified ratio
360
+ # if trim_right_ratio = 1.0, trim everything from right
361
+ padding_right = math.ceil(padding_total * self.trim_right_ratio)
362
+ padding_left = padding_total - padding_right
363
+ y = unpad1d(y, (padding_left, padding_right))
364
+ else:
365
+ # Asymmetric padding required for odd strides
366
+ padding_right = padding_total // 2
367
+ padding_left = padding_total - padding_right
368
+ y = unpad1d(y, (padding_left, padding_right))
369
+ return y
370
+
371
+
372
+ class StreamableLSTM(nn.Module):
373
+ """LSTM without worrying about the hidden state, nor the layout of the data.
374
+ Expects input as convolutional layout.
375
+ """
376
+ def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
377
+ super().__init__()
378
+ self.skip = skip
379
+ self.lstm = nn.LSTM(dimension, dimension, num_layers)
380
+
381
+ def forward(self, x):
382
+ x = x.permute(2, 0, 1)
383
+ y, _ = self.lstm(x)
384
+ if self.skip:
385
+ y = y + x
386
+ y = y.permute(1, 2, 0)
387
+ return y
388
+
389
+
390
+ class SEANetResnetBlock(nn.Module):
391
+ """Residual block from SEANet model.
392
+
393
+ Args:
394
+ dim (int): Dimension of the input/output.
395
+ kernel_sizes (list): List of kernel sizes for the convolutions.
396
+ dilations (list): List of dilations for the convolutions.
397
+ activation (str): Activation function.
398
+ activation_params (dict): Parameters to provide to the activation function.
399
+ norm (str): Normalization method.
400
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
401
+ causal (bool): Whether to use fully causal convolution.
402
+ pad_mode (str): Padding mode for the convolutions.
403
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3).
404
+ true_skip (bool): Whether to use true skip connection or a simple
405
+ (streamable) convolution as the skip connection.
406
+ """
407
+ def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1],
408
+ activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
409
+ norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False,
410
+ pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True):
411
+ super().__init__()
412
+ assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations'
413
+ act = getattr(nn, activation)
414
+ hidden = dim // compress
415
+ block = []
416
+ for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
417
+ in_chs = dim if i == 0 else hidden
418
+ out_chs = dim if i == len(kernel_sizes) - 1 else hidden
419
+ block += [
420
+ act(**activation_params),
421
+ StreamableConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation,
422
+ norm=norm, norm_kwargs=norm_params,
423
+ causal=causal, pad_mode=pad_mode),
424
+ ]
425
+ self.block = nn.Sequential(*block)
426
+ self.shortcut: nn.Module
427
+ if true_skip:
428
+ self.shortcut = nn.Identity()
429
+ else:
430
+ self.shortcut = StreamableConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params,
431
+ causal=causal, pad_mode=pad_mode)
432
+
433
+ def forward(self, x):
434
+ return self.shortcut(x) + self.block(x)
435
+
436
+
437
+ class SEANetEncoder(nn.Module):
438
+ """SEANet encoder.
439
+
440
+ Args:
441
+ channels (int): Audio channels.
442
+ dimension (int): Intermediate representation dimension.
443
+ n_filters (int): Base width for the model.
444
+ n_residual_layers (int): nb of residual layers.
445
+ ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
446
+ upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
447
+ that must match the decoder order. We use the decoder order as some models may only employ the decoder.
448
+ activation (str): Activation function.
449
+ activation_params (dict): Parameters to provide to the activation function.
450
+ norm (str): Normalization method.
451
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
452
+ kernel_size (int): Kernel size for the initial convolution.
453
+ last_kernel_size (int): Kernel size for the initial convolution.
454
+ residual_kernel_size (int): Kernel size for the residual layers.
455
+ dilation_base (int): How much to increase the dilation with each layer.
456
+ causal (bool): Whether to use fully causal convolution.
457
+ pad_mode (str): Padding mode for the convolutions.
458
+ true_skip (bool): Whether to use true skip connection or a simple
459
+ (streamable) convolution as the skip connection in the residual network blocks.
460
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3).
461
+ lstm (int): Number of LSTM layers at the end of the encoder.
462
+ disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
463
+ For the encoder, it corresponds to the N first blocks.
464
+ """
465
+ def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
466
+ ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
467
+ norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
468
+ last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
469
+ pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0,
470
+ disable_norm_outer_blocks: int = 0):
471
+ super().__init__()
472
+ self.channels = channels
473
+ self.dimension = dimension
474
+ self.n_filters = n_filters
475
+ self.ratios = list(reversed(ratios))
476
+ del ratios
477
+ self.n_residual_layers = n_residual_layers
478
+ self.hop_length = np.prod(self.ratios)
479
+ self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
480
+ self.disable_norm_outer_blocks = disable_norm_outer_blocks
481
+ assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \
482
+ "Number of blocks for which to disable norm is invalid." \
483
+ "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
484
+
485
+ act = getattr(nn, activation)
486
+ mult = 1
487
+ model: tp.List[nn.Module] = [
488
+ StreamableConv1d(channels, mult * n_filters, kernel_size,
489
+ norm='none' if self.disable_norm_outer_blocks >= 1 else norm,
490
+ norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
491
+ ]
492
+ # Downsample to raw audio scale
493
+ for i, ratio in enumerate(self.ratios):
494
+ block_norm = 'none' if self.disable_norm_outer_blocks >= i + 2 else norm
495
+ # Add residual layers
496
+ for j in range(n_residual_layers):
497
+ model += [
498
+ SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1],
499
+ dilations=[dilation_base ** j, 1],
500
+ norm=block_norm, norm_params=norm_params,
501
+ activation=activation, activation_params=activation_params,
502
+ causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
503
+
504
+ # Add downsampling layers
505
+ model += [
506
+ act(**activation_params),
507
+ StreamableConv1d(mult * n_filters, mult * n_filters * 2,
508
+ kernel_size=ratio * 2, stride=ratio,
509
+ norm=block_norm, norm_kwargs=norm_params,
510
+ causal=causal, pad_mode=pad_mode),
511
+ ]
512
+ mult *= 2
513
+
514
+ if lstm:
515
+ model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
516
+
517
+ model += [
518
+ act(**activation_params),
519
+ StreamableConv1d(mult * n_filters, dimension, last_kernel_size,
520
+ norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm,
521
+ norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
522
+ ]
523
+
524
+ self.model = nn.Sequential(*model)
525
+
526
+ def forward(self, x):
527
+ return self.model(x)
528
+
529
+
530
+ class SEANetDecoder(nn.Module):
531
+ """SEANet decoder.
532
+
533
+ Args:
534
+ channels (int): Audio channels.
535
+ dimension (int): Intermediate representation dimension.
536
+ n_filters (int): Base width for the model.
537
+ n_residual_layers (int): nb of residual layers.
538
+ ratios (Sequence[int]): kernel size and stride ratios.
539
+ activation (str): Activation function.
540
+ activation_params (dict): Parameters to provide to the activation function.
541
+ final_activation (str): Final activation function after all convolutions.
542
+ final_activation_params (dict): Parameters to provide to the activation function.
543
+ norm (str): Normalization method.
544
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
545
+ kernel_size (int): Kernel size for the initial convolution.
546
+ last_kernel_size (int): Kernel size for the initial convolution.
547
+ residual_kernel_size (int): Kernel size for the residual layers.
548
+ dilation_base (int): How much to increase the dilation with each layer.
549
+ causal (bool): Whether to use fully causal convolution.
550
+ pad_mode (str): Padding mode for the convolutions.
551
+ true_skip (bool): Whether to use true skip connection or a simple.
552
+ (streamable) convolution as the skip connection in the residual network blocks.
553
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3).
554
+ lstm (int): Number of LSTM layers at the end of the encoder.
555
+ disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
556
+ For the decoder, it corresponds to the N last blocks.
557
+ trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
558
+ If equal to 1.0, it means that all the trimming is done at the right.
559
+ """
560
+ def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
561
+ ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
562
+ final_activation: tp.Optional[str] = None, final_activation_params: tp.Optional[dict] = None,
563
+ norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
564
+ last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
565
+ pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0,
566
+ disable_norm_outer_blocks: int = 0, trim_right_ratio: float = 1.0):
567
+ super().__init__()
568
+ self.dimension = dimension
569
+ self.channels = channels
570
+ self.n_filters = n_filters
571
+ self.ratios = ratios
572
+ del ratios
573
+ self.n_residual_layers = n_residual_layers
574
+ self.hop_length = np.prod(self.ratios)
575
+ self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
576
+ self.disable_norm_outer_blocks = disable_norm_outer_blocks
577
+ assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \
578
+ "Number of blocks for which to disable norm is invalid." \
579
+ "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
580
+
581
+ act = getattr(nn, activation)
582
+ mult = int(2 ** len(self.ratios))
583
+ model: tp.List[nn.Module] = [
584
+ StreamableConv1d(dimension, mult * n_filters, kernel_size,
585
+ norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm,
586
+ norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
587
+ ]
588
+
589
+ if lstm:
590
+ model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
591
+
592
+ # Upsample to raw audio scale
593
+ for i, ratio in enumerate(self.ratios):
594
+ block_norm = 'none' if self.disable_norm_outer_blocks >= self.n_blocks - (i + 1) else norm
595
+ # Add upsampling layers
596
+ model += [
597
+ act(**activation_params),
598
+ StreamableConvTranspose1d(mult * n_filters, mult * n_filters // 2,
599
+ kernel_size=ratio * 2, stride=ratio,
600
+ norm=block_norm, norm_kwargs=norm_params,
601
+ causal=causal, trim_right_ratio=trim_right_ratio),
602
+ ]
603
+ # Add residual layers
604
+ for j in range(n_residual_layers):
605
+ model += [
606
+ SEANetResnetBlock(mult * n_filters // 2, kernel_sizes=[residual_kernel_size, 1],
607
+ dilations=[dilation_base ** j, 1],
608
+ activation=activation, activation_params=activation_params,
609
+ norm=block_norm, norm_params=norm_params, causal=causal,
610
+ pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
611
+
612
+ mult //= 2
613
+
614
+ # Add final layers
615
+ model += [
616
+ act(**activation_params),
617
+ StreamableConv1d(n_filters, channels, last_kernel_size,
618
+ norm='none' if self.disable_norm_outer_blocks >= 1 else norm,
619
+ norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
620
+ ]
621
+ # Add optional final activation to decoder (eg. tanh)
622
+ if final_activation is not None:
623
+ final_act = getattr(nn, final_activation)
624
+ final_activation_params = final_activation_params or {}
625
+ model += [
626
+ final_act(**final_activation_params)
627
+ ]
628
+ self.model = nn.Sequential(*model)
629
+
630
+ def forward(self, z):
631
+ y = self.model(z)
632
+ return y
633
+
634
+
635
+ def exists(val: tp.Optional[tp.Any]) -> bool:
636
+ return val is not None
637
+
638
+
639
+ def default(val: tp.Any, d: tp.Any) -> tp.Any:
640
+ return val if exists(val) else d
641
+
642
+
643
+ def l2norm(t):
644
+ return F.normalize(t, p=2, dim=-1)
645
+
646
+
647
+ def ema_inplace(moving_avg, new, decay: float):
648
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
649
+
650
+
651
+ def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
652
+ return (x + epsilon) / (x.sum() + n_categories * epsilon)
653
+
654
+
655
+ def uniform_init(*shape: int):
656
+ t = torch.empty(shape)
657
+ nn.init.kaiming_uniform_(t)
658
+ return t
659
+
660
+
661
+ def sample_vectors(samples, num: int):
662
+ num_samples, device = samples.shape[0], samples.device
663
+
664
+ if num_samples >= num:
665
+ indices = torch.randperm(num_samples, device=device)[:num]
666
+ else:
667
+ indices = torch.randint(0, num_samples, (num,), device=device)
668
+
669
+ return samples[indices]
670
+
671
+
672
+ def kmeans(samples, num_clusters: int, num_iters: int = 10):
673
+ dim, dtype = samples.shape[-1], samples.dtype
674
+
675
+ means = sample_vectors(samples, num_clusters)
676
+
677
+ for _ in range(num_iters):
678
+ diffs = rearrange(samples, "n d -> n () d") - rearrange(
679
+ means, "c d -> () c d"
680
+ )
681
+ dists = -(diffs ** 2).sum(dim=-1)
682
+
683
+ buckets = dists.max(dim=-1).indices
684
+ bins = torch.bincount(buckets, minlength=num_clusters)
685
+ zero_mask = bins == 0
686
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
687
+
688
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
689
+ new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
690
+ new_means = new_means / bins_min_clamped[..., None]
691
+
692
+ means = torch.where(zero_mask[..., None], means, new_means)
693
+
694
+ return means, bins
695
+
696
+
697
+ def orthogonal_loss_fn(t):
698
+ # eq (2) from https://arxiv.org/abs/2112.00384
699
+ n = t.shape[0]
700
+ normed_codes = l2norm(t)
701
+ identity = torch.eye(n, device=t.device)
702
+ cosine_sim = einsum("i d, j d -> i j", normed_codes, normed_codes)
703
+ return ((cosine_sim - identity) ** 2).sum() / (n ** 2)
704
+
705
+
706
+ class EuclideanCodebook(nn.Module):
707
+ """Codebook with Euclidean distance.
708
+
709
+ Args:
710
+ dim (int): Dimension.
711
+ codebook_size (int): Codebook size.
712
+ kmeans_init (bool): Whether to use k-means to initialize the codebooks.
713
+ If set to true, run the k-means algorithm on the first training batch and use
714
+ the learned centroids as initialization.
715
+ kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
716
+ decay (float): Decay for exponential moving average over the codebooks.
717
+ epsilon (float): Epsilon value for numerical stability.
718
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
719
+ that have an exponential moving average cluster size less than the specified threshold with
720
+ randomly selected vector from the current batch.
721
+ """
722
+ def __init__(
723
+ self,
724
+ dim: int,
725
+ codebook_size: int,
726
+ kmeans_init: int = False,
727
+ kmeans_iters: int = 10,
728
+ decay: float = 0.8,
729
+ epsilon: float = 1e-5,
730
+ threshold_ema_dead_code: int = 2,
731
+ ):
732
+ super().__init__()
733
+ self.decay = decay
734
+ init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
735
+ embed = init_fn(codebook_size, dim)
736
+
737
+ self.codebook_size = codebook_size
738
+
739
+ self.kmeans_iters = kmeans_iters
740
+ self.epsilon = epsilon
741
+ self.threshold_ema_dead_code = threshold_ema_dead_code
742
+
743
+ self.register_buffer("inited", torch.Tensor([not kmeans_init]))
744
+ self.register_buffer("cluster_size", torch.zeros(codebook_size))
745
+ self.register_buffer("embed", embed)
746
+ self.register_buffer("embed_avg", embed.clone())
747
+
748
+ @torch.jit.ignore
749
+ def init_embed_(self, data):
750
+ if self.inited:
751
+ return
752
+
753
+ embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
754
+ self.embed.data.copy_(embed)
755
+ self.embed_avg.data.copy_(embed.clone())
756
+ self.cluster_size.data.copy_(cluster_size)
757
+ self.inited.data.copy_(torch.Tensor([True]))
758
+ # Make sure all buffers across workers are in sync after initialization
759
+ flashy.distrib.broadcast_tensors(self.buffers())
760
+
761
+ def replace_(self, samples, mask):
762
+ modified_codebook = torch.where(
763
+ mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
764
+ )
765
+ self.embed.data.copy_(modified_codebook)
766
+
767
+ def expire_codes_(self, batch_samples):
768
+ if self.threshold_ema_dead_code == 0:
769
+ return
770
+
771
+ expired_codes = self.cluster_size < self.threshold_ema_dead_code
772
+ if not torch.any(expired_codes):
773
+ return
774
+
775
+ batch_samples = rearrange(batch_samples, "... d -> (...) d")
776
+ self.replace_(batch_samples, mask=expired_codes)
777
+ flashy.distrib.broadcast_tensors(self.buffers())
778
+
779
+ def preprocess(self, x):
780
+ x = rearrange(x, "... d -> (...) d")
781
+ return x
782
+
783
+ def quantize(self, x):
784
+ embed = self.embed.t()
785
+ dist = -(
786
+ x.pow(2).sum(1, keepdim=True)
787
+ - 2 * x @ embed
788
+ + embed.pow(2).sum(0, keepdim=True)
789
+ )
790
+ embed_ind = dist.max(dim=-1).indices
791
+ return embed_ind
792
+
793
+ def postprocess_emb(self, embed_ind, shape):
794
+ return embed_ind.view(*shape[:-1])
795
+
796
+ def dequantize(self, embed_ind):
797
+ quantize = F.embedding(embed_ind, self.embed)
798
+ return quantize
799
+
800
+ def encode(self, x):
801
+ shape = x.shape
802
+ # pre-process
803
+ x = self.preprocess(x)
804
+ # quantize
805
+ embed_ind = self.quantize(x)
806
+ # post-process
807
+ embed_ind = self.postprocess_emb(embed_ind, shape)
808
+ return embed_ind
809
+
810
+ def decode(self, embed_ind):
811
+ quantize = self.dequantize(embed_ind)
812
+ return quantize
813
+
814
+ def forward(self, x):
815
+ raise NotImplementedError()
816
+ shape, dtype = x.shape, x.dtype
817
+ x = self.preprocess(x)
818
+ self.init_embed_(x)
819
+
820
+ embed_ind = self.quantize(x)
821
+ embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
822
+ embed_ind = self.postprocess_emb(embed_ind, shape)
823
+ quantize = self.dequantize(embed_ind)
824
+
825
+ if self.training:
826
+ # We do the expiry of code at that point as buffers are in sync
827
+ # and all the workers will take the same decision.
828
+ self.expire_codes_(x)
829
+ ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
830
+ embed_sum = x.t() @ embed_onehot
831
+ ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
832
+ cluster_size = (
833
+ laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
834
+ * self.cluster_size.sum()
835
+ )
836
+ embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
837
+ self.embed.data.copy_(embed_normalized)
838
+
839
+ return quantize, embed_ind
840
+
841
+
842
+ class VectorQuantization(nn.Module):
843
+ """Vector quantization implementation.
844
+ Currently supports only euclidean distance.
845
+
846
+ Args:
847
+ dim (int): Dimension
848
+ codebook_size (int): Codebook size
849
+ codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
850
+ decay (float): Decay for exponential moving average over the codebooks.
851
+ epsilon (float): Epsilon value for numerical stability.
852
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
853
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
854
+ threshold_ema_dead_code (int):
855
+ channels_last (bool): Channels are the last dimension in the input tensors.
856
+ commitment_weight (float): Weight for commitment loss.
857
+ orthogonal_reg_weight (float): Orthogonal regularization weights.
858
+ orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes.
859
+ orthogonal_reg_max_codes (optional int): Maximum number of codes to consider
860
+ for orthogonal regularization.
861
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
862
+ that have an exponential moving average cluster size less than the specified threshold with
863
+ randomly selected vector from the current batch.
864
+ """
865
+ def __init__(
866
+ self,
867
+ dim: int,
868
+ codebook_size: int,
869
+ codebook_dim: tp.Optional[int] = None,
870
+ decay: float = 0.8,
871
+ epsilon: float = 1e-5,
872
+ kmeans_init: bool = False,
873
+ kmeans_iters: int = 10,
874
+ threshold_ema_dead_code: int = 2,
875
+ channels_last: bool = False,
876
+ commitment_weight: float = 1.,
877
+ orthogonal_reg_weight: float = 0.0,
878
+ orthogonal_reg_active_codes_only: bool = False,
879
+ orthogonal_reg_max_codes: tp.Optional[int] = None,
880
+ ):
881
+ super().__init__()
882
+ _codebook_dim: int = default(codebook_dim, dim)
883
+
884
+ requires_projection = _codebook_dim != dim
885
+ self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity())
886
+ self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity())
887
+
888
+ self.epsilon = epsilon
889
+ self.commitment_weight = commitment_weight
890
+
891
+ self.orthogonal_reg_weight = orthogonal_reg_weight
892
+ self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
893
+ self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
894
+
895
+ self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
896
+ kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
897
+ decay=decay, epsilon=epsilon,
898
+ threshold_ema_dead_code=threshold_ema_dead_code)
899
+ self.codebook_size = codebook_size
900
+
901
+ self.channels_last = channels_last
902
+
903
+ @property
904
+ def codebook(self):
905
+ return self._codebook.embed
906
+
907
+ @property
908
+ def inited(self):
909
+ return self._codebook.inited
910
+
911
+ def _preprocess(self, x):
912
+ if not self.channels_last:
913
+ x = rearrange(x, "b d n -> b n d")
914
+ return x
915
+
916
+ def _postprocess(self, quantize):
917
+ if not self.channels_last:
918
+ quantize = rearrange(quantize, "b n d -> b d n")
919
+ return quantize
920
+
921
+ def encode(self, x):
922
+ x = self._preprocess(x)
923
+ x = self.project_in(x)
924
+ embed_in = self._codebook.encode(x)
925
+ return embed_in
926
+
927
+ def decode(self, embed_ind):
928
+ quantize = self._codebook.decode(embed_ind)
929
+ quantize = self.project_out(quantize)
930
+ quantize = self._postprocess(quantize)
931
+ return quantize
932
+
933
+ def forward(self, x):
934
+ device = x.device
935
+ x = self._preprocess(x)
936
+
937
+ x = self.project_in(x)
938
+ quantize, embed_ind = self._codebook(x)
939
+
940
+ if self.training:
941
+ quantize = x + (quantize - x).detach()
942
+
943
+ loss = torch.tensor([0.0], device=device, requires_grad=self.training)
944
+
945
+ if self.training:
946
+ if self.commitment_weight > 0:
947
+ commit_loss = F.mse_loss(quantize.detach(), x)
948
+ loss = loss + commit_loss * self.commitment_weight
949
+
950
+ if self.orthogonal_reg_weight > 0:
951
+ codebook = self.codebook
952
+
953
+ if self.orthogonal_reg_active_codes_only:
954
+ # only calculate orthogonal loss for the activated codes for this batch
955
+ unique_code_ids = torch.unique(embed_ind)
956
+ codebook = codebook[unique_code_ids]
957
+
958
+ num_codes = codebook.shape[0]
959
+ if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes:
960
+ rand_ids = torch.randperm(num_codes, device=device)[:self.orthogonal_reg_max_codes]
961
+ codebook = codebook[rand_ids]
962
+
963
+ orthogonal_reg_loss = orthogonal_loss_fn(codebook)
964
+ loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight
965
+
966
+ quantize = self.project_out(quantize)
967
+ quantize = self._postprocess(quantize)
968
+
969
+ return quantize, embed_ind, loss
970
+
971
+
972
+ class ResidualVectorQuantization(nn.Module):
973
+ """Residual vector quantization implementation.
974
+
975
+ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
976
+ """
977
+ def __init__(self, *, num_quantizers, **kwargs):
978
+ super().__init__()
979
+ codebook_size = kwargs.pop('codebook_size', None)
980
+ if codebook_size is None:
981
+ raise ValueError("codebook_size must be provided in kwargs")
982
+ if type(codebook_size) != list:
983
+ codebook_size = [codebook_size] * num_quantizers
984
+ self.layers = nn.ModuleList(
985
+ [VectorQuantization(codebook_size=cur_codebook_size, **kwargs) for _,cur_codebook_size in zip(range(num_quantizers), codebook_size)]
986
+ )
987
+
988
+
989
+ # self.layers = nn.ModuleList(
990
+ # [VectorQuantization(**kwargs) for _ in range(num_quantizers)]
991
+ # )
992
+
993
+ def forward(self, x, n_q: tp.Optional[int] = None):
994
+ quantized_out = 0.0
995
+ residual = x
996
+
997
+ all_losses = []
998
+ all_indices = []
999
+
1000
+ n_q = n_q or len(self.layers)
1001
+
1002
+ for i, layer in enumerate(self.layers[:n_q]):
1003
+ quantized, indices, loss = layer(residual)
1004
+ residual = residual - quantized
1005
+ quantized_out = quantized_out + quantized
1006
+ all_indices.append(indices)
1007
+ all_losses.append(loss)
1008
+
1009
+ out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
1010
+ return quantized_out, out_indices, out_losses
1011
+
1012
+ def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
1013
+ residual = x
1014
+ all_indices = []
1015
+ n_q = n_q or len(self.layers)
1016
+ for layer in self.layers[:n_q]:
1017
+ indices = layer.encode(residual)
1018
+ quantized = layer.decode(indices)
1019
+ # the original code is below
1020
+ # since quantize has the gradient of residual, according to line 321
1021
+ # quantize = x + (quantize - x).detach()
1022
+ # the code below will make commitment loss to be 0 for all codebooks except for codebook1
1023
+ # https://github.com/facebookresearch/encodec/issues/25
1024
+ # therefore we change it
1025
+
1026
+ residual = residual - quantized
1027
+ # residual = residual - quantized.detach()
1028
+ # since commitment loss is averaged, the scale of the loss won't get change (not as said in the issue above)
1029
+ all_indices.append(indices)
1030
+ out_indices = torch.stack(all_indices)
1031
+ return out_indices
1032
+
1033
+ def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
1034
+ quantized_out = torch.tensor(0.0, device=q_indices.device)
1035
+ for i, indices in enumerate(q_indices):
1036
+ layer = self.layers[i]
1037
+ quantized = layer.decode(indices)
1038
+ quantized_out = quantized_out + quantized
1039
+ return quantized_out
1040
+
1041
+
1042
+ class ResidualVectorQuantizer(BaseQuantizer):
1043
+ """Residual Vector Quantizer.
1044
+
1045
+ Args:
1046
+ dimension (int): Dimension of the codebooks.
1047
+ n_q (int): Number of residual vector quantizers used.
1048
+ q_dropout (bool): Random quantizer drop out at train time.
1049
+ bins (int): Codebook size.
1050
+ decay (float): Decay for exponential moving average over the codebooks.
1051
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
1052
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
1053
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
1054
+ that have an exponential moving average cluster size less than the specified threshold with
1055
+ randomly selected vector from the current batch.
1056
+ orthogonal_reg_weight (float): Orthogonal regularization weights.
1057
+ orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes.
1058
+ orthogonal_reg_max_codes (optional int): Maximum number of codes to consider.
1059
+ for orthogonal regularization.
1060
+ """
1061
+ def __init__(
1062
+ self,
1063
+ dimension: int = 256,
1064
+ n_q: int = 8,
1065
+ q_dropout: bool = False,
1066
+ bins: tp.Union[int, tp.List[int]] = 1024,
1067
+ decay: float = 0.99,
1068
+ kmeans_init: bool = True,
1069
+ kmeans_iters: int = 10,
1070
+ threshold_ema_dead_code: int = 2,
1071
+ orthogonal_reg_weight: float = 0.0,
1072
+ orthogonal_reg_active_codes_only: bool = False,
1073
+ orthogonal_reg_max_codes: tp.Optional[int] = None,
1074
+ ):
1075
+ super().__init__()
1076
+ self.max_n_q = n_q
1077
+ self.n_q = n_q
1078
+ self.q_dropout = q_dropout
1079
+ self.dimension = dimension
1080
+ self.bins = bins
1081
+ self.decay = decay
1082
+ self.kmeans_init = kmeans_init
1083
+ self.kmeans_iters = kmeans_iters
1084
+ self.threshold_ema_dead_code = threshold_ema_dead_code
1085
+ self.orthogonal_reg_weight = orthogonal_reg_weight
1086
+ self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
1087
+ self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
1088
+ self.vq = ResidualVectorQuantization(
1089
+ dim=self.dimension,
1090
+ codebook_size=self.bins,
1091
+ num_quantizers=self.n_q,
1092
+ decay=self.decay,
1093
+ kmeans_init=self.kmeans_init,
1094
+ kmeans_iters=self.kmeans_iters,
1095
+ threshold_ema_dead_code=self.threshold_ema_dead_code,
1096
+ orthogonal_reg_weight=self.orthogonal_reg_weight,
1097
+ orthogonal_reg_active_codes_only=self.orthogonal_reg_active_codes_only,
1098
+ orthogonal_reg_max_codes=self.orthogonal_reg_max_codes,
1099
+ channels_last=False
1100
+ )
1101
+
1102
+ def forward(self, x: torch.Tensor, frame_rate: int):
1103
+ n_q = self.n_q
1104
+ if self.training and self.q_dropout:
1105
+ n_q = int(torch.randint(1, self.n_q + 1, (1,)).item())
1106
+ if type(self.bins) == list:
1107
+ bins = self.bins
1108
+ else:
1109
+ bins = [self.bins] * self.n_q
1110
+ bw_per_q = [math.log2(bin) * frame_rate / 1000 for bin in bins]
1111
+ bw = torch.tensor(sum(bw_per_q)).to(x)
1112
+ quantized, codes, commit_loss = self.vq(x, n_q=n_q)
1113
+ codes = codes.transpose(0, 1)
1114
+ # codes is [B, K, T], with T frames, K nb of codebooks.
1115
+ return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss))
1116
+
1117
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
1118
+ """Encode a given input tensor with the specified frame rate at the given bandwidth.
1119
+ The RVQ encode method sets the appropriate number of quantizer to use
1120
+ and returns indices for each quantizer.
1121
+ """
1122
+ n_q = self.n_q
1123
+ codes = self.vq.encode(x, n_q=n_q)
1124
+ codes = codes.transpose(0, 1)
1125
+ # codes is [B, K, T], with T frames, K nb of codebooks.
1126
+ return codes
1127
+
1128
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
1129
+ """Decode the given codes to the quantized representation."""
1130
+ # codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T].
1131
+ codes = codes.transpose(0, 1)
1132
+ quantized = self.vq.decode(codes)
1133
+ return quantized
1134
+
1135
+ @property
1136
+ def total_codebooks(self):
1137
+ return self.max_n_q
1138
+
1139
+ @property
1140
+ def num_codebooks(self):
1141
+ return self.n_q
1142
+
1143
+ def set_num_codebooks(self, n: int):
1144
+ assert n > 0 and n <= self.max_n_q
1145
+ self.n_q = n
1146
+
1147
+ class DummyQuantizer(BaseQuantizer):
1148
+ """Fake quantizer that actually does not perform any quantization.
1149
+ """
1150
+ def __init__(self):
1151
+ super().__init__()
1152
+
1153
+ def forward(self, x: torch.Tensor, frame_rate: int):
1154
+ q = x.unsqueeze(1)
1155
+ return QuantizedResult(x, q, torch.tensor(q.numel() * 32 * frame_rate / 1000 / len(x)).to(x))
1156
+
1157
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
1158
+ """Encode a given input tensor with the specified sample rate at the given bandwidth.
1159
+ In the case of the DummyQuantizer, the codes are actually identical
1160
+ to the input and resulting quantized representation as no quantization is done.
1161
+ """
1162
+ return x.unsqueeze(1)
1163
+
1164
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
1165
+ """Decode the given codes to the quantized representation.
1166
+ In the case of the DummyQuantizer, the codes are actually identical
1167
+ to the input and resulting quantized representation as no quantization is done.
1168
+ """
1169
+ return codes.squeeze(1)
1170
+
1171
+ @property
1172
+ def total_codebooks(self):
1173
+ """Total number of codebooks."""
1174
+ return 1
1175
+
1176
+ @property
1177
+ def num_codebooks(self):
1178
+ """Total number of codebooks."""
1179
+ return self.total_codebooks
1180
+
1181
+ def set_num_codebooks(self, n: int):
1182
+ """Set the number of active codebooks."""
1183
+ raise AttributeError("Cannot override the number of codebooks for the dummy quantizer")
1184
+
1185
+
1186
+ class EncodecModel(CompressionModel):
1187
+ """Encodec model operating on the raw waveform.
1188
+
1189
+ Args:
1190
+ encoder (nn.Module): Encoder network.
1191
+ decoder (nn.Module): Decoder network.
1192
+ quantizer (BaseQuantizer): Quantizer network.
1193
+ frame_rate (int): Frame rate for the latent representation.
1194
+ sample_rate (int): Audio sample rate.
1195
+ channels (int): Number of audio channels.
1196
+ causal (bool): Whether to use a causal version of the model.
1197
+ renormalize (bool): Whether to renormalize the audio before running the model.
1198
+ """
1199
+ # we need assignment to override the property in the abstract class,
1200
+ # I couldn't find a better way...
1201
+ frame_rate: float = 0
1202
+ sample_rate: int = 0
1203
+ channels: int = 0
1204
+
1205
+ def __init__(self,
1206
+ encoder: nn.Module,
1207
+ decoder: nn.Module,
1208
+ quantizer: BaseQuantizer,
1209
+ frame_rate: int,
1210
+ sample_rate: int,
1211
+ channels: int,
1212
+ causal: bool = False,
1213
+ renormalize: bool = False):
1214
+ super().__init__()
1215
+ self.encoder = encoder
1216
+ self.decoder = decoder
1217
+ self.quantizer = quantizer
1218
+ self.frame_rate = frame_rate
1219
+ self.sample_rate = sample_rate
1220
+ self.channels = channels
1221
+ self.renormalize = renormalize
1222
+ self.causal = causal
1223
+ if self.causal:
1224
+ # we force disabling here to avoid handling linear overlap of segments
1225
+ # as supported in original EnCodec codebase.
1226
+ assert not self.renormalize, 'Causal model does not support renormalize'
1227
+
1228
+ @property
1229
+ def total_codebooks(self):
1230
+ """Total number of quantizer codebooks available."""
1231
+ return self.quantizer.total_codebooks
1232
+
1233
+ @property
1234
+ def num_codebooks(self):
1235
+ """Active number of codebooks used by the quantizer."""
1236
+ return self.quantizer.num_codebooks
1237
+
1238
+ def set_num_codebooks(self, n: int):
1239
+ """Set the active number of codebooks used by the quantizer."""
1240
+ self.quantizer.set_num_codebooks(n)
1241
+
1242
+ @property
1243
+ def cardinality(self):
1244
+ """Cardinality of each codebook."""
1245
+ return self.quantizer.bins
1246
+
1247
+ def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
1248
+ scale: tp.Optional[torch.Tensor]
1249
+ if self.renormalize:
1250
+ mono = x.mean(dim=1, keepdim=True)
1251
+ volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt()
1252
+ scale = 1e-8 + volume
1253
+ x = x / scale
1254
+ scale = scale.view(-1, 1)
1255
+ else:
1256
+ scale = None
1257
+ return x, scale
1258
+
1259
+ def postprocess(self,
1260
+ x: torch.Tensor,
1261
+ scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
1262
+ if scale is not None:
1263
+ assert self.renormalize
1264
+ x = x * scale.view(-1, 1, 1)
1265
+ return x
1266
+
1267
+ def forward(self, x: torch.Tensor, encode=False) -> QuantizedResult:
1268
+ if encode:
1269
+ return self.encode(x)
1270
+ else:
1271
+ raise NotImplementedError("model forward and training is not supported.")
1272
+ assert x.dim() == 3
1273
+ length = x.shape[-1]
1274
+ x, scale = self.preprocess(x)
1275
+
1276
+ emb = self.encoder(x)
1277
+ q_res = self.quantizer(emb, self.frame_rate)
1278
+ out = self.decoder(q_res.x)
1279
+
1280
+ # remove extra padding added by the encoder and decoder
1281
+ assert out.shape[-1] >= length, (out.shape[-1], length)
1282
+ out = out[..., :length]
1283
+
1284
+ q_res.x = self.postprocess(out, scale)
1285
+
1286
+ return q_res
1287
+
1288
+ def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
1289
+ """Encode the given input tensor to quantized representation along with scale parameter.
1290
+
1291
+ Args:
1292
+ x (torch.Tensor): Float tensor of shape [B, C, T]
1293
+
1294
+ Returns:
1295
+ codes, scale (tuple of torch.Tensor, torch.Tensor): Tuple composed of:
1296
+ codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep.
1297
+ scale a float tensor containing the scale for audio renormalizealization.
1298
+ """
1299
+ assert x.dim() == 3
1300
+ x, scale = self.preprocess(x)
1301
+ emb = self.encoder(x)
1302
+ codes = self.quantizer.encode(emb)
1303
+ return codes, scale
1304
+
1305
+ def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
1306
+ """Decode the given codes to a reconstructed representation, using the scale to perform
1307
+ audio denormalization if needed.
1308
+
1309
+ Args:
1310
+ codes (torch.Tensor): Int tensor of shape [B, K, T]
1311
+ scale (torch.Tensor, optional): Float tensor containing the scale value.
1312
+
1313
+ Returns:
1314
+ out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio.
1315
+ """
1316
+ emb = self.decode_latent(codes)
1317
+ out = self.decoder(emb)
1318
+ out = self.postprocess(out, scale)
1319
+ # out contains extra padding added by the encoder and decoder
1320
+ return out
1321
+
1322
+ def decode_latent(self, codes: torch.Tensor):
1323
+ """Decode from the discrete codes to continuous latent space."""
1324
+ return self.quantizer.decode(codes)
1325
+
1326
+ class EncodecModel_encode_only(CompressionModel):
1327
+ """Encodec model operating on the raw waveform. Encode only, so no decoder
1328
+
1329
+ Args:
1330
+ encoder (nn.Module): Encoder network.
1331
+ quantizer (BaseQuantizer): Quantizer network.
1332
+ frame_rate (int): Frame rate for the latent representation.
1333
+ sample_rate (int): Audio sample rate.
1334
+ channels (int): Number of audio channels.
1335
+ causal (bool): Whether to use a causal version of the model.
1336
+ renormalize (bool): Whether to renormalize the audio before running the model.
1337
+ """
1338
+ # we need assignment to override the property in the abstract class,
1339
+ # I couldn't find a better way...
1340
+ frame_rate: float = 0
1341
+ sample_rate: int = 0
1342
+ channels: int = 0
1343
+
1344
+ def __init__(self,
1345
+ encoder: nn.Module,
1346
+ quantizer: BaseQuantizer,
1347
+ frame_rate: int,
1348
+ sample_rate: int,
1349
+ channels: int,
1350
+ causal: bool = False,
1351
+ renormalize: bool = False):
1352
+ super().__init__()
1353
+ self.encoder = encoder
1354
+ self.quantizer = quantizer
1355
+ self.frame_rate = frame_rate
1356
+ self.sample_rate = sample_rate
1357
+ self.channels = channels
1358
+ self.renormalize = renormalize
1359
+ self.causal = causal
1360
+ if self.causal:
1361
+ # we force disabling here to avoid handling linear overlap of segments
1362
+ # as supported in original EnCodec codebase.
1363
+ assert not self.renormalize, 'Causal model does not support renormalize'
1364
+
1365
+ @property
1366
+ def total_codebooks(self):
1367
+ """Total number of quantizer codebooks available."""
1368
+ return self.quantizer.total_codebooks
1369
+
1370
+ @property
1371
+ def num_codebooks(self):
1372
+ """Active number of codebooks used by the quantizer."""
1373
+ return self.quantizer.num_codebooks
1374
+
1375
+ def set_num_codebooks(self, n: int):
1376
+ """Set the active number of codebooks used by the quantizer."""
1377
+ self.quantizer.set_num_codebooks(n)
1378
+
1379
+ @property
1380
+ def cardinality(self):
1381
+ """Cardinality of each codebook."""
1382
+ return self.quantizer.bins
1383
+
1384
+ def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
1385
+ scale: tp.Optional[torch.Tensor]
1386
+ if self.renormalize:
1387
+ mono = x.mean(dim=1, keepdim=True)
1388
+ volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt()
1389
+ scale = 1e-8 + volume
1390
+ x = x / scale
1391
+ scale = scale.view(-1, 1)
1392
+ else:
1393
+ scale = None
1394
+ return x, scale
1395
+
1396
+ def postprocess(self,
1397
+ x: torch.Tensor,
1398
+ scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
1399
+ if scale is not None:
1400
+ assert self.renormalize
1401
+ x = x * scale.view(-1, 1, 1)
1402
+ return x
1403
+
1404
+ def forward(self, x: torch.Tensor, encode=False) -> QuantizedResult:
1405
+ if encode:
1406
+ return self.encode(x)
1407
+ else:
1408
+ raise NotImplementedError("model forward and training is not supported.")
1409
+ assert x.dim() == 3
1410
+ length = x.shape[-1]
1411
+ x, scale = self.preprocess(x)
1412
+
1413
+ emb = self.encoder(x)
1414
+ q_res = self.quantizer(emb, self.frame_rate)
1415
+ out = self.decoder(q_res.x)
1416
+
1417
+ # remove extra padding added by the encoder and decoder
1418
+ assert out.shape[-1] >= length, (out.shape[-1], length)
1419
+ out = out[..., :length]
1420
+
1421
+ q_res.x = self.postprocess(out, scale)
1422
+
1423
+ return q_res
1424
+
1425
+ def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
1426
+ """Encode the given input tensor to quantized representation along with scale parameter.
1427
+
1428
+ Args:
1429
+ x (torch.Tensor): Float tensor of shape [B, C, T]
1430
+
1431
+ Returns:
1432
+ codes, scale (tuple of torch.Tensor, torch.Tensor): Tuple composed of:
1433
+ codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep.
1434
+ scale a float tensor containing the scale for audio renormalizealization.
1435
+ """
1436
+ assert x.dim() == 3
1437
+ x, scale = self.preprocess(x)
1438
+ emb = self.encoder(x)
1439
+ codes = self.quantizer.encode(emb)
1440
+ return codes, scale
1441
+
1442
+ def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
1443
+ """Decode the given codes to a reconstructed representation, using the scale to perform
1444
+ audio denormalization if needed.
1445
+
1446
+ Args:
1447
+ codes (torch.Tensor): Int tensor of shape [B, K, T]
1448
+ scale (torch.Tensor, optional): Float tensor containing the scale value.
1449
+
1450
+ Returns:
1451
+ out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio.
1452
+ """
1453
+ raise NotImplementedError("Decode is not supported for encode only model")
1454
+ emb = self.decode_latent(codes)
1455
+ out = self.decoder(emb)
1456
+ out = self.postprocess(out, scale)
1457
+ # out contains extra padding added by the encoder and decoder
1458
+ return out
1459
+
1460
+ def decode_latent(self, codes: torch.Tensor):
1461
+ """Decode from the discrete codes to continuous latent space."""
1462
+ raise NotImplementedError("Decode is not supported for encode only model")
1463
+ return self.quantizer.decode(codes)
1464
+
1465
+ def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) -> BaseQuantizer:
1466
+ klass = {
1467
+ 'no_quant': DummyQuantizer,
1468
+ 'rvq': ResidualVectorQuantizer
1469
+ }[quantizer]
1470
+ kwargs = dict_from_config(getattr(cfg, quantizer))
1471
+ if quantizer != 'no_quant':
1472
+ kwargs['dimension'] = dimension
1473
+ return klass(**kwargs)
1474
+
1475
+ def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig):
1476
+ if encoder_name == 'seanet':
1477
+ kwargs = dict_from_config(getattr(cfg, 'seanet'))
1478
+ encoder_override_kwargs = kwargs.pop('encoder')
1479
+ decoder_override_kwargs = kwargs.pop('decoder')
1480
+ encoder_kwargs = {**kwargs, **encoder_override_kwargs}
1481
+ decoder_kwargs = {**kwargs, **decoder_override_kwargs}
1482
+ encoder = SEANetEncoder(**encoder_kwargs)
1483
+ decoder = SEANetDecoder(**decoder_kwargs)
1484
+ return encoder, decoder
1485
+ else:
1486
+ raise KeyError(f"Unexpected compression model {cfg.compression_model}")
1487
+
1488
+
1489
+ def get_compression_model(ckpt_fn, encode_only=False, device="cpu") -> CompressionModel:
1490
+ """Instantiate a compression model."""
1491
+ if device == None:
1492
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1493
+ state = torch.load(ckpt_fn, map_location='cpu')
1494
+ cfg = state['xp.cfg']
1495
+ cfg.device = str(device)
1496
+ weights = state['best_state']['model']
1497
+ assert cfg.compression_model == 'encodec', "Only Encodec model is supported for now."
1498
+ if encode_only:
1499
+ all_keys = list(weights.keys())
1500
+ for key in all_keys:
1501
+ if key.startswith('decoder'):
1502
+ del weights[key]
1503
+ kwargs = dict_from_config(getattr(cfg, 'encodec'))
1504
+ encoder_name = kwargs.pop('autoencoder')
1505
+ quantizer_name = kwargs.pop('quantizer')
1506
+ encoder, _ = get_encodec_autoencoder(encoder_name, cfg)
1507
+ quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension)
1508
+ frame_rate = kwargs['sample_rate'] // encoder.hop_length
1509
+ renormalize = kwargs.pop('renormalize', False)
1510
+ # deprecated params
1511
+ kwargs.pop('renorm', None)
1512
+ compression_model = EncodecModel_encode_only(encoder, quantizer,
1513
+ frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device)
1514
+ assert compression_model.sample_rate == cfg.sample_rate, "Compression model sample rate should match"
1515
+ compression_model.load_state_dict(weights)
1516
+ compression_model.eval()
1517
+ return compression_model
1518
+
1519
+ else:
1520
+ kwargs = dict_from_config(getattr(cfg, 'encodec'))
1521
+ encoder_name = kwargs.pop('autoencoder')
1522
+ quantizer_name = kwargs.pop('quantizer')
1523
+ encoder, decoder = get_encodec_autoencoder(encoder_name, cfg)
1524
+ quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension)
1525
+ frame_rate = kwargs['sample_rate'] // encoder.hop_length
1526
+ renormalize = kwargs.pop('renormalize', False)
1527
+ # deprecated params
1528
+ kwargs.pop('renorm', None)
1529
+ compression_model = EncodecModel(encoder, decoder, quantizer,
1530
+ frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device)
1531
+ assert compression_model.sample_rate == cfg.sample_rate, "Compression model sample rate should match"
1532
+ compression_model.load_state_dict(weights)
1533
+ compression_model.eval()
1534
+ return compression_model
1535
+
1536
+ if __name__ == "__main__":
1537
+ import torchaudio
1538
+ ckpt_fn = "/home/pyp/BoostedVoiceEditor/pretrained/encodec_6f79c6a8.th"
1539
+ audio_in_fns = ["/home/pyp/BoostedVoiceEditor/demo/pam.wav", "/home/pyp/BoostedVoiceEditor/demo/ray.wav", "/home/pyp/BoostedVoiceEditor/demo/84_121550_000074_000000.wav", "/home/pyp/BoostedVoiceEditor/demo/caribbean.wav", "/home/pyp/BoostedVoiceEditor/demo/bible.wav", "/home/pyp/BoostedVoiceEditor/demo/miley.wav"]
1540
+ audio_out_fns = ["/home/pyp/BoostedVoiceEditor/demo/pam_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/ray_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/84_121550_000074_000000_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/caribbean_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/bible_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/miley_encodecTest.wav"]
1541
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1542
+ model = get_compression_model(ckpt_fn, device=device)
1543
+
1544
+ for audio_in_fn, audio_out_fn in zip(audio_in_fns, audio_out_fns):
1545
+ audio_in, sr = torchaudio.load(audio_in_fn)
1546
+ if sr != model.sample_rate:
1547
+ audio_in = torchaudio.transforms.Resample(sr, model.sample_rate)(audio_in)
1548
+ if audio_in.shape[0] == 2:
1549
+ audio_in = audio_in.mean(dim=0, keepdim=True)
1550
+ audio_in = audio_in.unsqueeze(0)
1551
+ audio_in = audio_in.to(torch.float32).to(device)
1552
+ codes = model.encode(audio_in)[0]
1553
+ audio_out = model.decode(codes)[0].cpu()
1554
+ torchaudio.save(audio_out_fn, audio_out, model.sample_rate)
data/ll60k_preprocessing/config.yaml ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # WARNING: This is the base configuration file shared across ALL solvers in AudioCraft
2
+ # Please don't update this file directly. Instead use distinct configuration files
3
+ # to override the below configuration.
4
+ defaults:
5
+ - _self_
6
+ - dset: default
7
+ - solver: default
8
+
9
+ device: cuda
10
+ dtype: float32
11
+ autocast: false
12
+ autocast_dtype: bfloat16
13
+ seed: 2036
14
+ show: false # just show the model and its size and exit
15
+ continue_from: # continue from a given sig or path
16
+ execute_only: # can be set to generate/evaluate/valid to run that stage
17
+ execute_inplace: false # don't enforce continue_from to be set
18
+ # to enable inplace execution of the stage. This assume
19
+ # that you know what you are doing and execute stage
20
+ # preserving the original xp sig.
21
+ benchmark_no_load: false # if set to true, will repeat the same batch instead of loading them
22
+
23
+ efficient_attention_backend: torch # can be torch or xformers.
24
+ num_threads: 1 # called with torch.set_num_thread.
25
+ mp_start_method: forkserver # multiprocessing method (spawn, fork or fork_server).
26
+
27
+
28
+ label: # use this if you want twice the same exp, with a name.
29
+
30
+ # logging parameters
31
+ logging:
32
+ level: INFO
33
+ log_updates: 10
34
+ log_tensorboard: false
35
+ log_wandb: false
36
+ tensorboard:
37
+ with_media_logging: false
38
+ name: # optional name for the experiment
39
+ sub_dir: # optional sub directory to store tensorboard data
40
+ wandb:
41
+ with_media_logging: true
42
+ project: # project name
43
+ name: # optional name for the experiment
44
+ group: # optional group
45
+
46
+ # SLURM launcher configuration.
47
+ slurm:
48
+ gpus: 4 # convenience parameter, number of GPUs to use.
49
+ mem_per_gpu: 40 # in GB, total mem is automatically scaled with `gpus`.
50
+ time: 3600
51
+ constraint:
52
+ partition:
53
+ comment:
54
+ setup: []
55
+ exclude: ''
56
+
57
+ # dora parameters
58
+ dora:
59
+ # Output folder for all artifacts of an experiment.
60
+ dir: /checkpoint/${oc.env:USER}/experiments/audiocraft/outputs
61
+ # The following entries will be ignored by dora when computing the unique XP signature.
62
+ # Note that slurm.* and dora.* are automatically ignored.
63
+ exclude: [
64
+ 'device', 'wandb.*', 'tensorboard.*', 'logging.*',
65
+ 'dataset.num_workers', 'eval.num_workers', 'special.*',
66
+ 'metrics.visqol.bin', 'metrics.fad.bin',
67
+ 'execute_only', 'execute_best', 'generate.every',
68
+ 'optim.eager_sync', 'profiler.*', 'deadlock.*',
69
+ 'efficient_attention_backend', 'num_threads', 'mp_start_method',
70
+ ]
71
+ use_rendezvous: false
72
+ # for grids, always run from a clean repo, allowing reliable runs and storing
73
+ # the exact commit. Your repo must be absolutely pristine clean.
74
+ # Local `dora run` are not impacted for easier debugging.
75
+ git_save: true
data/ll60k_preprocessing/encodec.py ADDED
@@ -0,0 +1,1554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Compression models or wrapper around existing models.
7
+ Also defines the main interface that a model must follow to be usable as an audio tokenizer.
8
+ """
9
+
10
+ from abc import ABC, abstractmethod
11
+ from dataclasses import dataclass, field
12
+ import logging
13
+ import math
14
+ from pathlib import Path
15
+ import typing as tp
16
+
17
+ import numpy as np
18
+ import torch
19
+ from torch import nn
20
+ from torch import einsum
21
+ import torch.nn.functional as F
22
+ from torch.nn.utils import spectral_norm, weight_norm
23
+
24
+ import logging
25
+ import warnings
26
+ from einops import rearrange, repeat
27
+ import omegaconf
28
+ # import flashy
29
+
30
+ CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
31
+ 'time_group_norm'])
32
+
33
+ def dict_from_config(cfg: omegaconf.DictConfig) -> dict:
34
+ """Convenience function to map an omegaconf configuration to a dictionary.
35
+
36
+ Args:
37
+ cfg (omegaconf.DictConfig): Original configuration to map to dict.
38
+ Returns:
39
+ dict: Config as dictionary object.
40
+ """
41
+ dct = omegaconf.OmegaConf.to_container(cfg, resolve=True)
42
+ assert isinstance(dct, dict)
43
+ return dct
44
+
45
+ @dataclass
46
+ class QuantizedResult:
47
+ x: torch.Tensor
48
+ codes: torch.Tensor
49
+ bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
50
+ penalty: tp.Optional[torch.Tensor] = None
51
+ metrics: dict = field(default_factory=dict)
52
+
53
+ class BaseQuantizer(nn.Module):
54
+ """Base class for quantizers.
55
+ """
56
+
57
+ def forward(self, x: torch.Tensor, frame_rate: int) -> QuantizedResult:
58
+ """
59
+ Given input tensor x, returns first the quantized (or approximately quantized)
60
+ representation along with quantized codes, bandwidth, and any penalty term for the loss.
61
+ Finally, this returns a dict of metrics to update logging etc.
62
+ Frame rate must be passed so that the bandwidth is properly computed.
63
+ """
64
+ raise NotImplementedError()
65
+
66
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
67
+ """Encode a given input tensor with the specified sample rate at the given bandwidth."""
68
+ raise NotImplementedError()
69
+
70
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
71
+ """Decode the given codes to the quantized representation."""
72
+ raise NotImplementedError()
73
+
74
+ @property
75
+ def total_codebooks(self):
76
+ """Total number of codebooks."""
77
+ raise NotImplementedError()
78
+
79
+ @property
80
+ def num_codebooks(self):
81
+ """Number of active codebooks."""
82
+ raise NotImplementedError()
83
+
84
+ def set_num_codebooks(self, n: int):
85
+ """Set the number of active codebooks."""
86
+ raise NotImplementedError()
87
+
88
+ class CompressionModel(ABC, nn.Module):
89
+ """Base API for all compression model that aim at being used as audio tokenizers
90
+ with a language model.
91
+ """
92
+
93
+ @abstractmethod
94
+ def forward(self, x: torch.Tensor) -> QuantizedResult:
95
+ ...
96
+
97
+ @abstractmethod
98
+ def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
99
+ """See `EncodecModel.encode`."""
100
+ ...
101
+
102
+ @abstractmethod
103
+ def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
104
+ """See `EncodecModel.decode`."""
105
+ ...
106
+
107
+ @abstractmethod
108
+ def decode_latent(self, codes: torch.Tensor):
109
+ """Decode from the discrete codes to continuous latent space."""
110
+ ...
111
+
112
+ @property
113
+ @abstractmethod
114
+ def channels(self) -> int:
115
+ ...
116
+
117
+ @property
118
+ @abstractmethod
119
+ def frame_rate(self) -> float:
120
+ ...
121
+
122
+ @property
123
+ @abstractmethod
124
+ def sample_rate(self) -> int:
125
+ ...
126
+
127
+ @property
128
+ @abstractmethod
129
+ def cardinality(self) -> int:
130
+ ...
131
+
132
+ @property
133
+ @abstractmethod
134
+ def num_codebooks(self) -> int:
135
+ ...
136
+
137
+ @property
138
+ @abstractmethod
139
+ def total_codebooks(self) -> int:
140
+ ...
141
+
142
+ @abstractmethod
143
+ def set_num_codebooks(self, n: int):
144
+ """Set the active number of codebooks used by the quantizer."""
145
+ ...
146
+
147
+ def apply_parametrization_norm(module: nn.Module, norm: str = 'none'):
148
+ assert norm in CONV_NORMALIZATIONS
149
+ if norm == 'weight_norm':
150
+ return weight_norm(module)
151
+ elif norm == 'spectral_norm':
152
+ return spectral_norm(module)
153
+ else:
154
+ # We already check was in CONV_NORMALIZATION, so any other choice
155
+ # doesn't need reparametrization.
156
+ return module
157
+
158
+
159
+ def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs):
160
+ """Return the proper normalization module. If causal is True, this will ensure the returned
161
+ module is causal, or return an error if the normalization doesn't support causal evaluation.
162
+ """
163
+ assert norm in CONV_NORMALIZATIONS
164
+ if norm == 'time_group_norm':
165
+ if causal:
166
+ raise ValueError("GroupNorm doesn't support causal evaluation.")
167
+ assert isinstance(module, nn.modules.conv._ConvNd)
168
+ return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
169
+ else:
170
+ return nn.Identity()
171
+
172
+
173
+ def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
174
+ padding_total: int = 0) -> int:
175
+ """See `pad_for_conv1d`."""
176
+ length = x.shape[-1]
177
+ n_frames = (length - kernel_size + padding_total) / stride + 1
178
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
179
+ return ideal_length - length
180
+
181
+
182
+ def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
183
+ """Pad for a convolution to make sure that the last window is full.
184
+ Extra padding is added at the end. This is required to ensure that we can rebuild
185
+ an output of the same length, as otherwise, even with padding, some time steps
186
+ might get removed.
187
+ For instance, with total padding = 4, kernel size = 4, stride = 2:
188
+ 0 0 1 2 3 4 5 0 0 # (0s are padding)
189
+ 1 2 3 # (output frames of a convolution, last 0 is never used)
190
+ 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
191
+ 1 2 3 4 # once you removed padding, we are missing one time step !
192
+ """
193
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
194
+ return F.pad(x, (0, extra_padding))
195
+
196
+
197
+ def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
198
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
199
+ If this is the case, we insert extra 0 padding to the right before the reflection happen.
200
+ """
201
+ length = x.shape[-1]
202
+ padding_left, padding_right = paddings
203
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
204
+ if mode == 'reflect':
205
+ max_pad = max(padding_left, padding_right)
206
+ extra_pad = 0
207
+ if length <= max_pad:
208
+ extra_pad = max_pad - length + 1
209
+ x = F.pad(x, (0, extra_pad))
210
+ padded = F.pad(x, paddings, mode, value)
211
+ end = padded.shape[-1] - extra_pad
212
+ return padded[..., :end]
213
+ else:
214
+ return F.pad(x, paddings, mode, value)
215
+
216
+
217
+ def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
218
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
219
+ padding_left, padding_right = paddings
220
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
221
+ assert (padding_left + padding_right) <= x.shape[-1]
222
+ end = x.shape[-1] - padding_right
223
+ return x[..., padding_left: end]
224
+
225
+
226
+ class NormConv1d(nn.Module):
227
+ """Wrapper around Conv1d and normalization applied to this conv
228
+ to provide a uniform interface across normalization approaches.
229
+ """
230
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
231
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
232
+ super().__init__()
233
+ self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
234
+ self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
235
+ self.norm_type = norm
236
+
237
+ def forward(self, x):
238
+ x = self.conv(x)
239
+ x = self.norm(x)
240
+ return x
241
+
242
+
243
+ class NormConv2d(nn.Module):
244
+ """Wrapper around Conv2d and normalization applied to this conv
245
+ to provide a uniform interface across normalization approaches.
246
+ """
247
+ def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
248
+ super().__init__()
249
+ self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
250
+ self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
251
+ self.norm_type = norm
252
+
253
+ def forward(self, x):
254
+ x = self.conv(x)
255
+ x = self.norm(x)
256
+ return x
257
+
258
+
259
+ class NormConvTranspose1d(nn.Module):
260
+ """Wrapper around ConvTranspose1d and normalization applied to this conv
261
+ to provide a uniform interface across normalization approaches.
262
+ """
263
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
264
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
265
+ super().__init__()
266
+ self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
267
+ self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
268
+ self.norm_type = norm
269
+
270
+ def forward(self, x):
271
+ x = self.convtr(x)
272
+ x = self.norm(x)
273
+ return x
274
+
275
+
276
+ class NormConvTranspose2d(nn.Module):
277
+ """Wrapper around ConvTranspose2d and normalization applied to this conv
278
+ to provide a uniform interface across normalization approaches.
279
+ """
280
+ def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
281
+ super().__init__()
282
+ self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
283
+ self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
284
+
285
+ def forward(self, x):
286
+ x = self.convtr(x)
287
+ x = self.norm(x)
288
+ return x
289
+
290
+
291
+ class StreamableConv1d(nn.Module):
292
+ """Conv1d with some builtin handling of asymmetric or causal padding
293
+ and normalization.
294
+ """
295
+ def __init__(self, in_channels: int, out_channels: int,
296
+ kernel_size: int, stride: int = 1, dilation: int = 1,
297
+ groups: int = 1, bias: bool = True, causal: bool = False,
298
+ norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
299
+ pad_mode: str = 'reflect'):
300
+ super().__init__()
301
+ # warn user on unusual setup between dilation and stride
302
+ if stride > 1 and dilation > 1:
303
+ warnings.warn("StreamableConv1d has been initialized with stride > 1 and dilation > 1"
304
+ f" (kernel_size={kernel_size} stride={stride}, dilation={dilation}).")
305
+ self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
306
+ dilation=dilation, groups=groups, bias=bias, causal=causal,
307
+ norm=norm, norm_kwargs=norm_kwargs)
308
+ self.causal = causal
309
+ self.pad_mode = pad_mode
310
+
311
+ def forward(self, x):
312
+ B, C, T = x.shape
313
+ kernel_size = self.conv.conv.kernel_size[0]
314
+ stride = self.conv.conv.stride[0]
315
+ dilation = self.conv.conv.dilation[0]
316
+ kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
317
+ padding_total = kernel_size - stride
318
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
319
+ if self.causal:
320
+ # Left padding for causal
321
+ x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
322
+ else:
323
+ # Asymmetric padding required for odd strides
324
+ padding_right = padding_total // 2
325
+ padding_left = padding_total - padding_right
326
+ x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
327
+ return self.conv(x)
328
+
329
+
330
+ class StreamableConvTranspose1d(nn.Module):
331
+ """ConvTranspose1d with some builtin handling of asymmetric or causal padding
332
+ and normalization.
333
+ """
334
+ def __init__(self, in_channels: int, out_channels: int,
335
+ kernel_size: int, stride: int = 1, causal: bool = False,
336
+ norm: str = 'none', trim_right_ratio: float = 1.,
337
+ norm_kwargs: tp.Dict[str, tp.Any] = {}):
338
+ super().__init__()
339
+ self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
340
+ causal=causal, norm=norm, norm_kwargs=norm_kwargs)
341
+ self.causal = causal
342
+ self.trim_right_ratio = trim_right_ratio
343
+ assert self.causal or self.trim_right_ratio == 1., \
344
+ "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
345
+ assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
346
+
347
+ def forward(self, x):
348
+ kernel_size = self.convtr.convtr.kernel_size[0]
349
+ stride = self.convtr.convtr.stride[0]
350
+ padding_total = kernel_size - stride
351
+
352
+ y = self.convtr(x)
353
+
354
+ # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
355
+ # removed at the very end, when keeping only the right length for the output,
356
+ # as removing it here would require also passing the length at the matching layer
357
+ # in the encoder.
358
+ if self.causal:
359
+ # Trim the padding on the right according to the specified ratio
360
+ # if trim_right_ratio = 1.0, trim everything from right
361
+ padding_right = math.ceil(padding_total * self.trim_right_ratio)
362
+ padding_left = padding_total - padding_right
363
+ y = unpad1d(y, (padding_left, padding_right))
364
+ else:
365
+ # Asymmetric padding required for odd strides
366
+ padding_right = padding_total // 2
367
+ padding_left = padding_total - padding_right
368
+ y = unpad1d(y, (padding_left, padding_right))
369
+ return y
370
+
371
+
372
+ class StreamableLSTM(nn.Module):
373
+ """LSTM without worrying about the hidden state, nor the layout of the data.
374
+ Expects input as convolutional layout.
375
+ """
376
+ def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
377
+ super().__init__()
378
+ self.skip = skip
379
+ self.lstm = nn.LSTM(dimension, dimension, num_layers)
380
+
381
+ def forward(self, x):
382
+ x = x.permute(2, 0, 1)
383
+ y, _ = self.lstm(x)
384
+ if self.skip:
385
+ y = y + x
386
+ y = y.permute(1, 2, 0)
387
+ return y
388
+
389
+
390
+ class SEANetResnetBlock(nn.Module):
391
+ """Residual block from SEANet model.
392
+
393
+ Args:
394
+ dim (int): Dimension of the input/output.
395
+ kernel_sizes (list): List of kernel sizes for the convolutions.
396
+ dilations (list): List of dilations for the convolutions.
397
+ activation (str): Activation function.
398
+ activation_params (dict): Parameters to provide to the activation function.
399
+ norm (str): Normalization method.
400
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
401
+ causal (bool): Whether to use fully causal convolution.
402
+ pad_mode (str): Padding mode for the convolutions.
403
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3).
404
+ true_skip (bool): Whether to use true skip connection or a simple
405
+ (streamable) convolution as the skip connection.
406
+ """
407
+ def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1],
408
+ activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
409
+ norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False,
410
+ pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True):
411
+ super().__init__()
412
+ assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations'
413
+ act = getattr(nn, activation)
414
+ hidden = dim // compress
415
+ block = []
416
+ for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
417
+ in_chs = dim if i == 0 else hidden
418
+ out_chs = dim if i == len(kernel_sizes) - 1 else hidden
419
+ block += [
420
+ act(**activation_params),
421
+ StreamableConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation,
422
+ norm=norm, norm_kwargs=norm_params,
423
+ causal=causal, pad_mode=pad_mode),
424
+ ]
425
+ self.block = nn.Sequential(*block)
426
+ self.shortcut: nn.Module
427
+ if true_skip:
428
+ self.shortcut = nn.Identity()
429
+ else:
430
+ self.shortcut = StreamableConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params,
431
+ causal=causal, pad_mode=pad_mode)
432
+
433
+ def forward(self, x):
434
+ return self.shortcut(x) + self.block(x)
435
+
436
+
437
+ class SEANetEncoder(nn.Module):
438
+ """SEANet encoder.
439
+
440
+ Args:
441
+ channels (int): Audio channels.
442
+ dimension (int): Intermediate representation dimension.
443
+ n_filters (int): Base width for the model.
444
+ n_residual_layers (int): nb of residual layers.
445
+ ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
446
+ upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
447
+ that must match the decoder order. We use the decoder order as some models may only employ the decoder.
448
+ activation (str): Activation function.
449
+ activation_params (dict): Parameters to provide to the activation function.
450
+ norm (str): Normalization method.
451
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
452
+ kernel_size (int): Kernel size for the initial convolution.
453
+ last_kernel_size (int): Kernel size for the initial convolution.
454
+ residual_kernel_size (int): Kernel size for the residual layers.
455
+ dilation_base (int): How much to increase the dilation with each layer.
456
+ causal (bool): Whether to use fully causal convolution.
457
+ pad_mode (str): Padding mode for the convolutions.
458
+ true_skip (bool): Whether to use true skip connection or a simple
459
+ (streamable) convolution as the skip connection in the residual network blocks.
460
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3).
461
+ lstm (int): Number of LSTM layers at the end of the encoder.
462
+ disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
463
+ For the encoder, it corresponds to the N first blocks.
464
+ """
465
+ def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
466
+ ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
467
+ norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
468
+ last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
469
+ pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0,
470
+ disable_norm_outer_blocks: int = 0):
471
+ super().__init__()
472
+ self.channels = channels
473
+ self.dimension = dimension
474
+ self.n_filters = n_filters
475
+ self.ratios = list(reversed(ratios))
476
+ del ratios
477
+ self.n_residual_layers = n_residual_layers
478
+ self.hop_length = np.prod(self.ratios)
479
+ self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
480
+ self.disable_norm_outer_blocks = disable_norm_outer_blocks
481
+ assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \
482
+ "Number of blocks for which to disable norm is invalid." \
483
+ "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
484
+
485
+ act = getattr(nn, activation)
486
+ mult = 1
487
+ model: tp.List[nn.Module] = [
488
+ StreamableConv1d(channels, mult * n_filters, kernel_size,
489
+ norm='none' if self.disable_norm_outer_blocks >= 1 else norm,
490
+ norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
491
+ ]
492
+ # Downsample to raw audio scale
493
+ for i, ratio in enumerate(self.ratios):
494
+ block_norm = 'none' if self.disable_norm_outer_blocks >= i + 2 else norm
495
+ # Add residual layers
496
+ for j in range(n_residual_layers):
497
+ model += [
498
+ SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1],
499
+ dilations=[dilation_base ** j, 1],
500
+ norm=block_norm, norm_params=norm_params,
501
+ activation=activation, activation_params=activation_params,
502
+ causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
503
+
504
+ # Add downsampling layers
505
+ model += [
506
+ act(**activation_params),
507
+ StreamableConv1d(mult * n_filters, mult * n_filters * 2,
508
+ kernel_size=ratio * 2, stride=ratio,
509
+ norm=block_norm, norm_kwargs=norm_params,
510
+ causal=causal, pad_mode=pad_mode),
511
+ ]
512
+ mult *= 2
513
+
514
+ if lstm:
515
+ model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
516
+
517
+ model += [
518
+ act(**activation_params),
519
+ StreamableConv1d(mult * n_filters, dimension, last_kernel_size,
520
+ norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm,
521
+ norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
522
+ ]
523
+
524
+ self.model = nn.Sequential(*model)
525
+
526
+ def forward(self, x):
527
+ return self.model(x)
528
+
529
+
530
+ class SEANetDecoder(nn.Module):
531
+ """SEANet decoder.
532
+
533
+ Args:
534
+ channels (int): Audio channels.
535
+ dimension (int): Intermediate representation dimension.
536
+ n_filters (int): Base width for the model.
537
+ n_residual_layers (int): nb of residual layers.
538
+ ratios (Sequence[int]): kernel size and stride ratios.
539
+ activation (str): Activation function.
540
+ activation_params (dict): Parameters to provide to the activation function.
541
+ final_activation (str): Final activation function after all convolutions.
542
+ final_activation_params (dict): Parameters to provide to the activation function.
543
+ norm (str): Normalization method.
544
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
545
+ kernel_size (int): Kernel size for the initial convolution.
546
+ last_kernel_size (int): Kernel size for the initial convolution.
547
+ residual_kernel_size (int): Kernel size for the residual layers.
548
+ dilation_base (int): How much to increase the dilation with each layer.
549
+ causal (bool): Whether to use fully causal convolution.
550
+ pad_mode (str): Padding mode for the convolutions.
551
+ true_skip (bool): Whether to use true skip connection or a simple.
552
+ (streamable) convolution as the skip connection in the residual network blocks.
553
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3).
554
+ lstm (int): Number of LSTM layers at the end of the encoder.
555
+ disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
556
+ For the decoder, it corresponds to the N last blocks.
557
+ trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
558
+ If equal to 1.0, it means that all the trimming is done at the right.
559
+ """
560
+ def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
561
+ ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
562
+ final_activation: tp.Optional[str] = None, final_activation_params: tp.Optional[dict] = None,
563
+ norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
564
+ last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
565
+ pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0,
566
+ disable_norm_outer_blocks: int = 0, trim_right_ratio: float = 1.0):
567
+ super().__init__()
568
+ self.dimension = dimension
569
+ self.channels = channels
570
+ self.n_filters = n_filters
571
+ self.ratios = ratios
572
+ del ratios
573
+ self.n_residual_layers = n_residual_layers
574
+ self.hop_length = np.prod(self.ratios)
575
+ self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
576
+ self.disable_norm_outer_blocks = disable_norm_outer_blocks
577
+ assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \
578
+ "Number of blocks for which to disable norm is invalid." \
579
+ "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
580
+
581
+ act = getattr(nn, activation)
582
+ mult = int(2 ** len(self.ratios))
583
+ model: tp.List[nn.Module] = [
584
+ StreamableConv1d(dimension, mult * n_filters, kernel_size,
585
+ norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm,
586
+ norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
587
+ ]
588
+
589
+ if lstm:
590
+ model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
591
+
592
+ # Upsample to raw audio scale
593
+ for i, ratio in enumerate(self.ratios):
594
+ block_norm = 'none' if self.disable_norm_outer_blocks >= self.n_blocks - (i + 1) else norm
595
+ # Add upsampling layers
596
+ model += [
597
+ act(**activation_params),
598
+ StreamableConvTranspose1d(mult * n_filters, mult * n_filters // 2,
599
+ kernel_size=ratio * 2, stride=ratio,
600
+ norm=block_norm, norm_kwargs=norm_params,
601
+ causal=causal, trim_right_ratio=trim_right_ratio),
602
+ ]
603
+ # Add residual layers
604
+ for j in range(n_residual_layers):
605
+ model += [
606
+ SEANetResnetBlock(mult * n_filters // 2, kernel_sizes=[residual_kernel_size, 1],
607
+ dilations=[dilation_base ** j, 1],
608
+ activation=activation, activation_params=activation_params,
609
+ norm=block_norm, norm_params=norm_params, causal=causal,
610
+ pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
611
+
612
+ mult //= 2
613
+
614
+ # Add final layers
615
+ model += [
616
+ act(**activation_params),
617
+ StreamableConv1d(n_filters, channels, last_kernel_size,
618
+ norm='none' if self.disable_norm_outer_blocks >= 1 else norm,
619
+ norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
620
+ ]
621
+ # Add optional final activation to decoder (eg. tanh)
622
+ if final_activation is not None:
623
+ final_act = getattr(nn, final_activation)
624
+ final_activation_params = final_activation_params or {}
625
+ model += [
626
+ final_act(**final_activation_params)
627
+ ]
628
+ self.model = nn.Sequential(*model)
629
+
630
+ def forward(self, z):
631
+ y = self.model(z)
632
+ return y
633
+
634
+
635
+ def exists(val: tp.Optional[tp.Any]) -> bool:
636
+ return val is not None
637
+
638
+
639
+ def default(val: tp.Any, d: tp.Any) -> tp.Any:
640
+ return val if exists(val) else d
641
+
642
+
643
+ def l2norm(t):
644
+ return F.normalize(t, p=2, dim=-1)
645
+
646
+
647
+ def ema_inplace(moving_avg, new, decay: float):
648
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
649
+
650
+
651
+ def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
652
+ return (x + epsilon) / (x.sum() + n_categories * epsilon)
653
+
654
+
655
+ def uniform_init(*shape: int):
656
+ t = torch.empty(shape)
657
+ nn.init.kaiming_uniform_(t)
658
+ return t
659
+
660
+
661
+ def sample_vectors(samples, num: int):
662
+ num_samples, device = samples.shape[0], samples.device
663
+
664
+ if num_samples >= num:
665
+ indices = torch.randperm(num_samples, device=device)[:num]
666
+ else:
667
+ indices = torch.randint(0, num_samples, (num,), device=device)
668
+
669
+ return samples[indices]
670
+
671
+
672
+ def kmeans(samples, num_clusters: int, num_iters: int = 10):
673
+ dim, dtype = samples.shape[-1], samples.dtype
674
+
675
+ means = sample_vectors(samples, num_clusters)
676
+
677
+ for _ in range(num_iters):
678
+ diffs = rearrange(samples, "n d -> n () d") - rearrange(
679
+ means, "c d -> () c d"
680
+ )
681
+ dists = -(diffs ** 2).sum(dim=-1)
682
+
683
+ buckets = dists.max(dim=-1).indices
684
+ bins = torch.bincount(buckets, minlength=num_clusters)
685
+ zero_mask = bins == 0
686
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
687
+
688
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
689
+ new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
690
+ new_means = new_means / bins_min_clamped[..., None]
691
+
692
+ means = torch.where(zero_mask[..., None], means, new_means)
693
+
694
+ return means, bins
695
+
696
+
697
+ def orthogonal_loss_fn(t):
698
+ # eq (2) from https://arxiv.org/abs/2112.00384
699
+ n = t.shape[0]
700
+ normed_codes = l2norm(t)
701
+ identity = torch.eye(n, device=t.device)
702
+ cosine_sim = einsum("i d, j d -> i j", normed_codes, normed_codes)
703
+ return ((cosine_sim - identity) ** 2).sum() / (n ** 2)
704
+
705
+
706
+ class EuclideanCodebook(nn.Module):
707
+ """Codebook with Euclidean distance.
708
+
709
+ Args:
710
+ dim (int): Dimension.
711
+ codebook_size (int): Codebook size.
712
+ kmeans_init (bool): Whether to use k-means to initialize the codebooks.
713
+ If set to true, run the k-means algorithm on the first training batch and use
714
+ the learned centroids as initialization.
715
+ kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
716
+ decay (float): Decay for exponential moving average over the codebooks.
717
+ epsilon (float): Epsilon value for numerical stability.
718
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
719
+ that have an exponential moving average cluster size less than the specified threshold with
720
+ randomly selected vector from the current batch.
721
+ """
722
+ def __init__(
723
+ self,
724
+ dim: int,
725
+ codebook_size: int,
726
+ kmeans_init: int = False,
727
+ kmeans_iters: int = 10,
728
+ decay: float = 0.8,
729
+ epsilon: float = 1e-5,
730
+ threshold_ema_dead_code: int = 2,
731
+ ):
732
+ super().__init__()
733
+ self.decay = decay
734
+ init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
735
+ embed = init_fn(codebook_size, dim)
736
+
737
+ self.codebook_size = codebook_size
738
+
739
+ self.kmeans_iters = kmeans_iters
740
+ self.epsilon = epsilon
741
+ self.threshold_ema_dead_code = threshold_ema_dead_code
742
+
743
+ self.register_buffer("inited", torch.Tensor([not kmeans_init]))
744
+ self.register_buffer("cluster_size", torch.zeros(codebook_size))
745
+ self.register_buffer("embed", embed)
746
+ self.register_buffer("embed_avg", embed.clone())
747
+
748
+ @torch.jit.ignore
749
+ def init_embed_(self, data):
750
+ if self.inited:
751
+ return
752
+
753
+ embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
754
+ self.embed.data.copy_(embed)
755
+ self.embed_avg.data.copy_(embed.clone())
756
+ self.cluster_size.data.copy_(cluster_size)
757
+ self.inited.data.copy_(torch.Tensor([True]))
758
+ # Make sure all buffers across workers are in sync after initialization
759
+ flashy.distrib.broadcast_tensors(self.buffers())
760
+
761
+ def replace_(self, samples, mask):
762
+ modified_codebook = torch.where(
763
+ mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
764
+ )
765
+ self.embed.data.copy_(modified_codebook)
766
+
767
+ def expire_codes_(self, batch_samples):
768
+ if self.threshold_ema_dead_code == 0:
769
+ return
770
+
771
+ expired_codes = self.cluster_size < self.threshold_ema_dead_code
772
+ if not torch.any(expired_codes):
773
+ return
774
+
775
+ batch_samples = rearrange(batch_samples, "... d -> (...) d")
776
+ self.replace_(batch_samples, mask=expired_codes)
777
+ flashy.distrib.broadcast_tensors(self.buffers())
778
+
779
+ def preprocess(self, x):
780
+ x = rearrange(x, "... d -> (...) d")
781
+ return x
782
+
783
+ def quantize(self, x):
784
+ embed = self.embed.t()
785
+ dist = -(
786
+ x.pow(2).sum(1, keepdim=True)
787
+ - 2 * x @ embed
788
+ + embed.pow(2).sum(0, keepdim=True)
789
+ )
790
+ embed_ind = dist.max(dim=-1).indices
791
+ return embed_ind
792
+
793
+ def postprocess_emb(self, embed_ind, shape):
794
+ return embed_ind.view(*shape[:-1])
795
+
796
+ def dequantize(self, embed_ind):
797
+ quantize = F.embedding(embed_ind, self.embed)
798
+ return quantize
799
+
800
+ def encode(self, x):
801
+ shape = x.shape
802
+ # pre-process
803
+ x = self.preprocess(x)
804
+ # quantize
805
+ embed_ind = self.quantize(x)
806
+ # post-process
807
+ embed_ind = self.postprocess_emb(embed_ind, shape)
808
+ return embed_ind
809
+
810
+ def decode(self, embed_ind):
811
+ quantize = self.dequantize(embed_ind)
812
+ return quantize
813
+
814
+ def forward(self, x):
815
+ raise NotImplementedError()
816
+ shape, dtype = x.shape, x.dtype
817
+ x = self.preprocess(x)
818
+ self.init_embed_(x)
819
+
820
+ embed_ind = self.quantize(x)
821
+ embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
822
+ embed_ind = self.postprocess_emb(embed_ind, shape)
823
+ quantize = self.dequantize(embed_ind)
824
+
825
+ if self.training:
826
+ # We do the expiry of code at that point as buffers are in sync
827
+ # and all the workers will take the same decision.
828
+ self.expire_codes_(x)
829
+ ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
830
+ embed_sum = x.t() @ embed_onehot
831
+ ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
832
+ cluster_size = (
833
+ laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
834
+ * self.cluster_size.sum()
835
+ )
836
+ embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
837
+ self.embed.data.copy_(embed_normalized)
838
+
839
+ return quantize, embed_ind
840
+
841
+
842
+ class VectorQuantization(nn.Module):
843
+ """Vector quantization implementation.
844
+ Currently supports only euclidean distance.
845
+
846
+ Args:
847
+ dim (int): Dimension
848
+ codebook_size (int): Codebook size
849
+ codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
850
+ decay (float): Decay for exponential moving average over the codebooks.
851
+ epsilon (float): Epsilon value for numerical stability.
852
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
853
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
854
+ threshold_ema_dead_code (int):
855
+ channels_last (bool): Channels are the last dimension in the input tensors.
856
+ commitment_weight (float): Weight for commitment loss.
857
+ orthogonal_reg_weight (float): Orthogonal regularization weights.
858
+ orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes.
859
+ orthogonal_reg_max_codes (optional int): Maximum number of codes to consider
860
+ for orthogonal regularization.
861
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
862
+ that have an exponential moving average cluster size less than the specified threshold with
863
+ randomly selected vector from the current batch.
864
+ """
865
+ def __init__(
866
+ self,
867
+ dim: int,
868
+ codebook_size: int,
869
+ codebook_dim: tp.Optional[int] = None,
870
+ decay: float = 0.8,
871
+ epsilon: float = 1e-5,
872
+ kmeans_init: bool = False,
873
+ kmeans_iters: int = 10,
874
+ threshold_ema_dead_code: int = 2,
875
+ channels_last: bool = False,
876
+ commitment_weight: float = 1.,
877
+ orthogonal_reg_weight: float = 0.0,
878
+ orthogonal_reg_active_codes_only: bool = False,
879
+ orthogonal_reg_max_codes: tp.Optional[int] = None,
880
+ ):
881
+ super().__init__()
882
+ _codebook_dim: int = default(codebook_dim, dim)
883
+
884
+ requires_projection = _codebook_dim != dim
885
+ self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity())
886
+ self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity())
887
+
888
+ self.epsilon = epsilon
889
+ self.commitment_weight = commitment_weight
890
+
891
+ self.orthogonal_reg_weight = orthogonal_reg_weight
892
+ self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
893
+ self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
894
+
895
+ self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
896
+ kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
897
+ decay=decay, epsilon=epsilon,
898
+ threshold_ema_dead_code=threshold_ema_dead_code)
899
+ self.codebook_size = codebook_size
900
+
901
+ self.channels_last = channels_last
902
+
903
+ @property
904
+ def codebook(self):
905
+ return self._codebook.embed
906
+
907
+ @property
908
+ def inited(self):
909
+ return self._codebook.inited
910
+
911
+ def _preprocess(self, x):
912
+ if not self.channels_last:
913
+ x = rearrange(x, "b d n -> b n d")
914
+ return x
915
+
916
+ def _postprocess(self, quantize):
917
+ if not self.channels_last:
918
+ quantize = rearrange(quantize, "b n d -> b d n")
919
+ return quantize
920
+
921
+ def encode(self, x):
922
+ x = self._preprocess(x)
923
+ x = self.project_in(x)
924
+ embed_in = self._codebook.encode(x)
925
+ return embed_in
926
+
927
+ def decode(self, embed_ind):
928
+ quantize = self._codebook.decode(embed_ind)
929
+ quantize = self.project_out(quantize)
930
+ quantize = self._postprocess(quantize)
931
+ return quantize
932
+
933
+ def forward(self, x):
934
+ device = x.device
935
+ x = self._preprocess(x)
936
+
937
+ x = self.project_in(x)
938
+ quantize, embed_ind = self._codebook(x)
939
+
940
+ if self.training:
941
+ quantize = x + (quantize - x).detach()
942
+
943
+ loss = torch.tensor([0.0], device=device, requires_grad=self.training)
944
+
945
+ if self.training:
946
+ if self.commitment_weight > 0:
947
+ commit_loss = F.mse_loss(quantize.detach(), x)
948
+ loss = loss + commit_loss * self.commitment_weight
949
+
950
+ if self.orthogonal_reg_weight > 0:
951
+ codebook = self.codebook
952
+
953
+ if self.orthogonal_reg_active_codes_only:
954
+ # only calculate orthogonal loss for the activated codes for this batch
955
+ unique_code_ids = torch.unique(embed_ind)
956
+ codebook = codebook[unique_code_ids]
957
+
958
+ num_codes = codebook.shape[0]
959
+ if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes:
960
+ rand_ids = torch.randperm(num_codes, device=device)[:self.orthogonal_reg_max_codes]
961
+ codebook = codebook[rand_ids]
962
+
963
+ orthogonal_reg_loss = orthogonal_loss_fn(codebook)
964
+ loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight
965
+
966
+ quantize = self.project_out(quantize)
967
+ quantize = self._postprocess(quantize)
968
+
969
+ return quantize, embed_ind, loss
970
+
971
+
972
+ class ResidualVectorQuantization(nn.Module):
973
+ """Residual vector quantization implementation.
974
+
975
+ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
976
+ """
977
+ def __init__(self, *, num_quantizers, **kwargs):
978
+ super().__init__()
979
+ codebook_size = kwargs.pop('codebook_size', None)
980
+ if codebook_size is None:
981
+ raise ValueError("codebook_size must be provided in kwargs")
982
+ if type(codebook_size) != list:
983
+ codebook_size = [codebook_size] * num_quantizers
984
+ self.layers = nn.ModuleList(
985
+ [VectorQuantization(codebook_size=cur_codebook_size, **kwargs) for _,cur_codebook_size in zip(range(num_quantizers), codebook_size)]
986
+ )
987
+
988
+
989
+ # self.layers = nn.ModuleList(
990
+ # [VectorQuantization(**kwargs) for _ in range(num_quantizers)]
991
+ # )
992
+
993
+ def forward(self, x, n_q: tp.Optional[int] = None):
994
+ quantized_out = 0.0
995
+ residual = x
996
+
997
+ all_losses = []
998
+ all_indices = []
999
+
1000
+ n_q = n_q or len(self.layers)
1001
+
1002
+ for i, layer in enumerate(self.layers[:n_q]):
1003
+ quantized, indices, loss = layer(residual)
1004
+ residual = residual - quantized
1005
+ quantized_out = quantized_out + quantized
1006
+ all_indices.append(indices)
1007
+ all_losses.append(loss)
1008
+
1009
+ out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
1010
+ return quantized_out, out_indices, out_losses
1011
+
1012
+ def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
1013
+ residual = x
1014
+ all_indices = []
1015
+ n_q = n_q or len(self.layers)
1016
+ for layer in self.layers[:n_q]:
1017
+ indices = layer.encode(residual)
1018
+ quantized = layer.decode(indices)
1019
+ # the original code is below
1020
+ # since quantize has the gradient of residual, according to line 321
1021
+ # quantize = x + (quantize - x).detach()
1022
+ # the code below will make commitment loss to be 0 for all codebooks except for codebook1
1023
+ # https://github.com/facebookresearch/encodec/issues/25
1024
+ # therefore we change it
1025
+
1026
+ residual = residual - quantized
1027
+ # residual = residual - quantized.detach()
1028
+ # since commitment loss is averaged, the scale of the loss won't get change (not as said in the issue above)
1029
+ all_indices.append(indices)
1030
+ out_indices = torch.stack(all_indices)
1031
+ return out_indices
1032
+
1033
+ def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
1034
+ quantized_out = torch.tensor(0.0, device=q_indices.device)
1035
+ for i, indices in enumerate(q_indices):
1036
+ layer = self.layers[i]
1037
+ quantized = layer.decode(indices)
1038
+ quantized_out = quantized_out + quantized
1039
+ return quantized_out
1040
+
1041
+
1042
+ class ResidualVectorQuantizer(BaseQuantizer):
1043
+ """Residual Vector Quantizer.
1044
+
1045
+ Args:
1046
+ dimension (int): Dimension of the codebooks.
1047
+ n_q (int): Number of residual vector quantizers used.
1048
+ q_dropout (bool): Random quantizer drop out at train time.
1049
+ bins (int): Codebook size.
1050
+ decay (float): Decay for exponential moving average over the codebooks.
1051
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
1052
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
1053
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
1054
+ that have an exponential moving average cluster size less than the specified threshold with
1055
+ randomly selected vector from the current batch.
1056
+ orthogonal_reg_weight (float): Orthogonal regularization weights.
1057
+ orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes.
1058
+ orthogonal_reg_max_codes (optional int): Maximum number of codes to consider.
1059
+ for orthogonal regularization.
1060
+ """
1061
+ def __init__(
1062
+ self,
1063
+ dimension: int = 256,
1064
+ n_q: int = 8,
1065
+ q_dropout: bool = False,
1066
+ bins: tp.Union[int, tp.List[int]] = 1024,
1067
+ decay: float = 0.99,
1068
+ kmeans_init: bool = True,
1069
+ kmeans_iters: int = 10,
1070
+ threshold_ema_dead_code: int = 2,
1071
+ orthogonal_reg_weight: float = 0.0,
1072
+ orthogonal_reg_active_codes_only: bool = False,
1073
+ orthogonal_reg_max_codes: tp.Optional[int] = None,
1074
+ ):
1075
+ super().__init__()
1076
+ self.max_n_q = n_q
1077
+ self.n_q = n_q
1078
+ self.q_dropout = q_dropout
1079
+ self.dimension = dimension
1080
+ self.bins = bins
1081
+ self.decay = decay
1082
+ self.kmeans_init = kmeans_init
1083
+ self.kmeans_iters = kmeans_iters
1084
+ self.threshold_ema_dead_code = threshold_ema_dead_code
1085
+ self.orthogonal_reg_weight = orthogonal_reg_weight
1086
+ self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
1087
+ self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
1088
+ self.vq = ResidualVectorQuantization(
1089
+ dim=self.dimension,
1090
+ codebook_size=self.bins,
1091
+ num_quantizers=self.n_q,
1092
+ decay=self.decay,
1093
+ kmeans_init=self.kmeans_init,
1094
+ kmeans_iters=self.kmeans_iters,
1095
+ threshold_ema_dead_code=self.threshold_ema_dead_code,
1096
+ orthogonal_reg_weight=self.orthogonal_reg_weight,
1097
+ orthogonal_reg_active_codes_only=self.orthogonal_reg_active_codes_only,
1098
+ orthogonal_reg_max_codes=self.orthogonal_reg_max_codes,
1099
+ channels_last=False
1100
+ )
1101
+
1102
+ def forward(self, x: torch.Tensor, frame_rate: int):
1103
+ n_q = self.n_q
1104
+ if self.training and self.q_dropout:
1105
+ n_q = int(torch.randint(1, self.n_q + 1, (1,)).item())
1106
+ if type(self.bins) == list:
1107
+ bins = self.bins
1108
+ else:
1109
+ bins = [self.bins] * self.n_q
1110
+ bw_per_q = [math.log2(bin) * frame_rate / 1000 for bin in bins]
1111
+ bw = torch.tensor(sum(bw_per_q)).to(x)
1112
+ quantized, codes, commit_loss = self.vq(x, n_q=n_q)
1113
+ codes = codes.transpose(0, 1)
1114
+ # codes is [B, K, T], with T frames, K nb of codebooks.
1115
+ return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss))
1116
+
1117
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
1118
+ """Encode a given input tensor with the specified frame rate at the given bandwidth.
1119
+ The RVQ encode method sets the appropriate number of quantizer to use
1120
+ and returns indices for each quantizer.
1121
+ """
1122
+ n_q = self.n_q
1123
+ codes = self.vq.encode(x, n_q=n_q)
1124
+ codes = codes.transpose(0, 1)
1125
+ # codes is [B, K, T], with T frames, K nb of codebooks.
1126
+ return codes
1127
+
1128
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
1129
+ """Decode the given codes to the quantized representation."""
1130
+ # codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T].
1131
+ codes = codes.transpose(0, 1)
1132
+ quantized = self.vq.decode(codes)
1133
+ return quantized
1134
+
1135
+ @property
1136
+ def total_codebooks(self):
1137
+ return self.max_n_q
1138
+
1139
+ @property
1140
+ def num_codebooks(self):
1141
+ return self.n_q
1142
+
1143
+ def set_num_codebooks(self, n: int):
1144
+ assert n > 0 and n <= self.max_n_q
1145
+ self.n_q = n
1146
+
1147
+ class DummyQuantizer(BaseQuantizer):
1148
+ """Fake quantizer that actually does not perform any quantization.
1149
+ """
1150
+ def __init__(self):
1151
+ super().__init__()
1152
+
1153
+ def forward(self, x: torch.Tensor, frame_rate: int):
1154
+ q = x.unsqueeze(1)
1155
+ return QuantizedResult(x, q, torch.tensor(q.numel() * 32 * frame_rate / 1000 / len(x)).to(x))
1156
+
1157
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
1158
+ """Encode a given input tensor with the specified sample rate at the given bandwidth.
1159
+ In the case of the DummyQuantizer, the codes are actually identical
1160
+ to the input and resulting quantized representation as no quantization is done.
1161
+ """
1162
+ return x.unsqueeze(1)
1163
+
1164
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
1165
+ """Decode the given codes to the quantized representation.
1166
+ In the case of the DummyQuantizer, the codes are actually identical
1167
+ to the input and resulting quantized representation as no quantization is done.
1168
+ """
1169
+ return codes.squeeze(1)
1170
+
1171
+ @property
1172
+ def total_codebooks(self):
1173
+ """Total number of codebooks."""
1174
+ return 1
1175
+
1176
+ @property
1177
+ def num_codebooks(self):
1178
+ """Total number of codebooks."""
1179
+ return self.total_codebooks
1180
+
1181
+ def set_num_codebooks(self, n: int):
1182
+ """Set the number of active codebooks."""
1183
+ raise AttributeError("Cannot override the number of codebooks for the dummy quantizer")
1184
+
1185
+
1186
+ class EncodecModel(CompressionModel):
1187
+ """Encodec model operating on the raw waveform.
1188
+
1189
+ Args:
1190
+ encoder (nn.Module): Encoder network.
1191
+ decoder (nn.Module): Decoder network.
1192
+ quantizer (BaseQuantizer): Quantizer network.
1193
+ frame_rate (int): Frame rate for the latent representation.
1194
+ sample_rate (int): Audio sample rate.
1195
+ channels (int): Number of audio channels.
1196
+ causal (bool): Whether to use a causal version of the model.
1197
+ renormalize (bool): Whether to renormalize the audio before running the model.
1198
+ """
1199
+ # we need assignment to override the property in the abstract class,
1200
+ # I couldn't find a better way...
1201
+ frame_rate: float = 0
1202
+ sample_rate: int = 0
1203
+ channels: int = 0
1204
+
1205
+ def __init__(self,
1206
+ encoder: nn.Module,
1207
+ decoder: nn.Module,
1208
+ quantizer: BaseQuantizer,
1209
+ frame_rate: int,
1210
+ sample_rate: int,
1211
+ channels: int,
1212
+ causal: bool = False,
1213
+ renormalize: bool = False):
1214
+ super().__init__()
1215
+ self.encoder = encoder
1216
+ self.decoder = decoder
1217
+ self.quantizer = quantizer
1218
+ self.frame_rate = frame_rate
1219
+ self.sample_rate = sample_rate
1220
+ self.channels = channels
1221
+ self.renormalize = renormalize
1222
+ self.causal = causal
1223
+ if self.causal:
1224
+ # we force disabling here to avoid handling linear overlap of segments
1225
+ # as supported in original EnCodec codebase.
1226
+ assert not self.renormalize, 'Causal model does not support renormalize'
1227
+
1228
+ @property
1229
+ def total_codebooks(self):
1230
+ """Total number of quantizer codebooks available."""
1231
+ return self.quantizer.total_codebooks
1232
+
1233
+ @property
1234
+ def num_codebooks(self):
1235
+ """Active number of codebooks used by the quantizer."""
1236
+ return self.quantizer.num_codebooks
1237
+
1238
+ def set_num_codebooks(self, n: int):
1239
+ """Set the active number of codebooks used by the quantizer."""
1240
+ self.quantizer.set_num_codebooks(n)
1241
+
1242
+ @property
1243
+ def cardinality(self):
1244
+ """Cardinality of each codebook."""
1245
+ return self.quantizer.bins
1246
+
1247
+ def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
1248
+ scale: tp.Optional[torch.Tensor]
1249
+ if self.renormalize:
1250
+ mono = x.mean(dim=1, keepdim=True)
1251
+ volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt()
1252
+ scale = 1e-8 + volume
1253
+ x = x / scale
1254
+ scale = scale.view(-1, 1)
1255
+ else:
1256
+ scale = None
1257
+ return x, scale
1258
+
1259
+ def postprocess(self,
1260
+ x: torch.Tensor,
1261
+ scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
1262
+ if scale is not None:
1263
+ assert self.renormalize
1264
+ x = x * scale.view(-1, 1, 1)
1265
+ return x
1266
+
1267
+ def forward(self, x: torch.Tensor, encode=False) -> QuantizedResult:
1268
+ if encode:
1269
+ return self.encode(x)
1270
+ else:
1271
+ raise NotImplementedError("model forward and training is not supported.")
1272
+ assert x.dim() == 3
1273
+ length = x.shape[-1]
1274
+ x, scale = self.preprocess(x)
1275
+
1276
+ emb = self.encoder(x)
1277
+ q_res = self.quantizer(emb, self.frame_rate)
1278
+ out = self.decoder(q_res.x)
1279
+
1280
+ # remove extra padding added by the encoder and decoder
1281
+ assert out.shape[-1] >= length, (out.shape[-1], length)
1282
+ out = out[..., :length]
1283
+
1284
+ q_res.x = self.postprocess(out, scale)
1285
+
1286
+ return q_res
1287
+
1288
+ def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
1289
+ """Encode the given input tensor to quantized representation along with scale parameter.
1290
+
1291
+ Args:
1292
+ x (torch.Tensor): Float tensor of shape [B, C, T]
1293
+
1294
+ Returns:
1295
+ codes, scale (tuple of torch.Tensor, torch.Tensor): Tuple composed of:
1296
+ codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep.
1297
+ scale a float tensor containing the scale for audio renormalizealization.
1298
+ """
1299
+ assert x.dim() == 3
1300
+ x, scale = self.preprocess(x)
1301
+ emb = self.encoder(x)
1302
+ codes = self.quantizer.encode(emb)
1303
+ return codes, scale
1304
+
1305
+ def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
1306
+ """Decode the given codes to a reconstructed representation, using the scale to perform
1307
+ audio denormalization if needed.
1308
+
1309
+ Args:
1310
+ codes (torch.Tensor): Int tensor of shape [B, K, T]
1311
+ scale (torch.Tensor, optional): Float tensor containing the scale value.
1312
+
1313
+ Returns:
1314
+ out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio.
1315
+ """
1316
+ emb = self.decode_latent(codes)
1317
+ out = self.decoder(emb)
1318
+ out = self.postprocess(out, scale)
1319
+ # out contains extra padding added by the encoder and decoder
1320
+ return out
1321
+
1322
+ def decode_latent(self, codes: torch.Tensor):
1323
+ """Decode from the discrete codes to continuous latent space."""
1324
+ return self.quantizer.decode(codes)
1325
+
1326
+ class EncodecModel_encode_only(CompressionModel):
1327
+ """Encodec model operating on the raw waveform. Encode only, so no decoder
1328
+
1329
+ Args:
1330
+ encoder (nn.Module): Encoder network.
1331
+ quantizer (BaseQuantizer): Quantizer network.
1332
+ frame_rate (int): Frame rate for the latent representation.
1333
+ sample_rate (int): Audio sample rate.
1334
+ channels (int): Number of audio channels.
1335
+ causal (bool): Whether to use a causal version of the model.
1336
+ renormalize (bool): Whether to renormalize the audio before running the model.
1337
+ """
1338
+ # we need assignment to override the property in the abstract class,
1339
+ # I couldn't find a better way...
1340
+ frame_rate: float = 0
1341
+ sample_rate: int = 0
1342
+ channels: int = 0
1343
+
1344
+ def __init__(self,
1345
+ encoder: nn.Module,
1346
+ quantizer: BaseQuantizer,
1347
+ frame_rate: int,
1348
+ sample_rate: int,
1349
+ channels: int,
1350
+ causal: bool = False,
1351
+ renormalize: bool = False):
1352
+ super().__init__()
1353
+ self.encoder = encoder
1354
+ self.quantizer = quantizer
1355
+ self.frame_rate = frame_rate
1356
+ self.sample_rate = sample_rate
1357
+ self.channels = channels
1358
+ self.renormalize = renormalize
1359
+ self.causal = causal
1360
+ if self.causal:
1361
+ # we force disabling here to avoid handling linear overlap of segments
1362
+ # as supported in original EnCodec codebase.
1363
+ assert not self.renormalize, 'Causal model does not support renormalize'
1364
+
1365
+ @property
1366
+ def total_codebooks(self):
1367
+ """Total number of quantizer codebooks available."""
1368
+ return self.quantizer.total_codebooks
1369
+
1370
+ @property
1371
+ def num_codebooks(self):
1372
+ """Active number of codebooks used by the quantizer."""
1373
+ return self.quantizer.num_codebooks
1374
+
1375
+ def set_num_codebooks(self, n: int):
1376
+ """Set the active number of codebooks used by the quantizer."""
1377
+ self.quantizer.set_num_codebooks(n)
1378
+
1379
+ @property
1380
+ def cardinality(self):
1381
+ """Cardinality of each codebook."""
1382
+ return self.quantizer.bins
1383
+
1384
+ def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
1385
+ scale: tp.Optional[torch.Tensor]
1386
+ if self.renormalize:
1387
+ mono = x.mean(dim=1, keepdim=True)
1388
+ volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt()
1389
+ scale = 1e-8 + volume
1390
+ x = x / scale
1391
+ scale = scale.view(-1, 1)
1392
+ else:
1393
+ scale = None
1394
+ return x, scale
1395
+
1396
+ def postprocess(self,
1397
+ x: torch.Tensor,
1398
+ scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
1399
+ if scale is not None:
1400
+ assert self.renormalize
1401
+ x = x * scale.view(-1, 1, 1)
1402
+ return x
1403
+
1404
+ def forward(self, x: torch.Tensor, encode=False) -> QuantizedResult:
1405
+ if encode:
1406
+ return self.encode(x)
1407
+ else:
1408
+ raise NotImplementedError("model forward and training is not supported.")
1409
+ assert x.dim() == 3
1410
+ length = x.shape[-1]
1411
+ x, scale = self.preprocess(x)
1412
+
1413
+ emb = self.encoder(x)
1414
+ q_res = self.quantizer(emb, self.frame_rate)
1415
+ out = self.decoder(q_res.x)
1416
+
1417
+ # remove extra padding added by the encoder and decoder
1418
+ assert out.shape[-1] >= length, (out.shape[-1], length)
1419
+ out = out[..., :length]
1420
+
1421
+ q_res.x = self.postprocess(out, scale)
1422
+
1423
+ return q_res
1424
+
1425
+ def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
1426
+ """Encode the given input tensor to quantized representation along with scale parameter.
1427
+
1428
+ Args:
1429
+ x (torch.Tensor): Float tensor of shape [B, C, T]
1430
+
1431
+ Returns:
1432
+ codes, scale (tuple of torch.Tensor, torch.Tensor): Tuple composed of:
1433
+ codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep.
1434
+ scale a float tensor containing the scale for audio renormalizealization.
1435
+ """
1436
+ assert x.dim() == 3
1437
+ x, scale = self.preprocess(x)
1438
+ emb = self.encoder(x)
1439
+ codes = self.quantizer.encode(emb)
1440
+ return codes, scale
1441
+
1442
+ def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
1443
+ """Decode the given codes to a reconstructed representation, using the scale to perform
1444
+ audio denormalization if needed.
1445
+
1446
+ Args:
1447
+ codes (torch.Tensor): Int tensor of shape [B, K, T]
1448
+ scale (torch.Tensor, optional): Float tensor containing the scale value.
1449
+
1450
+ Returns:
1451
+ out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio.
1452
+ """
1453
+ raise NotImplementedError("Decode is not supported for encode only model")
1454
+ emb = self.decode_latent(codes)
1455
+ out = self.decoder(emb)
1456
+ out = self.postprocess(out, scale)
1457
+ # out contains extra padding added by the encoder and decoder
1458
+ return out
1459
+
1460
+ def decode_latent(self, codes: torch.Tensor):
1461
+ """Decode from the discrete codes to continuous latent space."""
1462
+ raise NotImplementedError("Decode is not supported for encode only model")
1463
+ return self.quantizer.decode(codes)
1464
+
1465
+ def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) -> BaseQuantizer:
1466
+ klass = {
1467
+ 'no_quant': DummyQuantizer,
1468
+ 'rvq': ResidualVectorQuantizer
1469
+ }[quantizer]
1470
+ kwargs = dict_from_config(getattr(cfg, quantizer))
1471
+ if quantizer != 'no_quant':
1472
+ kwargs['dimension'] = dimension
1473
+ return klass(**kwargs)
1474
+
1475
+ def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig):
1476
+ if encoder_name == 'seanet':
1477
+ kwargs = dict_from_config(getattr(cfg, 'seanet'))
1478
+ encoder_override_kwargs = kwargs.pop('encoder')
1479
+ decoder_override_kwargs = kwargs.pop('decoder')
1480
+ encoder_kwargs = {**kwargs, **encoder_override_kwargs}
1481
+ decoder_kwargs = {**kwargs, **decoder_override_kwargs}
1482
+ encoder = SEANetEncoder(**encoder_kwargs)
1483
+ decoder = SEANetDecoder(**decoder_kwargs)
1484
+ return encoder, decoder
1485
+ else:
1486
+ raise KeyError(f"Unexpected compression model {cfg.compression_model}")
1487
+
1488
+
1489
+ def get_compression_model(ckpt_fn, encode_only=False, device="cpu") -> CompressionModel:
1490
+ """Instantiate a compression model."""
1491
+ if device == None:
1492
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1493
+ state = torch.load(ckpt_fn, map_location='cpu')
1494
+ cfg = state['xp.cfg']
1495
+ cfg.device = str(device)
1496
+ weights = state['best_state']['model']
1497
+ assert cfg.compression_model == 'encodec', "Only Encodec model is supported for now."
1498
+ if encode_only:
1499
+ all_keys = list(weights.keys())
1500
+ for key in all_keys:
1501
+ if key.startswith('decoder'):
1502
+ del weights[key]
1503
+ kwargs = dict_from_config(getattr(cfg, 'encodec'))
1504
+ encoder_name = kwargs.pop('autoencoder')
1505
+ quantizer_name = kwargs.pop('quantizer')
1506
+ encoder, _ = get_encodec_autoencoder(encoder_name, cfg)
1507
+ quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension)
1508
+ frame_rate = kwargs['sample_rate'] // encoder.hop_length
1509
+ renormalize = kwargs.pop('renormalize', False)
1510
+ # deprecated params
1511
+ kwargs.pop('renorm', None)
1512
+ compression_model = EncodecModel_encode_only(encoder, quantizer,
1513
+ frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device)
1514
+ assert compression_model.sample_rate == cfg.sample_rate, "Compression model sample rate should match"
1515
+ compression_model.load_state_dict(weights)
1516
+ compression_model.eval()
1517
+ return compression_model
1518
+
1519
+ else:
1520
+ kwargs = dict_from_config(getattr(cfg, 'encodec'))
1521
+ encoder_name = kwargs.pop('autoencoder')
1522
+ quantizer_name = kwargs.pop('quantizer')
1523
+ encoder, decoder = get_encodec_autoencoder(encoder_name, cfg)
1524
+ quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension)
1525
+ frame_rate = kwargs['sample_rate'] // encoder.hop_length
1526
+ renormalize = kwargs.pop('renormalize', False)
1527
+ # deprecated params
1528
+ kwargs.pop('renorm', None)
1529
+ compression_model = EncodecModel(encoder, decoder, quantizer,
1530
+ frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device)
1531
+ assert compression_model.sample_rate == cfg.sample_rate, "Compression model sample rate should match"
1532
+ compression_model.load_state_dict(weights)
1533
+ compression_model.eval()
1534
+ return compression_model
1535
+
1536
+ if __name__ == "__main__":
1537
+ import torchaudio
1538
+ ckpt_fn = "/home/pyp/BoostedVoiceEditor/pretrained/encodec_6f79c6a8.th"
1539
+ audio_in_fns = ["/home/pyp/BoostedVoiceEditor/demo/pam.wav", "/home/pyp/BoostedVoiceEditor/demo/ray.wav", "/home/pyp/BoostedVoiceEditor/demo/84_121550_000074_000000.wav", "/home/pyp/BoostedVoiceEditor/demo/caribbean.wav", "/home/pyp/BoostedVoiceEditor/demo/bible.wav", "/home/pyp/BoostedVoiceEditor/demo/miley.wav"]
1540
+ audio_out_fns = ["/home/pyp/BoostedVoiceEditor/demo/pam_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/ray_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/84_121550_000074_000000_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/caribbean_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/bible_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/miley_encodecTest.wav"]
1541
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1542
+ model = get_compression_model(ckpt_fn, device=device)
1543
+
1544
+ for audio_in_fn, audio_out_fn in zip(audio_in_fns, audio_out_fns):
1545
+ audio_in, sr = torchaudio.load(audio_in_fn)
1546
+ if sr != model.sample_rate:
1547
+ audio_in = torchaudio.transforms.Resample(sr, model.sample_rate)(audio_in)
1548
+ if audio_in.shape[0] == 2:
1549
+ audio_in = audio_in.mean(dim=0, keepdim=True)
1550
+ audio_in = audio_in.unsqueeze(0)
1551
+ audio_in = audio_in.to(torch.float32).to(device)
1552
+ codes = model.encode(audio_in)[0]
1553
+ audio_out = model.decode(codes)[0].cpu()
1554
+ torchaudio.save(audio_out_fn, audio_out, model.sample_rate)
data/ll60k_preprocessing/step1_download.sh ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # define where to store the the downloaded data
2
+ dataroot=$DATAROOT
3
+ mkdir -p $dataroot
4
+ manifestroot=$dataroot/libriheavy
5
+ mkdir -p $manifestroot
6
+ audioroot=$dataroot/audio
7
+ mkdir -p $audioroot
8
+
9
+ # download libriheavy_long and libriheavy
10
+ cd $manifestroot
11
+ wget https://huggingface.co/datasets/pkufool/libriheavy/resolve/main/libriheavy_cuts_dev.jsonl.gz?download=true -O libriheavy_cuts_dev.jsonl.gz
12
+ wget https://huggingface.co/datasets/pkufool/libriheavy/resolve/main/libriheavy_cuts_test_clean.jsonl.gz?download=true -O libriheavy_cuts_test_clean.jsonl.gz
13
+ wget https://huggingface.co/datasets/pkufool/libriheavy/resolve/main/libriheavy_cuts_test_other.jsonl.gz?download=true -O libriheavy_cuts_test_other.jsonl.gz
14
+ wget https://huggingface.co/datasets/pkufool/libriheavy/resolve/main/libriheavy_cuts_small.jsonl.gz?download=true -O libriheavy_cuts_small.jsonl.gz
15
+ wget https://huggingface.co/datasets/pkufool/libriheavy/resolve/main/libriheavy_cuts_medium.jsonl.gz?download=true -O libriheavy_cuts_medium.jsonl.gz
16
+ wget https://huggingface.co/datasets/pkufool/libriheavy/resolve/main/libriheavy_cuts_large.jsonl.gz?download=true -O libriheavy_cuts_large.jsonl.gz
17
+ wget https://huggingface.co/datasets/pkufool/libriheavy_long/resolve/main/libriheavy_cuts_small.jsonl.gz?download=true -O libriheavy_long_original_cuts_small.jsonl.gz
18
+ wget https://huggingface.co/datasets/pkufool/libriheavy_long/resolve/main/libriheavy_cuts_medium.jsonl.gz?download=true -O libriheavy_long_original_cuts_medium.jsonl.gz
19
+ wget https://huggingface.co/datasets/pkufool/libriheavy_long/resolve/main/libriheavy_cuts_large.jsonl.gz?download=true -O libriheavy_long_original_cuts_large.jsonl.gz
20
+
21
+ # turn .jsonl.gz to .jsonl
22
+ gunzip -k libriheavy_cuts_dev.jsonl.gz
23
+ gunzip -k libriheavy_cuts_test_clean.jsonl.gz
24
+ gunzip -k libriheavy_cuts_test_other.jsonl.gz
25
+ gunzip -k libriheavy_cuts_small.jsonl.gz
26
+ gunzip -k libriheavy_cuts_medium.jsonl.gz
27
+ gunzip -k libriheavy_cuts_large.jsonl.gz
28
+ gunzip -k libriheavy_long_original_cuts_small.jsonl.gz
29
+ gunzip -k libriheavy_long_original_cuts_medium.jsonl.gz
30
+ gunzip -k libriheavy_long_original_cuts_large.jsonl.gz
31
+
32
+ # if librilight is already unzipped in origDATAROOT, then skip this step
33
+ # download ll
34
+ cd $audioroot
35
+ wget https://dl.fbaipublicfiles.com/librilight/data/small.tar
36
+ wget https://dl.fbaipublicfiles.com/librilight/data/medium.tar
37
+ wget https://dl.fbaipublicfiles.com/librilight/data/large.tar
38
+
39
+ # untar small, medium, large
40
+ tar -xf small.tar
41
+ tar -xf medium.tar
42
+ tar -xf large.tar
data/ll60k_preprocessing/step2_resplit_long.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # find split, spk, books in libriheavy_cuts_dev.jsonl, libriheavy_cuts_test_clean.jsonl, libriheavy_cuts_test_other.jsonl
2
+ # those would be in "id" field
3
+
4
+ import sys
5
+ import os, random, numpy as np, socket
6
+ import json
7
+ import tqdm
8
+ def write_jsonl(data, fn):
9
+ with open(fn, "w") as file:
10
+ for entry in data:
11
+ file.write(json.dumps(entry, ensure_ascii=False) + "\n")
12
+ def read_jsonl(file_path):
13
+ cur_data = []
14
+ with open(file_path, 'r', encoding='utf-8-sig') as file:
15
+ for line in file:
16
+ cur_data.append(json.loads(line.strip()))
17
+ return cur_data
18
+ import os
19
+ dataroot=os.environ["DATAROOT"]
20
+ manifestroot=os.path.join(dataroot, "libriheavy")
21
+ tgt_names = ['libriheavy_cuts_dev.jsonl', 'libriheavy_cuts_test_clean.jsonl', 'libriheavy_cuts_test_other.jsonl']
22
+ orig_names = ['libriheavy_long_original_cuts_small.jsonl', 'libriheavy_long_original_cuts_medium.jsonl', 'libriheavy_long_original_cuts_large.jsonl']
23
+
24
+ id2split = {}
25
+ data = read_jsonl(os.path.join(manifestroot, "libriheavy_cuts_dev.jsonl"))
26
+ dev_ids = set(["/".join(item['id'].split("/")[:3]) for item in data])
27
+ data = read_jsonl(os.path.join(manifestroot, "libriheavy_cuts_test_clean.jsonl"))
28
+ test_clean_ids = set(["/".join(item['id'].split("/")[:3]) for item in data])
29
+ data = read_jsonl(os.path.join(manifestroot, "libriheavy_cuts_test_other.jsonl"))
30
+ test_other_ids = set(["/".join(item['id'].split("/")[:3]) for item in data])
31
+
32
+ long_dev = []
33
+ long_test_clean = []
34
+ long_test_other = []
35
+ for orig_name in orig_names:
36
+ keep = []
37
+ data = read_jsonl(os.path.join(manifestroot, orig_name))
38
+ for item in tqdm.tqdm(data):
39
+ if "/".join(item['id'].split("/")[:3]) in dev_ids:
40
+ long_dev.append(item)
41
+ elif "/".join(item['id'].split("/")[:3]) in test_clean_ids:
42
+ long_test_clean.append(item)
43
+ elif "/".join(item['id'].split("/")[:3]) in test_other_ids:
44
+ long_test_other.append(item)
45
+ else:
46
+ keep.append(item)
47
+ write_jsonl(keep, os.path.join(manifestroot, orig_name.replace("_original", "")))
48
+
49
+ write_jsonl(long_dev, os.path.join(manifestroot, "libriheavy_long_cuts_dev.jsonl"))
50
+ write_jsonl(long_test_clean, os.path.join(manifestroot, "libriheavy_long_cuts_test_clean.jsonl"))
51
+ write_jsonl(long_test_other, os.path.join(manifestroot, "libriheavy_long_cuts_test_other.jsonl"))
data/ll60k_preprocessing/step3_seg_phn_manifest.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from importlib.resources import path
2
+ import pathlib
3
+ import soundfile as sf
4
+ import numpy as np
5
+ import json
6
+ import multiprocessing
7
+ import argparse
8
+ import tqdm
9
+ import gzip
10
+ import time
11
+ import os
12
+ from tokenizer import TextTokenizer, tokenize_text
13
+ import glob
14
+ import sys
15
+ import os, random, numpy as np, socket
16
+ import json
17
+ import tqdm
18
+ def write_jsonl(data, fn):
19
+ with open(fn, "w") as file:
20
+ for entry in data:
21
+ file.write(json.dumps(entry, ensure_ascii=False) + "\n")
22
+ def read_jsonl(file_path):
23
+ cur_data = []
24
+ with open(file_path, 'r', encoding='utf-8-sig') as file:
25
+ for line in file:
26
+ cur_data.append(json.loads(line.strip()))
27
+ return cur_data
28
+ def save_audio(seq, fn):
29
+ output = seq
30
+ os.makedirs(os.path.dirname(fn), exist_ok=True)
31
+ sf.write(fn, output, samplerate=16000)
32
+
33
+ def save_text(text, fn):
34
+ os.makedirs(os.path.dirname(fn), exist_ok=True)
35
+ with open(fn, "w") as wwf:
36
+ wwf.writelines(text)
37
+
38
+ def phonemize_and_save(text, fn):
39
+ phn = tokenize_text(text_tokenizer, text)
40
+ os.makedirs(os.path.dirname(fn), exist_ok=True)
41
+ with open(fn, "w") as f:
42
+ f.write(' '.join(phn))
43
+ return set(phn)
44
+
45
+ def cut_sequence(task):
46
+ in_audio_fn, output_dir, metadata = task
47
+ if not os.path.isfile(in_audio_fn):
48
+ # print("missing: ", in_audio_fn)
49
+ return None
50
+ data, samplerate = sf.read(in_audio_fn)
51
+ assert len(data.shape) == 1
52
+ assert samplerate == 16000
53
+ all_phns = set()
54
+ for item in metadata:
55
+ out_fn = item['file_id']
56
+ out_audio_fn = os.path.join(output_dir, "audio", out_fn)
57
+ out_text_fn = os.path.join(output_dir, "audio", out_fn.replace(".flac", ".txt"))
58
+ out_phn_fn = os.path.join(output_dir, "phoneme", out_fn.replace(".flac", ".txt"))
59
+ save_audio(data[int(item['vad'][0]*samplerate):int(item['vad'][1]*samplerate)], out_audio_fn)
60
+ save_text(item['text'], out_text_fn)
61
+ phns = phonemize_and_save(item['text'], out_phn_fn)
62
+ all_phns.update(phns)
63
+
64
+ return all_phns
65
+
66
+
67
+ from collections import defaultdict
68
+ # Function to create a defaultdict recursively
69
+ def nested_defaultdict(levels, inner_type):
70
+ if levels <= 1:
71
+ return defaultdict(inner_type)
72
+ return defaultdict(lambda: nested_defaultdict(levels-1, inner_type))
73
+
74
+
75
+ def open_mani(fn):
76
+ print("load segmentation and transcription metadata...")
77
+ stime = time.time()
78
+ data = []
79
+ with gzip.open(fn, 'rt', encoding='utf-8') as f:
80
+ for line in f:
81
+ data.append(json.loads(line))
82
+ print(f"loading done, took {time.time() - stime:.4f} seconds")
83
+ return data
84
+
85
+ def cut(split,
86
+ audio_dir,
87
+ mani_dir,
88
+ output_dir,
89
+ n_process=32,
90
+ percent=0.5):
91
+ split2manifest = {
92
+ "train": [
93
+ "libriheavy_long_cuts_small.jsonl",
94
+ "libriheavy_long_cuts_medium.jsonl",
95
+ "libriheavy_long_cuts_large.jsonl",
96
+ "libriheavy_cuts_small.jsonl",
97
+ "libriheavy_cuts_medium.jsonl",
98
+ "libriheavy_cuts_large.jsonl",
99
+ ],
100
+ "valid": [
101
+ "libriheavy_cuts_dev.jsonl",
102
+ "libriheavy_long_cuts_dev.jsonl"
103
+ ],
104
+ "test": [
105
+ "libriheavy_cuts_test_clean.jsonl",
106
+ "libriheavy_cuts_test_other.jsonl",
107
+ "libriheavy_long_cuts_test_clean.jsonl",
108
+ "libriheavy_long_cuts_test_other.jsonl"
109
+ ]
110
+ }
111
+
112
+ print("organize data by recording_id (i.e. the original big .flac file name)...")
113
+ stime = time.time()
114
+ organized_data = nested_defaultdict(4, list)
115
+ manifest_fn = os.path.join(output_dir, "manifest_mimi", split+".txt")
116
+ os.makedirs(os.path.join(output_dir, "manifest_mimi"), exist_ok=True)
117
+ with open(manifest_fn, "w") as wf:
118
+ for mani_fn in split2manifest[split]:
119
+ # data = open_mani(os.path.join(mani_dir, mani_fn))
120
+ data = read_jsonl(os.path.join(mani_dir, mani_fn))
121
+ for item in data:
122
+ file_id = item['supervisions'][0]['id'] + '.flac'
123
+ recording_id = item['recording']['id'] + '.flac'
124
+ sizeSplit, spk, book, flac = recording_id.split("/") # e.g. 'medium/100/emerald_city_librivox_64kb_mp3/emeraldcity_01_baum_64kb'
125
+ if os.path.isfile(os.path.join(audio_dir, recording_id)):
126
+ vad = (item['start'], item['start']+item['duration'])
127
+ text = item['supervisions'][0]['custom']['texts'][0]
128
+ file_id = file_id.replace(".flac", "") + f"_{vad[0]:.2f}_{vad[1]:.2f}.flac"
129
+ organized_data[sizeSplit][spk][book][recording_id].append({"file_id": file_id, "vad":vad, "text": text})
130
+ wf.writelines(f"{file_id}\t{item['duration']}\n")
131
+
132
+ # #### take only a subet of tasks
133
+ tasks = [(os.path.join(audio_dir, recording_id), output_dir, organized_data[sizeSplit][spk][book][recording_id], spk) for sizeSplit in organized_data for spk in organized_data[sizeSplit] for book in organized_data[sizeSplit][spk] for recording_id in organized_data[sizeSplit][spk][book]]
134
+ ntasks = len(tasks)
135
+ spk2tasks = defaultdict(list)
136
+ for task in tasks:
137
+ spk2tasks[task[3]].append(task)
138
+ # randomly shuffle each task list for each speaker
139
+ for spk in spk2tasks:
140
+ random.shuffle(spk2tasks[spk])
141
+ # take only 20% of the tasks, uniformly sampled from each speaker
142
+ # randomly pick a speaker, and then randomly pick a task from that speaker
143
+ tasks = []
144
+ while len(tasks) < ntasks * percent:
145
+ spk = random.choice(list(spk2tasks.keys()))
146
+ if len(spk2tasks[spk]) == 0:
147
+ continue
148
+ tasks.append(spk2tasks[spk].pop()[:-1])
149
+ print(f"take only {percent*100:.2f}% of the tasks, {len(tasks)} out of {ntasks} tasks")
150
+ #### take only a subet of tasks
151
+
152
+ print(f"organizing done, took {time.time() - stime:.4f} seconds")
153
+ print(f"Launching {n_process} processes")
154
+ phn_vocab = set()
155
+ cnt = 0
156
+ with multiprocessing.Pool(processes=n_process) as pool:
157
+ for phns in tqdm.tqdm(pool.imap_unordered(cut_sequence, tasks), total=len(tasks)):
158
+ cnt += 1
159
+ if phns != None:
160
+ phn_vocab.update(phns)
161
+
162
+ # save phn vocabulary
163
+ if split == "train":
164
+ vocab_fn = os.path.join(output_dir, "vocab.txt")
165
+ with open(vocab_fn, "w") as f:
166
+ for i, phn in enumerate(list(phn_vocab)):
167
+ if i < len(list(phn_vocab)) - 1:
168
+ f.write(f"{str(i)}\t{phn}\n")
169
+ else:
170
+ f.write(f"{str(i)}\t{phn}")
171
+
172
+
173
+ def parse_args():
174
+ parser = argparse.ArgumentParser(description="Cut a dataset in small "
175
+ "sequences using VAD files")
176
+ parser.add_argument('--split', type=str, default='train', choices=['train', 'valid', 'test'], help="train = libriheavy_cuts_{small,medium,large}.jsonl.gz, valid = libriheavy_cuts_dev_{clean,other}.jsonl.gz, test = libriheavy_cuts_test_{clean,other}.jsonl.gz")
177
+ parser.add_argument('--audio_dir', type=str, default="/data/scratch/pyp/datasets/librilight_example",
178
+ help="Path to the audio directory")
179
+ parser.add_argument('--manifest_dir', type=str, default="/data/scratch/pyp/datasets/librilight/libriheavy", help="path to the transcription file's dir, can be downloaded https://huggingface.co/datasets/pkufool/libriheavy/tree/main/v0.1")
180
+ parser.add_argument('--output_dir', type=str, default="/data/scratch/pyp/datasets/librilight/librilight_example_preprocessed",
181
+ help="Path to the output directory")
182
+ parser.add_argument('--n_workers', type=int, default=16,
183
+ help="Number of parallel worker processes")
184
+ parser.add_argument('--percent', type=float, default=0.5, help="take only this percent of the tasks, randomly sampled from each speaker")
185
+
186
+
187
+ return parser.parse_args()
188
+
189
+
190
+ if __name__ == "__main__":
191
+ args = parse_args()
192
+ pathlib.Path(args.output_dir).mkdir(exist_ok=True, parents=True)
193
+ text_tokenizer = TextTokenizer()
194
+ cut(args.split, args.audio_dir, args.manifest_dir, args.output_dir, args.n_workers, args.percent)
data/ll60k_preprocessing/step4_encodec_encode.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from email.policy import default
3
+ def parse_args():
4
+ parser = argparse.ArgumentParser(description="encode the librilight dataset using codec model")
5
+ parser.add_argument('--dir', type=str, default="/data/scratch/pyp/datasets/librilight/librilight_example_preprocessed", help="Path to the directory")
6
+ parser.add_argument('--sub_root', type=str, default="preprocessed", help="sub directory")
7
+ parser.add_argument('--encodec_name', type=str, default="encodec_6f79c6a8.th", help="name of the codec model")
8
+ parser.add_argument('--n_workers', type=int, default=16, help="Number of parallel worker processes")
9
+ parser.add_argument('--batch_size', type=int, default=16, help="batch size for codec encoding, decrease it if OOM. This is the sum of batch size *over each gpu*, so increase it if you are using more gpus")
10
+ parser.add_argument('--audio_sr', type=int, default=16000, help='input audio sample rate')
11
+ parser.add_argument('--model_sr', type=int, default=16000, help='encodec input audio sample rate')
12
+ parser.add_argument('--downsample_rate', type=int, default=320, help='encodec downsample rate')
13
+ parser.add_argument('--model_code_sr', type=float, default=50, help='codec model code sample rate')
14
+ parser.add_argument('--len_cap', type=float, default=1000, help='will drop audios that are longer than this number')
15
+ parser.add_argument('--min_len', type=float, default=0.5, help='will drop audios that are shorter than this number')
16
+ parser.add_argument('--partition', type=str, default="1/1", help='split for parallel processing')
17
+ parser.add_argument('--split', type=str, default='train', choices=['train', 'valid', 'test'])
18
+ return parser.parse_args()
19
+
20
+ if __name__ == "__main__":
21
+ import logging
22
+ formatter = (
23
+ "%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s"
24
+ )
25
+ logging.basicConfig(format=formatter, level=logging.INFO)
26
+
27
+ import os, sys
28
+ import numpy as np
29
+ import torch
30
+ import torchaudio
31
+ import tqdm
32
+ import time
33
+
34
+ args = parse_args()
35
+
36
+ def sort_by_audio_len(lens):
37
+ inds = np.argsort(lens).tolist()
38
+ if len(inds) < 10:
39
+ return inds[::-1]
40
+ logging.info(f"longest: {lens[inds[-1]]/args.downsample_rate} encodec codes, {lens[inds[-1]]/args.model_sr:.2f} sec.")
41
+ logging.info(f"shortest: {lens[inds[0]]/args.downsample_rate} encodec codes, {lens[inds[0]]/args.model_sr:.2f} sec.")
42
+ logging.info(f"median: {lens[inds[len(inds)//2]]/args.downsample_rate} encodec codes, {lens[inds[len(inds)//2]]/args.model_sr:.2f} sec.")
43
+ logging.info(f"95 percentile longest: {lens[inds[int(len(inds)*0.95)]]/args.downsample_rate} encodec codes, {lens[inds[int(len(inds)*0.95)]]/args.model_sr:.2f} sec.")
44
+ return inds[::-1]
45
+
46
+ def write_array_to_txt_file(array, filename):
47
+ with open(filename, 'w') as f:
48
+ for a in array[:-1]:
49
+ f.write(' '.join(map(str, a))+'\n')
50
+ f.write(' '.join(map(str, array[-1])))
51
+
52
+ class mydataset(torch.utils.data.Dataset):
53
+ def __init__(self, split):
54
+ super().__init__()
55
+ self.split = split
56
+ self.audio_dir = audio_dir
57
+ manifest_fn = os.path.join(encodec_manifest_dir, split+".txt")
58
+ cur_sp = int(args.partition.split("/")[0])-1
59
+ total_sp = int(args.partition.split("/")[1])
60
+ with open(manifest_fn, "r") as rf:
61
+ self.data = [l.strip().split("\t") for l in rf.readlines()][cur_sp::total_sp]
62
+ self.data = [l for l in self.data if os.path.isfile(os.path.join(self.audio_dir, l[0]))]
63
+ def __len__(self):
64
+ return len(self.data)
65
+ def __getitem__(self, ind):
66
+ try:
67
+ afn = self.data[ind][0]
68
+ fn = os.path.join(self.audio_dir, afn)
69
+ audio, sr = torchaudio.load(fn)
70
+ if sr != args.model_sr:
71
+ audio = torchaudio.transforms.Resample(sr, args.model_sr)(audio)
72
+ sr = args.model_sr
73
+ assert sr == args.model_sr, sr
74
+ except Exception as e:
75
+ # logging.info(f"{e}")
76
+ return None, None, None
77
+ assert audio.ndim==2 and audio.shape[0] == 1, audio.shape
78
+ return audio.type(torch.float32).squeeze(0), audio.shape[-1], os.path.splitext(afn)[0]
79
+ def collate(self, batch):
80
+ lens, audios, segment_ids = [], [], []
81
+ for item in batch:
82
+ if item[0] != None:
83
+ audios.append(item[0])
84
+ lens.append(item[1])
85
+ segment_ids.append(item[2])
86
+ return audios, lens, segment_ids
87
+
88
+ # roots
89
+ sub_root = args.sub_root
90
+ encodec_manifest_dir = os.path.join(args.dir, sub_root, "manifest_mimi")
91
+ audio_dir = os.path.join(args.dir, sub_root, "audio")
92
+ save_manifest_dir = os.path.join(args.dir, sub_root,"manifest_final_encodec")
93
+ if args.encodec_name == "encodec_6f79c6a8.th":
94
+ save_codes_dir = os.path.join(args.dir, sub_root,"encodec_4cb")
95
+ elif args.encodec_name == "encodec_8cb1024_giga.th":
96
+ save_codes_dir = os.path.join(args.dir, sub_root,"encodec_8cb")
97
+
98
+ os.makedirs(save_manifest_dir, exist_ok=True)
99
+ os.makedirs(save_codes_dir, exist_ok=True)
100
+
101
+ # load the encodec model
102
+ def import_encodec():
103
+ from encodec import get_compression_model
104
+ userdir = os.path.expanduser("~")
105
+ model = get_compression_model(os.path.join(userdir, "VoiceStar", f"pretrained/{args.encodec_name}"), encode_only=True, device="cuda")
106
+ model = torch.nn.DataParallel(model)
107
+ return model
108
+ model = import_encodec()
109
+
110
+ # setup dataloader
111
+ mega_batch_size = 1024
112
+ batch_size = args.batch_size
113
+
114
+ dataset = mydataset(args.split)
115
+ if len(dataset) == 0:
116
+ logging.info(f"no data found for split {args.split} partition {args.partition}")
117
+ sys.exit(0)
118
+ loader = torch.torch.utils.data.DataLoader(dataset, batch_size=mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=dataset.collate)
119
+ split = args.split
120
+
121
+ skip = 0
122
+ logging.info(f"now processing split {split} partition {args.partition}...")
123
+ mega_n_steps = int(np.ceil(len(loader.dataset) / mega_batch_size))
124
+ # mega_n_steps = int(np.ceil(len(gs) / mega_batch_size))
125
+ logging.info(f"partition the split {split} into {mega_n_steps} parts, each has at most {mega_batch_size} samples")
126
+ mani_fn = os.path.join(save_manifest_dir, f"{split}_{args.partition.replace('/', '=')}.txt")
127
+ logging.info(f"manifest for split {split} partition {args.partition.replace('/', '=')}.txt will be saved at {mani_fn}")
128
+ with open(mani_fn, "w") as mani_wf:
129
+ # with open(mani_fn, "a") as mani_wf: # resume from where we failed
130
+ for m, mega_batch in enumerate(tqdm.tqdm(loader)):
131
+
132
+ logging.info(f"====================================")
133
+ logging.info(f"====================================")
134
+ logging.info(f"now processing mega step {m+1}/{mega_n_steps}")
135
+ try:
136
+ # if True:
137
+ lengths = np.array(mega_batch[1])
138
+ if len(lengths) == 0: # the loader might not find any audio because step3 will write to manifest first, and then might selection a subset to cut and save audio
139
+ continue
140
+ sorted_inds = sort_by_audio_len(lengths)
141
+ for j in range(len(sorted_inds))[::-1]:
142
+ if lengths[sorted_inds[j]] < args.model_sr*args.min_len or lengths[sorted_inds[j]] > args.model_sr*args.len_cap: # skip samples that are too short (shorter than 0.2s), or too big (bigger than 80s)
143
+ skip += 1
144
+ del sorted_inds[j]
145
+
146
+ n_steps = int(np.ceil(len(sorted_inds) / batch_size))
147
+ for n in tqdm.tqdm(range(n_steps), disable=True):
148
+ inds_used = sorted_inds[n*batch_size:(n+1)*batch_size]
149
+ while len(inds_used) < batch_size:
150
+ inds_used += sorted_inds[:batch_size-len(inds_used)]
151
+ wav_batch = [mega_batch[0][id] for id in inds_used]
152
+ all_lens = [mega_batch[1][id] for id in inds_used]
153
+ segment_id_batch = [mega_batch[2][id] for id in inds_used]
154
+ padded_wav = torch.nn.utils.rnn.pad_sequence(wav_batch, batch_first=True).unsqueeze(1) # [B, T] -> [B, 1, T]
155
+ # Extract discrete codes from EnCodec
156
+ with torch.no_grad():
157
+ if max(all_lens) > 300000 and len(all_lens) > 1: # if utterances are long, simply pass half of them at a time
158
+ codes = []
159
+ inwav = padded_wav.cuda()
160
+ codes.append(model(inwav[:len(inwav)//2])[0].cpu())
161
+ codes.append(model(inwav[len(inwav)//2:])[0].cpu())
162
+ codes = torch.cat(codes, dim=0)
163
+ else:
164
+ encoded_frames = model(padded_wav.cuda())
165
+ codes = encoded_frames[0].cpu() # [B, n_codebook, T]
166
+
167
+ for i, length in enumerate(all_lens):
168
+ save_fn = os.path.join(save_codes_dir, segment_id_batch[i]+".txt")
169
+ actual_len = round(length / args.downsample_rate) # 320 is downsample rate for this model
170
+ cur_code = codes[i].tolist() if type(codes) == list else codes[i, :, :actual_len].tolist()
171
+ os.makedirs(os.path.dirname(save_fn), exist_ok=True)
172
+ write_array_to_txt_file(cur_code, save_fn)
173
+
174
+ mani_wf.write(f"{segment_id_batch[i]}\t{len(cur_code[0])}\n") # write to manifest file
175
+ # if i == 10:
176
+ # raise
177
+ except Exception as e:
178
+ print(f'exception!! at {m+1}')
179
+ print(e)
180
+ continue
181
+
182
+ # break
183
+ logging.info(f"split {split} partition {args.partition} has {len(loader.dataset)} samples in total, skipped {skip} due to utterance being too long or too short")
184
+ # break
data/ll60k_preprocessing/step4_encodec_encode_script.sh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ source ~/miniconda3/etc/profile.d/conda.sh
3
+ conda activate voicestar
4
+
5
+ dir=${dir:-/data/scratch/pyp/datasets/librilight}
6
+ sub_root=${sub_root:-preprocessed}
7
+ encodec_name=${encodec_name:-encodec_6f79c6a8.th} # or encodec_8cb1024_giga.th
8
+ n_workers=${n_workers:-12}
9
+ batch_size=${batch_size:-64}
10
+ audio_sr=16000
11
+ model_sr=16000
12
+ downsample_rate=320
13
+ model_code_sr=50
14
+ len_cap=1000
15
+ min_len=0.5
16
+ partition=${partition:-"1/1"}
17
+ split=${split:-"valid"}
18
+
19
+ python step4_encodec_encode.py --dir $dir --sub_root ${sub_root} --encodec_name ${encodec_name} --n_workers $n_workers --batch_size $batch_size --audio_sr $audio_sr --model_sr $model_sr --downsample_rate $downsample_rate --model_code_sr $model_code_sr --len_cap $len_cap --min_len $min_len --partition $partition --split $split
data/ll60k_preprocessing/step5_find_nearest_neighbor.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # for each each audio segment, find the non-overlapping neighboring segments
2
+ from importlib.resources import path
3
+ import pathlib
4
+ import soundfile as sf
5
+ import numpy as np
6
+ import json
7
+ import multiprocessing
8
+ import argparse
9
+ import tqdm
10
+ import gzip
11
+ import time
12
+ import os
13
+ from tokenizer import TextTokenizer, tokenize_text
14
+ import glob
15
+ import sys
16
+ import os, random, numpy as np, socket
17
+ import json
18
+ import tqdm
19
+ import json
20
+ import tqdm
21
+ def write_jsonl(data, fn):
22
+ with open(fn, "w") as file:
23
+ for entry in data:
24
+ file.write(json.dumps(entry, ensure_ascii=False) + "\n")
25
+ def read_jsonl(file_path):
26
+ cur_data = []
27
+ with open(file_path, 'r', encoding='utf-8-sig') as file:
28
+ for line in file:
29
+ cur_data.append(json.loads(line.strip()))
30
+ return cur_data
31
+ from collections import defaultdict
32
+ # Function to create a defaultdict recursively
33
+ def nested_defaultdict(levels, inner_type):
34
+ if levels <= 1:
35
+ return defaultdict(inner_type)
36
+ return defaultdict(lambda: nested_defaultdict(levels-1, inner_type))
37
+
38
+ def find_neighbor(args):
39
+ split2manifest = {
40
+ "train": [
41
+ "libriheavy_cuts_small.jsonl",
42
+ "libriheavy_cuts_medium.jsonl",
43
+ "libriheavy_cuts_large.jsonl",
44
+ "libriheavy_long_cuts_small.jsonl",
45
+ "libriheavy_long_cuts_medium.jsonl",
46
+ "libriheavy_long_cuts_large.jsonl"
47
+ ],
48
+ "valid": [
49
+ "libriheavy_cuts_dev.jsonl",
50
+ "libriheavy_long_cuts_dev.jsonl"
51
+ ],
52
+ "test": [
53
+ "libriheavy_cuts_test_clean.jsonl",
54
+ "libriheavy_cuts_test_other.jsonl",
55
+ "libriheavy_long_cuts_test_clean.jsonl",
56
+ "libriheavy_long_cuts_test_other.jsonl"
57
+ ]
58
+ }
59
+
60
+ stime = time.time()
61
+ organized_data = nested_defaultdict(4, list)
62
+ for mani_fn in split2manifest[args.split]:
63
+ # data = open_mani(os.path.join(mani_dir, mani_fn))
64
+ mani_full_fn = os.path.join(args.manifest_dir, mani_fn)
65
+ data = read_jsonl(mani_full_fn)
66
+ for item in data:
67
+ file_id = item['supervisions'][0]['id'] + '.flac'
68
+ recording_id = item['recording']['id'] + '.flac'
69
+ sizeSplit, spk, book, flac = recording_id.split("/") # e.g. 'medium/100/emerald_city_librivox_64kb_mp3/emeraldcity_01_baum_64kb'
70
+ if os.path.isfile(os.path.join(args.audio_dir, recording_id)):
71
+ vad = (item['start'], item['start']+item['duration'])
72
+ text = item['supervisions'][0]['custom']['texts'][0]
73
+ file_id = file_id.replace(".flac", "") + f"_{vad[0]:.2f}_{vad[1]:.2f}.flac"
74
+ organized_data[sizeSplit][spk][book][recording_id].append({"file_id": file_id, "vad":vad, "text": text})
75
+
76
+ # # for each recording_id, find the non-overlapping neighboring segments based on vad
77
+ # for sizeSplit in organized_data:
78
+ # for spk in organized_data[sizeSplit]:
79
+ # for book in organized_data[sizeSplit][spk]:
80
+ # for recording_id in organized_data[sizeSplit][spk][book]:
81
+ # segments = organized_data[sizeSplit][spk][book][recording_id]
82
+ # segments.sort(key=lambda x: x['vad'][0])
83
+ # for i in range(len(segments)):
84
+ # # for segment i, find the non-overlapping neighboring segments
85
+ # write_fn = os.path.join(args.output_dir, f"{segments[i]['file_id'].replace('.flac', '.txt')}")
86
+ # neighbors = []
87
+ # distance = []
88
+ # for j in range(len(segments)):
89
+ # if segments[i]['vad'][1] < segments[j]['vad'][0] or segments[i]['vad'][0] > segments[j]['vad'][0]:
90
+ # neighbors.append(segments[j]['file_id'].replace('.flac', '.txt'))
91
+ # distance.append(min(abs(segments[i]['vad'][1] - segments[j]['vad'][0]), abs(segments[i]['vad'][0] - segments[j]['vad'][1])))
92
+ # # order neighbors by distance
93
+ # neighbors_distance = [[x, dist] for dist, x in sorted(zip(distance, neighbors))]
94
+ # os.makedirs(os.path.dirname(write_fn), exist_ok=True)
95
+ # with open(write_fn, "w") as f:
96
+ # # note that there might be no neighbors, in which case the file is empty
97
+ # for neighbor, dist in neighbors_distance:
98
+ # f.write(f"{neighbor}\t{dist}\n")
99
+
100
+ # use multiprocessing.Pool for the above
101
+ segments = [organized_data[sizeSplit][spk][book][recording_id] for sizeSplit in organized_data for spk in organized_data[sizeSplit] for book in organized_data[sizeSplit][spk] for recording_id in organized_data[sizeSplit][spk][book]]
102
+ # only keep those that are exist
103
+ print(f"originally total {len(segments)} segments")
104
+ segments = [seg for seg in segments if os.path.isfile(os.path.join("/".join(args.output_dir.split("/")[:-1]),"audio", seg[0]['file_id']))]
105
+ print(f"after check existance, total {len(segments)} segments")
106
+ print(f"organizing took {(time.time()-stime)/60:.2f} minutes")
107
+ with multiprocessing.Pool(processes=args.n_workers) as pool:
108
+ for _ in tqdm.tqdm(pool.imap_unordered(find_neighbor_each, segments), total=len(segments)):
109
+ pass
110
+
111
+ # audio_root = "/data/scratch/pyp/datasets/librilight/preprocessed/audio"
112
+ def find_neighbor_each(segments):
113
+ # for each recording_id, find the non-overlapping neighboring segments based on vad
114
+ # only keep segments that have audio files
115
+ # actually only keep segments that have ipa_alignment files
116
+ segments = [seg for seg in segments if os.path.isfile(os.path.join("/".join(args.output_dir.split("/")[:-1]),"ipa_alignment", seg['file_id'].replace(".flac", ".txt")))]
117
+ if len(segments) <= 1:
118
+ return
119
+ for i in range(len(segments)):
120
+ # for segment i, find the non-overlapping neighboring segments
121
+ write_fn = os.path.join(args.output_dir, f"{segments[i]['file_id'].replace('.flac', '.txt')}")
122
+ neighbors = []
123
+ distance = []
124
+ for j in range(len(segments)):
125
+ if segments[i]['vad'][1] < segments[j]['vad'][0] or segments[i]['vad'][0] > segments[j]['vad'][0]:
126
+ neighbors.append(segments[j])
127
+ distance.append(min(abs(segments[i]['vad'][1] - segments[j]['vad'][0]), abs(segments[i]['vad'][0] - segments[j]['vad'][1])))
128
+ if len(neighbors) == 0:
129
+ continue
130
+ # order neighbors by distance
131
+ index = np.argsort(distance)
132
+ neighbors_distance = [[neighbors[ind], distance[ind]] for ind in index]
133
+ os.makedirs(os.path.dirname(write_fn), exist_ok=True)
134
+ with open(write_fn, "w") as f:
135
+ # note that there might be no neighbors, in which case the file is empty
136
+ for neighbor, dist in neighbors_distance:
137
+ f.write(f"{neighbor['file_id'].replace('.flac', '.txt')}\t{dist}\t{neighbor['vad'][1] - neighbor['vad'][0]}\n") # file_id, distance, duration
138
+
139
+
140
+
141
+ def parse_args():
142
+ parser = argparse.ArgumentParser(description="Cut a dataset in small "
143
+ "sequences using VAD files")
144
+ parser.add_argument('--split', type=str, default='train', choices=['train', 'valid', 'test'], help="train = libriheavy_cuts_{small,medium,large}.jsonl.gz, valid = libriheavy_cuts_dev_{clean,other}.jsonl.gz, test = libriheavy_cuts_test_{clean,other}.jsonl.gz")
145
+ parser.add_argument('--audio_dir', type=str, default="/data/scratch/pyp/datasets/librilight_example",
146
+ help="Path to the audio directory")
147
+ parser.add_argument('--manifest_dir', type=str, default="/data/scratch/pyp/datasets/librilight/libriheavy", help="path to the transcription file's dir, can be downloaded https://huggingface.co/datasets/pkufool/libriheavy/tree/main/v0.1")
148
+ parser.add_argument('--output_dir', type=str, default="/data/scratch/pyp/datasets/librilight/librilight_example_preprocessed/neighbors",
149
+ help="Path to the output directory")
150
+ parser.add_argument('--n_workers', type=int, default=16,
151
+ help="Number of parallel worker processes")
152
+ return parser.parse_args()
153
+
154
+ if __name__ == "__main__":
155
+ args = parse_args()
156
+ pathlib.Path(args.output_dir).mkdir(exist_ok=True, parents=True)
157
+ find_neighbor(args)
data/ll60k_preprocessing/step6_forced_alignment.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ import subprocess, tqdm
3
+ from concurrent.futures import ThreadPoolExecutor
4
+
5
+ def align_folders(audio_root, subfolder, subsubfolder):
6
+ # Construct output folder path
7
+ file_root = os.path.dirname(audio_root)
8
+ out_folder = f"{file_root}/alignment/{subfolder}/{subsubfolder}"
9
+
10
+ # Create the output directory
11
+ os.makedirs(out_folder, exist_ok=True)
12
+
13
+ # Construct the MFA align command
14
+ command = [
15
+ "mfa", "align", "--single_speaker", "-j", "8", "--clean",
16
+ f"{audio_root}/{subfolder}/{subsubfolder}", "english_us_arpa", "english_us_arpa",
17
+ out_folder, "--beam", "50", "--retry_beam", "400", "--output_format", "csv"
18
+ ]
19
+
20
+ # Run the command
21
+ subprocess.run(command, check=True)
22
+
23
+ def main(file_root = "/data/scratch/pyp/datasets/librilight/librilight_example_preprocessed", max_parallel_jobs=10, max_spk=100, partition="1/10", n_workers=64):
24
+ # Find all subfolder/subsubfolder combinations
25
+ tasks = []
26
+ audio_root = os.path.join(file_root, "audio")
27
+ for subfolder in os.listdir(audio_root):
28
+ subfolder_path = os.path.join(audio_root, subfolder)
29
+ if os.path.isdir(subfolder_path):
30
+ for subsubfolder in os.listdir(subfolder_path):
31
+ subsubfolder_path = os.path.join(subfolder_path, subsubfolder)
32
+ if os.path.isdir(subsubfolder_path):
33
+ tasks.append((audio_root, subfolder, subsubfolder))
34
+ speaker_folder_map = {}
35
+ for audio_root, subfolder, subsubfolder in tasks:
36
+ if os.path.join(audio_root, subfolder) not in speaker_folder_map:
37
+ speaker_folder_map[os.path.join(audio_root, subfolder)] = [os.path.join(audio_root, subfolder, subsubfolder)]
38
+ else:
39
+ speaker_folder_map[os.path.join(audio_root, subfolder)].append(os.path.join(audio_root, subfolder, subsubfolder))
40
+ speaker_folder_partitions = []
41
+ for audio_root_subfolder, speaker_folders in speaker_folder_map.items():
42
+ speaker_folder_partitions.extend([speaker_folders[i:i+max_spk] for i in range(0, len(speaker_folders), max_spk)])
43
+ s, e = partition.split("/")
44
+ s, e = int(s)-1, int(e)
45
+ cur_tasks = speaker_folder_partitions[s::e]
46
+ import secrets, string
47
+ import soundfile, glob
48
+ from joblib import Parallel, delayed
49
+ def delete_corrupted(fn):
50
+ try:
51
+ x = soundfile.read(fn)
52
+ except:
53
+ print(f"removing corrupted file: {fn}")
54
+ os.remove(fn)
55
+
56
+ for j, task in enumerate(tqdm.tqdm(cur_tasks)):
57
+ # get subfolder for the current task
58
+ subs = [item.split("/")[-2] for item in task]
59
+ # assert that all subs are the same
60
+ assert len(set(subs)) == 1, subs
61
+ sub = subs[0]
62
+ # randomly generate a foldername
63
+ # generate a random character
64
+ # make softlink from item in task to temp folder
65
+ random_string = ''.join(secrets.choice(string.ascii_letters + string.digits) for i in range(10))
66
+ temp_folder = os.path.join(file_root, "softlink_audio", random_string)
67
+ os.makedirs(temp_folder, exist_ok=True)
68
+ out_folder = f"{file_root}/alignment/{sub}"
69
+ all_out_speaker_folders = [os.path.join(out_folder, os.path.basename(item)) for item in task]
70
+ if sum(os.path.isdir(curpath) for curpath in all_out_speaker_folders) == len(all_out_speaker_folders):
71
+ continue
72
+ # remove audio files that are corrupted
73
+ all_audio_files = [audiofile for item in task for audiofile in glob.glob(item+"/*/*.flac")]
74
+ Parallel(n_jobs=n_workers)(delayed(delete_corrupted)(audiofn) for audiofn in all_audio_files)
75
+ for item in task:
76
+ # make softlink from subsubfolder to a new folder in temp folder
77
+ os.symlink(item, os.path.join(temp_folder, os.path.basename(item)))
78
+ # run mfa on the linked folder, but save alignment to the correct folder
79
+ command = f"mfa align -j {n_workers} {temp_folder} english_us_arpa english_us_arpa {out_folder} --beam 50 --retry_beam 200 --output_format csv --quiet --use_mp --temporary_directory {temp_folder}_temp"
80
+ os.system(command)
81
+ # delete the temp_folder
82
+ os.system(f"rm -r {temp_folder}")
83
+
84
+ if __name__ == "__main__":
85
+ import fire
86
+ fire.Fire(main)
data/ll60k_preprocessing/step6_forced_alignment.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ source ~/miniconda3/etc/profile.d/conda.sh
3
+ conda activate voicecraft
4
+
5
+ partition=$1
6
+ file_root=$2
7
+ max_spk=${max_spk:-100}
8
+ n_workers=${n_workers:-64}
9
+ python step6_forced_alignment.py \
10
+ --partition $partition \
11
+ --file_root $file_root \
12
+ --max_spk $max_spk \
13
+ --n_workers $n_workers
data/ll60k_preprocessing/step7_ipa_alignment.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # we have raw transcript at
2
+ # /data/scratch/pyp/datasets/librilight/preprocessed/audio
3
+ # we have word and ARPA alignment at
4
+ # /data/scratch/pyp/datasets/librilight/preprocessed/alignment
5
+
6
+ # we have manifest at /data/scratch/pyp/datasets/librilight/preprocessed/manifest_mimi
7
+ # where each row is like large/10022/essayoncriticism_1505_librivox_64kb_mp3/essayoncriticism_01_pope_64kb_5_610.32_630.08.flac 19.76
8
+
9
+ # we want to create IPA alignment from the raw transcript and word alignment, using phonemizer
10
+ # save at /data/scratch/pyp/datasets/librilight/preprocessed/ipa_alignment
11
+
12
+ # since ipa phonemized results are 1-to-1 with words (10 words might lead to a ipa sequence of 7 phonemes), we have to run phonemizer on each segment of the word sequence
13
+ import os, string, csv, random, tqdm, glob
14
+ from tokenizer import TextTokenizer, tokenize_text
15
+
16
+
17
+ def remove_punctuation(input_string):
18
+ translator = str.maketrans('', '', string.punctuation)
19
+ return input_string.translate(translator)
20
+
21
+
22
+
23
+ def create_alignment(fn, trans_dir, align_dir, audio_ext, trans_ext, arpa_ext, text_tokenizer, use_prob, ipa_alignment_fn, save=False, prompt_dur=30):
24
+ os.makedirs(os.path.dirname(ipa_alignment_fn), exist_ok=True)
25
+ trans_fn = os.path.join(trans_dir, fn.replace(audio_ext, trans_ext))
26
+ if not os.path.isfile(trans_fn):
27
+ return [], True
28
+ align_fn = os.path.join(align_dir, fn.replace(audio_ext, arpa_ext))
29
+ if not os.path.isfile(align_fn):
30
+ return [], True
31
+ # get raw transcript
32
+ with open(trans_fn, 'r') as f:
33
+ transcript = f.read().strip()
34
+ raw_word_list = transcript.split(" ")
35
+ # get word alignment
36
+ with open(align_fn, 'r') as f:
37
+ word_alignment = csv.reader(f)
38
+ word_alignment = [row for row in word_alignment if row[3]=='words']
39
+
40
+ ipa_alignment = []
41
+
42
+ for j, (item, raw_word) in enumerate(zip(word_alignment, raw_word_list)):
43
+ start, end, word = float(item[0]), float(item[1]), item[2]
44
+ if end > prompt_dur:
45
+ break
46
+ punc_re_raw_word = remove_punctuation(raw_word)
47
+ if not remove_punctuation(word).lower() == punc_re_raw_word.lower():
48
+ # print(f"word from alignment csv: {word}, word from txt: {raw_word}")
49
+ return ipa_alignment, True
50
+ if random.random() < use_prob:
51
+ cur_words = " ".join(raw_word_list[:j+1])
52
+ phn = tokenize_text(text_tokenizer, cur_words)
53
+ if len(phn) == 0:
54
+ continue
55
+ phn = " ".join(phn)
56
+ start = 0 # at this point, we always start from the beginning of the sentence
57
+ ipa_alignment.append([start, end, phn])
58
+ if save:
59
+ if ipa_alignment:
60
+ with open(ipa_alignment_fn, 'w') as f:
61
+ for item in ipa_alignment:
62
+ f.write(f"{item[0]}\t{item[1]}\t{item[2]}\n")
63
+ else:
64
+ return ipa_alignment, False
65
+
66
+
67
+
68
+ def main(
69
+ data_root: str = '/data/scratch/pyp/datasets/librilight/preprocessed',
70
+ audio_ext: str = '.flac',
71
+ arpa_ext: str = '.csv',
72
+ trans_ext: str = '.txt',
73
+ split: str = 'valid',
74
+ use_prob: float = 0.5,
75
+ max_dur: float = 30., # do not consider utterance longer than this
76
+ prompt_dur: float = 30., # do not consider prompt longer than this
77
+ ):
78
+ text_tokenizer = TextTokenizer()
79
+ trans_dir = f'{data_root}/audio'
80
+ align_dir = f'{data_root}/alignment'
81
+ manifest_fn = f"{data_root}/manifest_final_encodec/{split}*=*.txt"
82
+ manifest_fns = glob.glob(manifest_fn)
83
+ target_dir = f'{data_root}/ipa_alignment'
84
+ encodec_sr = 50
85
+ os.makedirs(target_dir, exist_ok=True)
86
+ manifest = []
87
+ for manifest_fn in manifest_fns:
88
+ with open(manifest_fn, 'r') as f:
89
+ temp = [l.strip().split("\t") for l in f.readlines()]
90
+ manifest += [l[0] + audio_ext for l in temp if float(l[1])/encodec_sr < max_dur]
91
+ # # sequential processing
92
+ n_flags = 0
93
+ zero_words = 0
94
+ for j, fn in enumerate(tqdm.tqdm(manifest)):
95
+ ipa_alignment_fn = os.path.join(target_dir, fn.replace(audio_ext, '.txt'))
96
+ ipa_alignment, flag = create_alignment(fn, trans_dir, align_dir, audio_ext, trans_ext, arpa_ext, text_tokenizer, use_prob, ipa_alignment_fn, prompt_dur=prompt_dur)
97
+ n_flags += flag
98
+ if not ipa_alignment:
99
+ zero_words += 1
100
+ # print(f"{n_flags} out of {j+1} utterances have mismatched words")
101
+ # print(f"{zero_words} out of {j+1} utterances have zero words")
102
+ if ipa_alignment:
103
+ with open(ipa_alignment_fn, 'w') as f:
104
+ for item in ipa_alignment:
105
+ f.write(f"{item[0]}\t{item[1]}\t{item[2]}\n")
106
+
107
+ # # # # do the above using joblib parallisim
108
+ # print(f"Processing {len(manifest)} utterances")
109
+ # from joblib import Parallel, delayed
110
+ # Parallel(n_jobs=32, verbose=2)(delayed(create_alignment)(fn, trans_dir, align_dir, audio_ext, trans_ext, arpa_ext, text_tokenizer, use_prob, os.path.join(target_dir, fn.replace(audio_ext, '.txt')), save=True) for fn in manifest)
111
+
112
+ if __name__ == "__main__":
113
+ import fire
114
+ fire.Fire(main)
data/ll60k_preprocessing/tokenizer.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # cp from https://github.com/lifeiteng/vall-e/blob/main/valle/data/tokenizer.py
2
+ # Copyright 2023 (authors: Feiteng Li)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import re
17
+ from dataclasses import asdict, dataclass
18
+ from typing import Any, Dict, List, Optional, Pattern, Union
19
+
20
+ import numpy as np
21
+ import torch
22
+ import torchaudio
23
+ # from encodec import EncodecModel
24
+ # from encodec.utils import convert_audio
25
+ # from lhotse.features import FeatureExtractor
26
+ # from lhotse.utils import Seconds, compute_num_frames
27
+ from phonemizer.backend import EspeakBackend
28
+ from phonemizer.backend.espeak.language_switch import LanguageSwitch
29
+ from phonemizer.backend.espeak.words_mismatch import WordMismatch
30
+ from phonemizer.punctuation import Punctuation
31
+ from phonemizer.separator import Separator
32
+
33
+ try:
34
+ from pypinyin import Style, pinyin
35
+ from pypinyin.style._utils import get_finals, get_initials
36
+ except Exception:
37
+ pass
38
+
39
+
40
+ class PypinyinBackend:
41
+ """PypinyinBackend for Chinese. Most codes is referenced from espnet.
42
+ There are two types pinyin or initials_finals, one is
43
+ just like "ni1 hao3", the other is like "n i1 h ao3".
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ backend="initials_finals",
49
+ punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
50
+ ) -> None:
51
+ self.backend = backend
52
+ self.punctuation_marks = punctuation_marks
53
+
54
+ def phonemize(
55
+ self, text: List[str], separator: Separator, strip=True, njobs=1
56
+ ) -> List[str]:
57
+ assert isinstance(text, List)
58
+ phonemized = []
59
+ for _text in text:
60
+ _text = re.sub(" +", " ", _text.strip())
61
+ _text = _text.replace(" ", separator.word)
62
+ phones = []
63
+ if self.backend == "pypinyin":
64
+ for n, py in enumerate(
65
+ pinyin(
66
+ _text, style=Style.TONE3, neutral_tone_with_five=True
67
+ )
68
+ ):
69
+ if all([c in self.punctuation_marks for c in py[0]]):
70
+ if len(phones):
71
+ assert phones[-1] == separator.syllable
72
+ phones.pop(-1)
73
+
74
+ phones.extend(list(py[0]))
75
+ else:
76
+ phones.extend([py[0], separator.syllable])
77
+ elif self.backend == "pypinyin_initials_finals":
78
+ for n, py in enumerate(
79
+ pinyin(
80
+ _text, style=Style.TONE3, neutral_tone_with_five=True
81
+ )
82
+ ):
83
+ if all([c in self.punctuation_marks for c in py[0]]):
84
+ if len(phones):
85
+ assert phones[-1] == separator.syllable
86
+ phones.pop(-1)
87
+ phones.extend(list(py[0]))
88
+ else:
89
+ if py[0][-1].isalnum():
90
+ initial = get_initials(py[0], strict=False)
91
+ if py[0][-1].isdigit():
92
+ final = (
93
+ get_finals(py[0][:-1], strict=False)
94
+ + py[0][-1]
95
+ )
96
+ else:
97
+ final = get_finals(py[0], strict=False)
98
+ phones.extend(
99
+ [
100
+ initial,
101
+ separator.phone,
102
+ final,
103
+ separator.syllable,
104
+ ]
105
+ )
106
+ else:
107
+ assert ValueError
108
+ else:
109
+ raise NotImplementedError
110
+ phonemized.append(
111
+ "".join(phones).rstrip(f"{separator.word}{separator.syllable}")
112
+ )
113
+ return phonemized
114
+
115
+
116
+ class TextTokenizer:
117
+ """Phonemize Text."""
118
+
119
+ def __init__(
120
+ self,
121
+ language="en-us",
122
+ backend="espeak",
123
+ separator=Separator(word="_", syllable="-", phone="|"),
124
+ preserve_punctuation=True,
125
+ punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
126
+ with_stress: bool = False,
127
+ tie: Union[bool, str] = False,
128
+ language_switch: LanguageSwitch = "keep-flags",
129
+ words_mismatch: WordMismatch = "ignore",
130
+ ) -> None:
131
+ if backend == "espeak":
132
+ phonemizer = EspeakBackend(
133
+ language,
134
+ punctuation_marks=punctuation_marks,
135
+ preserve_punctuation=preserve_punctuation,
136
+ with_stress=with_stress,
137
+ tie=tie,
138
+ language_switch=language_switch,
139
+ words_mismatch=words_mismatch,
140
+ )
141
+ elif backend in ["pypinyin", "pypinyin_initials_finals"]:
142
+ phonemizer = PypinyinBackend(
143
+ backend=backend,
144
+ punctuation_marks=punctuation_marks + separator.word,
145
+ )
146
+ else:
147
+ raise NotImplementedError(f"{backend}")
148
+
149
+ self.backend = phonemizer
150
+ self.separator = separator
151
+
152
+ def to_list(self, phonemized: str) -> List[str]:
153
+ fields = []
154
+ for word in phonemized.split(self.separator.word):
155
+ # "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z.
156
+ pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE)
157
+ fields.extend(
158
+ [p for p in pp if p != self.separator.phone]
159
+ + [self.separator.word]
160
+ )
161
+ assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count(
162
+ self.separator.phone
163
+ )
164
+ return fields[:-1]
165
+
166
+ def __call__(self, text, strip=True) -> List[List[str]]:
167
+ if isinstance(text, str):
168
+ text = [text]
169
+
170
+ phonemized = self.backend.phonemize(
171
+ text, separator=self.separator, strip=strip, njobs=1
172
+ )
173
+ return [self.to_list(p) for p in phonemized]
174
+
175
+
176
+ def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]:
177
+ phonemes = tokenizer([text.strip()])
178
+ return phonemes[0] # k2symbols
179
+
180
+
181
+ def remove_encodec_weight_norm(model):
182
+ from encodec.modules import SConv1d
183
+ from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock
184
+ from torch.nn.utils import remove_weight_norm
185
+ encoder = model.encoder.model
186
+ for key in encoder._modules:
187
+ if isinstance(encoder._modules[key], SEANetResnetBlock):
188
+ remove_weight_norm(encoder._modules[key].shortcut.conv.conv)
189
+ block_modules = encoder._modules[key].block._modules
190
+ for skey in block_modules:
191
+ if isinstance(block_modules[skey], SConv1d):
192
+ remove_weight_norm(block_modules[skey].conv.conv)
193
+ elif isinstance(encoder._modules[key], SConv1d):
194
+ remove_weight_norm(encoder._modules[key].conv.conv)
195
+ decoder = model.decoder.model
196
+ for key in decoder._modules:
197
+ if isinstance(decoder._modules[key], SEANetResnetBlock):
198
+ remove_weight_norm(decoder._modules[key].shortcut.conv.conv)
199
+ block_modules = decoder._modules[key].block._modules
200
+ for skey in block_modules:
201
+ if isinstance(block_modules[skey], SConv1d):
202
+ remove_weight_norm(block_modules[skey].conv.conv)
203
+ elif isinstance(decoder._modules[key], SConvTranspose1d):
204
+ remove_weight_norm(decoder._modules[key].convtr.convtr)
205
+ elif isinstance(decoder._modules[key], SConv1d):
206
+ remove_weight_norm(decoder._modules[key].conv.conv)
207
+
208
+
209
+ # class AudioTokenizer:
210
+ # """EnCodec audio."""
211
+
212
+ # def __init__(
213
+ # self,
214
+ # bandwidth, float=6.0,
215
+ # device: Any = None,
216
+ # ) -> None:
217
+ # # Instantiate a pretrained EnCodec model
218
+ # model = EncodecModel.encodec_model_24khz()
219
+ # model.set_target_bandwidth(bandwidth=bandwidth)
220
+ # remove_encodec_weight_norm(model)
221
+
222
+ # if not device:
223
+ # device = torch.device("cpu")
224
+ # if torch.cuda.is_available():
225
+ # device = torch.device("cuda:0")
226
+
227
+ # self._device = device
228
+
229
+ # self.codec = model.to(device)
230
+ # self.sample_rate = model.sample_rate
231
+ # self.channels = model.channels
232
+
233
+ # @property
234
+ # def device(self):
235
+ # return self._device
236
+
237
+ # def encode(self, wav: torch.Tensor) -> torch.Tensor:
238
+ # return self.codec.encode(wav.to(self.device))
239
+
240
+ # def decode(self, frames: torch.Tensor) -> torch.Tensor:
241
+ # return self.codec.decode(frames)
242
+
243
+ # class AudioTokenizer:
244
+ # """EnCodec audio."""
245
+
246
+ # def __init__(
247
+ # self,
248
+ # bandwidth: float=6.0,
249
+ # device: Any = None,
250
+ # hificodec=False,
251
+ # signature = None
252
+ # ) -> None:
253
+ # self.hificodec = hificodec
254
+ # self.customized = True if signature != None else False
255
+ # if hificodec:
256
+ # import sys
257
+ # sys.path.append("/home/pyp/AcademiCodec")
258
+ # from academicodec.models.hificodec.vqvae import VQVAE
259
+ # config_path = "/home/pyp/AcademiCodec/egs/HiFi-Codec-16k-320d/config_16k_320d.json"
260
+ # model_path = "/home/pyp/AcademiCodec/egs/HiFi-Codec-16k-320d/checkpoint/HiFi-Codec-16k-320d"
261
+ # self.sample_rate = 16000
262
+ # self.channels = 1
263
+ # model = VQVAE(config_path, model_path, with_encoder=True)
264
+ # model.generator.remove_weight_norm()
265
+ # model.encoder.remove_weight_norm()
266
+ # model.eval()
267
+ # else:
268
+ # if signature != None:
269
+ # # use customized encodec model
270
+ # # import sys
271
+ # # sys.path.append("home/pyp/audiocraft")
272
+ # from audiocraft.solvers import CompressionSolver
273
+ # model_path = f'//sig/{signature}'
274
+ # model = CompressionSolver.model_from_checkpoint(model_path)
275
+ # self.sample_rate = model.sample_rate
276
+ # self.channels = model.channels
277
+ # else:
278
+ # # Instantiate a pretrained EnCodec model
279
+ # model = EncodecModel.encodec_model_24khz()
280
+ # model.set_target_bandwidth(bandwidth=bandwidth)
281
+ # remove_encodec_weight_norm(model)
282
+ # self.sample_rate = model.sample_rate
283
+ # self.channels = model.channels
284
+
285
+ # if not device:
286
+ # device = torch.device("cpu")
287
+ # if torch.cuda.is_available():
288
+ # device = torch.device("cuda:0")
289
+
290
+ # self._device = device
291
+
292
+ # self.codec = model.to(device)
293
+
294
+ # @property
295
+ # def device(self):
296
+ # return self._device
297
+
298
+ # def encode(self, wav: torch.Tensor) -> torch.Tensor:
299
+ # if self.hificodec:
300
+ # assert wav.ndim==3 and wav.shape[:2] == torch.Size((1,1)), wav.shape
301
+ # wav = wav.squeeze(0)
302
+ # codes = self.codec.encode(wav.to(self.device)) # [1,T,4]
303
+ # return [(codes.transpose(2,1),None)]
304
+ # elif self.customized:
305
+ # codes = self.codec.encode(wav.to(self.device))
306
+ # return [(codes[0], None)]
307
+ # return self.codec.encode(wav.to(self.device))
308
+
309
+ # def decode(self, frames: torch.Tensor) -> torch.Tensor:
310
+ # if self.hificodec:
311
+ # frames = frames[0][0] # [1,4,T]
312
+ # assert frames.shape[:2] == torch.Size((1,4))
313
+ # audio = self.codec(frames.transpose(2,1))
314
+ # assert audio.shape[0] == 1, audio.shape
315
+ # return audio
316
+ # elif self.customized:
317
+ # frames = frames[0][0] # [1,4,T]
318
+ # return self.codec.decode(frames)
319
+ # return self.codec.decode(frames)
320
+ # # try:
321
+ # # return self.codec.decode(frames)
322
+ # # except:
323
+ # # import logging
324
+ # # logging.info(f"error when decoding frame of shape: {frames[0][0].shape}")
325
+ # # self.codec.cpu()
326
+ # # ret = self.codec.cpu().decode([(frames[0][0].cpu(),None)])[0].to(self._device)
327
+ # # self.codec.to(self._device)
328
+ # # return [ret]
329
+
330
+ # def tokenize_audio(tokenizer: AudioTokenizer, audio_path: str, offset = -1, num_frames=-1):
331
+ # # Load and pre-process the audio waveform
332
+ # if offset != -1 and num_frames!=-1:
333
+ # wav, sr = torchaudio.load(audio_path, frame_offset=offset, num_frames=num_frames)
334
+ # else:
335
+ # wav, sr = torchaudio.load(audio_path)
336
+ # wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels)
337
+ # wav = wav.unsqueeze(0)
338
+
339
+ # # Extract discrete codes from EnCodec
340
+ # with torch.no_grad():
341
+ # encoded_frames = tokenizer.encode(wav)
342
+ # return encoded_frames
343
+
344
+
345
+ # @dataclass
346
+ # class AudioTokenConfig:
347
+ # frame_shift: Seconds = 320.0 / 24000
348
+ # num_quantizers: int = 8
349
+
350
+ # def to_dict(self) -> Dict[str, Any]:
351
+ # return asdict(self)
352
+
353
+ # @staticmethod
354
+ # def from_dict(data: Dict[str, Any]) -> "AudioTokenConfig":
355
+ # return AudioTokenConfig(**data)
356
+
357
+
358
+ # class AudioTokenExtractor(FeatureExtractor):
359
+ # name = "encodec"
360
+ # config_type = AudioTokenConfig
361
+
362
+ # def __init__(self, config: Optional[Any] = None):
363
+ # super(AudioTokenExtractor, self).__init__(config)
364
+ # self.tokenizer = AudioTokenizer()
365
+
366
+ # def extract(
367
+ # self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int
368
+ # ) -> np.ndarray:
369
+ # if not isinstance(samples, torch.Tensor):
370
+ # samples = torch.from_numpy(samples)
371
+ # if sampling_rate != self.tokenizer.sample_rate:
372
+ # samples = convert_audio(
373
+ # samples,
374
+ # sampling_rate,
375
+ # self.tokenizer.sample_rate,
376
+ # self.tokenizer.channels,
377
+ # )
378
+ # if len(samples.shape) == 2:
379
+ # samples = samples.unsqueeze(0)
380
+ # else:
381
+ # raise ValueError()
382
+
383
+ # device = self.tokenizer.device
384
+ # encoded_frames = self.tokenizer.encode(samples.detach().to(device))
385
+ # codes = encoded_frames[0][0] # [B, n_q, T]
386
+ # if True:
387
+ # duration = round(samples.shape[-1] / sampling_rate, ndigits=12)
388
+ # expected_num_frames = compute_num_frames(
389
+ # duration=duration,
390
+ # frame_shift=self.frame_shift,
391
+ # sampling_rate=sampling_rate,
392
+ # )
393
+ # assert abs(codes.shape[-1] - expected_num_frames) <= 1
394
+ # codes = codes[..., :expected_num_frames]
395
+ # return codes.cpu().squeeze(0).permute(1, 0).numpy()
396
+
397
+ # @property
398
+ # def frame_shift(self) -> Seconds:
399
+ # return self.config.frame_shift
400
+
401
+ # def feature_dim(self, sampling_rate: int) -> int:
402
+ # return self.config.num_quantizers
403
+
404
+ # def pad_tensor_list(self, tensor_list, device, padding_value=0):
405
+ # # 计算每个张量的长度
406
+ # lengths = [tensor.shape[0] for tensor in tensor_list]
407
+ # # 使用pad_sequence函数进行填充
408
+ # tensor_list = [torch.Tensor(t).to(device) for t in tensor_list]
409
+ # padded_tensor = torch.nn.utils.rnn.pad_sequence(
410
+ # tensor_list, batch_first=True, padding_value=padding_value
411
+ # )
412
+ # return padded_tensor, lengths
413
+
414
+ # def extract_batch(self, samples, sampling_rate, lengths) -> np.ndarray:
415
+ # samples = [wav.squeeze() for wav in samples]
416
+ # device = self.tokenizer.device
417
+ # samples, lengths = self.pad_tensor_list(samples, device)
418
+ # samples = samples.unsqueeze(1)
419
+
420
+ # if not isinstance(samples, torch.Tensor):
421
+ # samples = torch.from_numpy(samples)
422
+ # if len(samples.shape) != 3:
423
+ # raise ValueError()
424
+ # if sampling_rate != self.tokenizer.sample_rate:
425
+ # samples = [
426
+ # convert_audio(
427
+ # wav,
428
+ # sampling_rate,
429
+ # self.tokenizer.sample_rate,
430
+ # self.tokenizer.channels,
431
+ # )
432
+ # for wav in samples
433
+ # ]
434
+ # # Extract discrete codes from EnCodec
435
+ # with torch.no_grad():
436
+ # encoded_frames = self.tokenizer.encode(samples.detach().to(device))
437
+ # encoded_frames = encoded_frames[0][0] # [B, n_q, T]
438
+ # batch_codes = []
439
+ # for b, length in enumerate(lengths):
440
+ # codes = encoded_frames[b]
441
+ # duration = round(length / sampling_rate, ndigits=12)
442
+ # expected_num_frames = compute_num_frames(
443
+ # duration=duration,
444
+ # frame_shift=self.frame_shift,
445
+ # sampling_rate=sampling_rate,
446
+ # )
447
+ # batch_codes.append(codes[..., :expected_num_frames])
448
+ # return [codes.cpu().permute(1, 0).numpy() for codes in batch_codes]
449
+
450
+
451
+ if __name__ == "__main__":
452
+ model = EncodecModel.encodec_model_24khz()
453
+ model.set_target_bandwidth(6.0)
454
+ # model.cuda()
455
+ samples = torch.from_numpy(np.random.random([4, 1, 30000])).type(torch.float32)
456
+ codes_norm = model.encode(samples.cuda())
457
+ remove_encodec_weight_norm(model)
458
+ codes_raw = model.encode(samples.cuda())
459
+
460
+ assert torch.allclose(codes_raw[0][0], codes_norm[0][0])
data/tokenizer.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # cp from https://github.com/lifeiteng/vall-e/blob/main/valle/data/tokenizer.py
2
+ # Copyright 2023 (authors: Feiteng Li)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import re
17
+ from dataclasses import asdict, dataclass
18
+ from typing import Any, Dict, List, Optional, Pattern, Union
19
+
20
+ import numpy as np
21
+ import torch
22
+ import torchaudio
23
+ # from encodec import EncodecModel
24
+ # from encodec.utils import convert_audio
25
+ # from lhotse.features import FeatureExtractor
26
+ # from lhotse.utils import Seconds, compute_num_frames
27
+ from phonemizer.backend import EspeakBackend
28
+ from phonemizer.backend.espeak.language_switch import LanguageSwitch
29
+ from phonemizer.backend.espeak.words_mismatch import WordMismatch
30
+ from phonemizer.punctuation import Punctuation
31
+ from phonemizer.separator import Separator
32
+
33
+
34
+ try:
35
+ from pypinyin import Style, pinyin
36
+ from pypinyin.style._utils import get_finals, get_initials
37
+ except Exception:
38
+ pass
39
+
40
+
41
+ class PypinyinBackend:
42
+ """PypinyinBackend for Chinese. Most codes is referenced from espnet.
43
+ There are two types pinyin or initials_finals, one is
44
+ just like "ni1 hao3", the other is like "n i1 h ao3".
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ backend="initials_finals",
50
+ punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
51
+ ) -> None:
52
+ self.backend = backend
53
+ self.punctuation_marks = punctuation_marks
54
+
55
+ def phonemize(
56
+ self, text: List[str], separator: Separator, strip=True, njobs=1
57
+ ) -> List[str]:
58
+ assert isinstance(text, List)
59
+ phonemized = []
60
+ for _text in text:
61
+ _text = re.sub(" +", " ", _text.strip())
62
+ _text = _text.replace(" ", separator.word)
63
+ phones = []
64
+ if self.backend == "pypinyin":
65
+ for n, py in enumerate(
66
+ pinyin(
67
+ _text, style=Style.TONE3, neutral_tone_with_five=True
68
+ )
69
+ ):
70
+ if all([c in self.punctuation_marks for c in py[0]]):
71
+ if len(phones):
72
+ assert phones[-1] == separator.syllable
73
+ phones.pop(-1)
74
+
75
+ phones.extend(list(py[0]))
76
+ else:
77
+ phones.extend([py[0], separator.syllable])
78
+ elif self.backend == "pypinyin_initials_finals":
79
+ for n, py in enumerate(
80
+ pinyin(
81
+ _text, style=Style.TONE3, neutral_tone_with_five=True
82
+ )
83
+ ):
84
+ if all([c in self.punctuation_marks for c in py[0]]):
85
+ if len(phones):
86
+ assert phones[-1] == separator.syllable
87
+ phones.pop(-1)
88
+ phones.extend(list(py[0]))
89
+ else:
90
+ if py[0][-1].isalnum():
91
+ initial = get_initials(py[0], strict=False)
92
+ if py[0][-1].isdigit():
93
+ final = (
94
+ get_finals(py[0][:-1], strict=False)
95
+ + py[0][-1]
96
+ )
97
+ else:
98
+ final = get_finals(py[0], strict=False)
99
+ phones.extend(
100
+ [
101
+ initial,
102
+ separator.phone,
103
+ final,
104
+ separator.syllable,
105
+ ]
106
+ )
107
+ else:
108
+ assert ValueError
109
+ else:
110
+ raise NotImplementedError
111
+ phonemized.append(
112
+ "".join(phones).rstrip(f"{separator.word}{separator.syllable}")
113
+ )
114
+ return phonemized
115
+
116
+
117
+ class TextTokenizer:
118
+ """Phonemize Text."""
119
+
120
+ def __init__(
121
+ self,
122
+ language="en-us",
123
+ backend="espeak",
124
+ separator=Separator(word="_", syllable="-", phone="|"),
125
+ preserve_punctuation=True,
126
+ punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
127
+ with_stress: bool = False,
128
+ tie: Union[bool, str] = False,
129
+ language_switch: LanguageSwitch = "keep-flags",
130
+ words_mismatch: WordMismatch = "ignore",
131
+ ) -> None:
132
+ if backend == "espeak":
133
+ phonemizer = EspeakBackend(
134
+ language,
135
+ punctuation_marks=punctuation_marks,
136
+ preserve_punctuation=preserve_punctuation,
137
+ with_stress=with_stress,
138
+ tie=tie,
139
+ language_switch=language_switch,
140
+ words_mismatch=words_mismatch,
141
+ )
142
+ elif backend in ["pypinyin", "pypinyin_initials_finals"]:
143
+ phonemizer = PypinyinBackend(
144
+ backend=backend,
145
+ punctuation_marks=punctuation_marks + separator.word,
146
+ )
147
+ else:
148
+ raise NotImplementedError(f"{backend}")
149
+
150
+ self.backend = phonemizer
151
+ self.separator = separator
152
+
153
+ def to_list(self, phonemized: str) -> List[str]:
154
+ fields = []
155
+ for word in phonemized.split(self.separator.word):
156
+ # "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z.
157
+ pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE)
158
+ fields.extend(
159
+ [p for p in pp if p != self.separator.phone]
160
+ + [self.separator.word]
161
+ )
162
+ assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count(
163
+ self.separator.phone
164
+ )
165
+ return fields[:-1]
166
+
167
+ def __call__(self, text, strip=True) -> List[List[str]]:
168
+ if isinstance(text, str):
169
+ text = [text]
170
+
171
+ phonemized = self.backend.phonemize(
172
+ text, separator=self.separator, strip=strip, njobs=1
173
+ )
174
+ return [self.to_list(p) for p in phonemized]
175
+
176
+
177
+ def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]:
178
+ phonemes = tokenizer([text.strip()])
179
+ return phonemes[0] # k2symbols
180
+
181
+
182
+ def remove_encodec_weight_norm(model):
183
+ from encodec.modules import SConv1d
184
+ from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock
185
+ from torch.nn.utils import remove_weight_norm
186
+ encoder = model.encoder.model
187
+ for key in encoder._modules:
188
+ if isinstance(encoder._modules[key], SEANetResnetBlock):
189
+ remove_weight_norm(encoder._modules[key].shortcut.conv.conv)
190
+ block_modules = encoder._modules[key].block._modules
191
+ for skey in block_modules:
192
+ if isinstance(block_modules[skey], SConv1d):
193
+ remove_weight_norm(block_modules[skey].conv.conv)
194
+ elif isinstance(encoder._modules[key], SConv1d):
195
+ remove_weight_norm(encoder._modules[key].conv.conv)
196
+ decoder = model.decoder.model
197
+ for key in decoder._modules:
198
+ if isinstance(decoder._modules[key], SEANetResnetBlock):
199
+ remove_weight_norm(decoder._modules[key].shortcut.conv.conv)
200
+ block_modules = decoder._modules[key].block._modules
201
+ for skey in block_modules:
202
+ if isinstance(block_modules[skey], SConv1d):
203
+ remove_weight_norm(block_modules[skey].conv.conv)
204
+ elif isinstance(decoder._modules[key], SConvTranspose1d):
205
+ remove_weight_norm(decoder._modules[key].convtr.convtr)
206
+ elif isinstance(decoder._modules[key], SConv1d):
207
+ remove_weight_norm(decoder._modules[key].conv.conv)
208
+
209
+
210
+ class AudioTokenizer:
211
+ """mimi audio."""
212
+
213
+ def __init__(
214
+ self,
215
+ bandwidth: float=6.0,
216
+ device: Any = None,
217
+ hificodec=False,
218
+ signature = None,
219
+ encode_only = False
220
+ ) -> None:
221
+ self.signature = signature
222
+ from data.encodec import get_compression_model
223
+ model = get_compression_model(signature, encode_only=encode_only, device=device)
224
+ self.sample_rate = model.sample_rate
225
+ self.channels = model.channels
226
+
227
+ if not device:
228
+ device = torch.device("cpu")
229
+ if torch.cuda.is_available():
230
+ device = torch.device("cuda")
231
+
232
+ self._device = device
233
+
234
+ self.codec = model.to(device)
235
+
236
+ @property
237
+ def device(self):
238
+ return self._device
239
+
240
+ def encode(self, wav: torch.Tensor) -> torch.Tensor:
241
+ if self.signature != None:
242
+ if self.signature == "lfsc":
243
+ if wav.ndim==3:
244
+ assert wav.shape[:2] == torch.Size((1,1)), wav.shape
245
+ wav = wav.squeeze(0)
246
+ elif wav.ndim==2:
247
+ assert wav.shape[0] == 1, wav.shape
248
+ else:
249
+ raise ValueError(wav.shape)
250
+ audio_len = torch.tensor([wav.shape[1]]).to(self.device)
251
+ codes, encoded_len = self.codec.encode(audio=wav.to(self.device), audio_len=audio_len)
252
+ return codes[:, :, :encoded_len[0]]
253
+ else:
254
+ codes = self.codec.encode(wav.to(self.device))
255
+ return codes[0]
256
+ else:
257
+ assert wav.ndim==3 and wav.shape[:2] == torch.Size((1,1)), wav.shape
258
+ return self.codec.encode(wav.to(self.device))
259
+
260
+ def decode(self, frames: torch.Tensor) -> torch.Tensor:
261
+ if self.signature != None and self.signature == "lfsc":
262
+ encoded_len = torch.tensor([frames.shape[-1]]).to(self.device)
263
+ reconstructed_audio, decoded_len = self.codec.decode(tokens=frames, tokens_len=encoded_len)
264
+ return reconstructed_audio[:, :decoded_len[0]].unsqueeze(0)
265
+ else:
266
+ return self.codec.decode(frames)
267
+
268
+ def tokenize_audio(tokenizer: AudioTokenizer, audio_path: str, offset = -1, num_frames=-1):
269
+ # Load and pre-process the audio waveform
270
+ if offset != -1 and num_frames!=-1:
271
+ wav, sr = torchaudio.load(audio_path, frame_offset=offset, num_frames=num_frames)
272
+ else:
273
+ wav, sr = torchaudio.load(audio_path)
274
+ if sr != tokenizer.sample_rate:
275
+ wav = torchaudio.transforms.Resample(sr, tokenizer.sample_rate)(wav)
276
+ sr = tokenizer.sample_rate
277
+ if wav.shape[0] == 2:
278
+ wav = wav.mean(dim=0, keepdim=True)
279
+ wav = wav.unsqueeze(0)
280
+ # Extract discrete codes from mimi
281
+ with torch.no_grad():
282
+ encoded_frames = tokenizer.encode(wav)
283
+ return encoded_frames
284
+
285
+
286
+ if __name__ == "__main__":
287
+ # tok = AudioTokenizer(signature="lfsc", device="cpu")
288
+ tok = AudioTokenizer(signature="/home/pyp/BoostedVoiceEditor/pretrained/encodec_6f79c6a8.th", device="cpu")
289
+ inaudio = "/home/pyp/BoostedVoiceEditor/demo/pam.wav"
290
+ encoded_frames = tokenize_audio(tok, inaudio)
291
+ print(encoded_frames.shape)
292
+ # decode it back
293
+ decoded_audio = tok.decode(encoded_frames)
294
+ torchaudio.save("/home/pyp/BoostedVoiceEditor/demo/pam_reconstructed_encodec_4cb_2nd.wav", decoded_audio[0], tok.sample_rate)
295
+
demo/5895_34622_000026_000002.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6914b36cccc39da83dc9e2f9370556e62876b6f0e37e0ab324d5c85208107fe4
3
+ size 503738
generated_tts/generated.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b054b4bcd8b52504ef903b578c4040edbfd4949d3aaae6be7f76ac2976a1f280
3
+ size 251598
inference_commandline.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchaudio
4
+ import numpy as np
5
+ import random
6
+ import whisper
7
+ import fire
8
+ from argparse import Namespace
9
+
10
+ from data.tokenizer import (
11
+ AudioTokenizer,
12
+ TextTokenizer,
13
+ )
14
+
15
+ from models import voice_star
16
+ from inference_tts_utils import inference_one_sample
17
+
18
+ ############################################################
19
+ # Utility Functions
20
+ ############################################################
21
+
22
+ def seed_everything(seed=1):
23
+ os.environ['PYTHONHASHSEED'] = str(seed)
24
+ random.seed(seed)
25
+ np.random.seed(seed)
26
+ torch.manual_seed(seed)
27
+ torch.cuda.manual_seed(seed)
28
+ torch.backends.cudnn.benchmark = False
29
+ torch.backends.cudnn.deterministic = True
30
+
31
+
32
+ def estimate_duration(ref_audio_path, text):
33
+ """
34
+ Estimate duration based on seconds per character from the reference audio.
35
+ """
36
+ info = torchaudio.info(ref_audio_path)
37
+ audio_duration = info.num_frames / info.sample_rate
38
+ length_text = max(len(text), 1)
39
+ spc = audio_duration / length_text # seconds per character
40
+ return len(text) * spc
41
+
42
+ ############################################################
43
+ # Main Inference Function
44
+ ############################################################
45
+
46
+ def run_inference(
47
+ reference_speech="./demo/5895_34622_000026_000002.wav",
48
+ target_text="I cannot believe that the same model can also do text to speech synthesis too! And you know what? this audio is 8 seconds long.",
49
+ # Model
50
+ model_name="VoiceStar_840M_30s", # or VoiceStar_840M_40s, the later model is trained on maximally 40s long speech
51
+ model_root="./pretrained",
52
+ # Additional optional
53
+ reference_text=None, # if None => run whisper on reference_speech
54
+ target_duration=None, # if None => estimate from reference_speech and target_text
55
+ # Default hyperparameters from snippet
56
+ codec_audio_sr=16000, # do not change
57
+ codec_sr=50, # do not change
58
+ top_k=10, # try 10, 20, 30, 40
59
+ top_p=1, # do not change
60
+ min_p=1, # do not change
61
+ temperature=1,
62
+ silence_tokens=None, # do not change it
63
+ kvcache=1, # if OOM, set to 0
64
+ multi_trial=None, # do not change it
65
+ repeat_prompt=1, # increase this to improve speaker similarity, but it reference speech duration in total adding target duration is longer than maximal training duration, quality may drop
66
+ stop_repetition=3, # will not use it
67
+ sample_batch_size=1, # do not change
68
+ # Others
69
+ seed=1,
70
+ output_dir="./generated_tts",
71
+ # Some snippet-based defaults
72
+ cut_off_sec=100, # do not adjust this, we always use the entire reference speech. If you wish to change, also make sure to change the reference_transcript, so that it's only the trasnscript of the speech remained
73
+ ):
74
+ """
75
+ Inference script using Fire.
76
+
77
+ Example:
78
+ python inference_commandline.py \
79
+ --reference_speech "./demo/5895_34622_000026_000002.wav" \
80
+ --target_text "I cannot believe ... this audio is 10 seconds long." \
81
+ --reference_text "(optional) text to use as prefix" \
82
+ --target_duration (optional float)
83
+ """
84
+
85
+ # Seed everything
86
+ seed_everything(seed)
87
+
88
+ # Load model, phn2num, and args
89
+ torch.serialization.add_safe_globals([Namespace])
90
+ device = "cuda" if torch.cuda.is_available() else "cpu"
91
+ ckpt_fn = os.path.join(model_root, model_name+".pth")
92
+ if not os.path.exists(ckpt_fn):
93
+ # use wget to download
94
+ print(f"[Info] Downloading {model_name} checkpoint...")
95
+ os.system(f"wget https://huggingface.co/pyp1/VoiceStar/resolve/main/{model_name}.pth?download=true -O {ckpt_fn}")
96
+ bundle = torch.load(ckpt_fn, map_location=device, weights_only=True)
97
+ args = bundle["args"]
98
+ phn2num = bundle["phn2num"]
99
+ model = voice_star.VoiceStar(args)
100
+ model.load_state_dict(bundle["model"])
101
+ model.to(device)
102
+ model.eval()
103
+
104
+ # If reference_text not provided, use whisper large-v3-turbo
105
+ if reference_text is None:
106
+ print("[Info] No reference_text provided, transcribing reference_speech with Whisper.")
107
+ wh_model = whisper.load_model("large-v3-turbo")
108
+ result = wh_model.transcribe(reference_speech)
109
+ prefix_transcript = result["text"]
110
+ print(f"[Info] Whisper transcribed text: {prefix_transcript}")
111
+ else:
112
+ prefix_transcript = reference_text
113
+
114
+ # If target_duration not provided, estimate from reference speech + target_text
115
+ if target_duration is None:
116
+ target_generation_length = estimate_duration(reference_speech, target_text)
117
+ print(f"[Info] target_duration not provided, estimated as {target_generation_length:.2f} seconds. If not desired, please provide a target_duration.")
118
+ else:
119
+ target_generation_length = float(target_duration)
120
+
121
+ # signature from snippet
122
+ if args.n_codebooks == 4:
123
+ signature = "./pretrained/encodec_6f79c6a8.th"
124
+ elif args.n_codebooks == 8:
125
+ signature = "./pretrained/encodec_8cb1024_giga.th"
126
+ else:
127
+ # fallback, just use the 6-f79c6a8
128
+ signature = "./pretrained/encodec_6f79c6a8.th"
129
+
130
+ if silence_tokens is None:
131
+ # default from snippet
132
+ silence_tokens = []
133
+
134
+ if multi_trial is None:
135
+ # default from snippet
136
+ multi_trial = []
137
+
138
+ delay_pattern_increment = args.n_codebooks + 1 # from snippet
139
+
140
+ # We can compute prompt_end_frame if we want, from snippet
141
+ info = torchaudio.info(reference_speech)
142
+ prompt_end_frame = int(cut_off_sec * info.sample_rate)
143
+
144
+ # Prepare tokenizers
145
+ audio_tokenizer = AudioTokenizer(signature=signature)
146
+ text_tokenizer = TextTokenizer(backend="espeak")
147
+
148
+ # decode_config from snippet
149
+ decode_config = {
150
+ 'top_k': top_k,
151
+ 'top_p': top_p,
152
+ 'min_p': min_p,
153
+ 'temperature': temperature,
154
+ 'stop_repetition': stop_repetition,
155
+ 'kvcache': kvcache,
156
+ 'codec_audio_sr': codec_audio_sr,
157
+ 'codec_sr': codec_sr,
158
+ 'silence_tokens': silence_tokens,
159
+ 'sample_batch_size': sample_batch_size
160
+ }
161
+
162
+ # Run inference
163
+ print("[Info] Running TTS inference...")
164
+ concated_audio, gen_audio = inference_one_sample(
165
+ model, args, phn2num, text_tokenizer, audio_tokenizer,
166
+ reference_speech, target_text,
167
+ device, decode_config,
168
+ prompt_end_frame=prompt_end_frame,
169
+ target_generation_length=target_generation_length,
170
+ delay_pattern_increment=delay_pattern_increment,
171
+ prefix_transcript=prefix_transcript,
172
+ multi_trial=multi_trial,
173
+ repeat_prompt=repeat_prompt,
174
+ )
175
+
176
+ # The model returns a list of waveforms, pick the first
177
+ concated_audio, gen_audio = concated_audio[0].cpu(), gen_audio[0].cpu()
178
+
179
+ # Save the audio (just the generated portion, as the snippet does)
180
+ os.makedirs(output_dir, exist_ok=True)
181
+ out_filename = "generated.wav"
182
+ out_path = os.path.join(output_dir, out_filename)
183
+ torchaudio.save(out_path, gen_audio, codec_audio_sr)
184
+
185
+ print(f"[Success] Generated audio saved to {out_path}")
186
+
187
+
188
+ def main():
189
+ fire.Fire(run_inference)
190
+
191
+ if __name__ == "__main__":
192
+ main()
inference_gradio.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ gradio_tts_app.py
4
+
5
+ Run:
6
+ python gradio_tts_app.py
7
+
8
+ Then open the printed local or public URL in your browser.
9
+ """
10
+
11
+ import os
12
+ import random
13
+ import numpy as np
14
+ import torch
15
+ import torchaudio
16
+ import whisper
17
+ import gradio as gr
18
+ from argparse import Namespace
19
+
20
+ # ---------------------------------------------------------------------
21
+ # The following imports assume your local project structure:
22
+ # data/tokenizer.py
23
+ # models/voice_star.py
24
+ # inference_tts_utils.py
25
+ # Adjust if needed.
26
+ # ---------------------------------------------------------------------
27
+ from data.tokenizer import AudioTokenizer, TextTokenizer
28
+ from models import voice_star
29
+ from inference_tts_utils import inference_one_sample
30
+
31
+
32
+ ############################################################
33
+ # Utility Functions
34
+ ############################################################
35
+
36
+ def seed_everything(seed=1):
37
+ os.environ['PYTHONHASHSEED'] = str(seed)
38
+ random.seed(seed)
39
+ np.random.seed(seed)
40
+ torch.manual_seed(seed)
41
+ torch.cuda.manual_seed(seed)
42
+ torch.backends.cudnn.benchmark = False
43
+ torch.backends.cudnn.deterministic = True
44
+
45
+
46
+ def estimate_duration(ref_audio_path, text):
47
+ """
48
+ Estimate duration based on seconds per character from the reference audio.
49
+ """
50
+ info = torchaudio.info(ref_audio_path)
51
+ audio_duration = info.num_frames / info.sample_rate
52
+ length_text = max(len(text), 1)
53
+ spc = audio_duration / length_text # seconds per character
54
+ return len(text) * spc
55
+
56
+
57
+ ############################################################
58
+ # Main Inference Function
59
+ ############################################################
60
+
61
+ def run_inference(
62
+ # User-adjustable parameters (no "# do not change" in snippet)
63
+ reference_speech="./demo/5895_34622_000026_000002.wav",
64
+ target_text="VoiceStar is a very interesting model, it's duration controllable and can extrapolate",
65
+ model_name="VoiceStar_840M_40s",
66
+ model_root="./pretrained",
67
+ reference_text=None, # optional
68
+ target_duration=None, # optional
69
+ top_k=10, # can try 10, 20, 30, 40
70
+ temperature=1,
71
+ kvcache=1, # if OOM, set to 0
72
+ repeat_prompt=1, # use higher to improve speaker similarity
73
+ stop_repetition=3, # snippet says "will not use it" but not "do not change"
74
+ seed=1,
75
+ output_dir="./generated_tts",
76
+
77
+ # Non-adjustable parameters (based on snippet instructions)
78
+ codec_audio_sr=16000, # do not change
79
+ codec_sr=50, # do not change
80
+ top_p=1, # do not change
81
+ min_p=1, # do not change
82
+ silence_tokens=None, # do not change it
83
+ multi_trial=None, # do not change it
84
+ sample_batch_size=1, # do not change
85
+ cut_off_sec=100, # do not adjust
86
+ ):
87
+ """
88
+ Inference script for VoiceStar TTS.
89
+ """
90
+ # 1. Set seed
91
+ seed_everything(seed)
92
+
93
+ # 2. Load model checkpoint
94
+ torch.serialization.add_safe_globals([Namespace])
95
+ device = "cuda" if torch.cuda.is_available() else "cpu"
96
+ ckpt_fn = os.path.join(model_root, model_name + ".pth")
97
+ if not os.path.exists(ckpt_fn):
98
+ # use wget to download
99
+ print(f"[Info] Downloading {model_name} checkpoint...")
100
+ os.system(f"wget https://huggingface.co/pyp1/VoiceStar/resolve/main/{model_name}.pth?download=true -O {ckpt_fn}")
101
+ bundle = torch.load(ckpt_fn, map_location=device, weights_only=True)
102
+ args = bundle["args"]
103
+ phn2num = bundle["phn2num"]
104
+
105
+ model = voice_star.VoiceStar(args)
106
+ model.load_state_dict(bundle["model"])
107
+ model.to(device)
108
+ model.eval()
109
+
110
+ # 3. If reference_text not provided, transcribe reference speech with Whisper
111
+ if reference_text is None:
112
+ print("[Info] No reference_text provided. Transcribing reference_speech with Whisper (large-v3-turbo).")
113
+ wh_model = whisper.load_model("large-v3-turbo")
114
+ result = wh_model.transcribe(reference_speech)
115
+ prefix_transcript = result["text"]
116
+ print(f"[Info] Whisper transcribed text: {prefix_transcript}")
117
+ else:
118
+ prefix_transcript = reference_text
119
+
120
+ # 4. If target_duration not provided, estimate from reference speech + target_text
121
+ if target_duration is None:
122
+ target_generation_length = estimate_duration(reference_speech, target_text)
123
+ print(f"[Info] target_duration not provided, estimated as {target_generation_length:.2f}s. Provide --target_duration if needed.")
124
+ else:
125
+ target_generation_length = float(target_duration)
126
+
127
+ # 5. Prepare signature from snippet
128
+ if args.n_codebooks == 4:
129
+ signature = "./pretrained/encodec_6f79c6a8.th"
130
+ elif args.n_codebooks == 8:
131
+ signature = "./pretrained/encodec_8cb1024_giga.th"
132
+ else:
133
+ signature = "./pretrained/encodec_6f79c6a8.th"
134
+
135
+ if silence_tokens is None:
136
+ silence_tokens = []
137
+
138
+ if multi_trial is None:
139
+ multi_trial = []
140
+
141
+ delay_pattern_increment = args.n_codebooks + 1 # from snippet
142
+
143
+ info = torchaudio.info(reference_speech)
144
+ prompt_end_frame = int(cut_off_sec * info.sample_rate)
145
+
146
+ # 6. Tokenizers
147
+ audio_tokenizer = AudioTokenizer(signature=signature)
148
+ text_tokenizer = TextTokenizer(backend="espeak")
149
+
150
+ # 7. decode_config
151
+ decode_config = {
152
+ "top_k": top_k,
153
+ "top_p": top_p,
154
+ "min_p": min_p,
155
+ "temperature": temperature,
156
+ "stop_repetition": stop_repetition,
157
+ "kvcache": kvcache,
158
+ "codec_audio_sr": codec_audio_sr,
159
+ "codec_sr": codec_sr,
160
+ "silence_tokens": silence_tokens,
161
+ "sample_batch_size": sample_batch_size,
162
+ }
163
+
164
+ # 8. Run inference
165
+ print("[Info] Running TTS inference...")
166
+ concated_audio, gen_audio = inference_one_sample(
167
+ model, args, phn2num, text_tokenizer, audio_tokenizer,
168
+ reference_speech, target_text,
169
+ device, decode_config,
170
+ prompt_end_frame=prompt_end_frame,
171
+ target_generation_length=target_generation_length,
172
+ delay_pattern_increment=delay_pattern_increment,
173
+ prefix_transcript=prefix_transcript,
174
+ multi_trial=multi_trial,
175
+ repeat_prompt=repeat_prompt,
176
+ )
177
+
178
+ # The model returns a list of waveforms, pick the first
179
+ concated_audio, gen_audio = concated_audio[0].cpu(), gen_audio[0].cpu()
180
+
181
+ # 9. Save generated audio
182
+ os.makedirs(output_dir, exist_ok=True)
183
+ out_filename = "generated.wav"
184
+ out_path = os.path.join(output_dir, out_filename)
185
+ torchaudio.save(out_path, gen_audio, codec_audio_sr)
186
+
187
+ print(f"[Success] Generated audio saved to {out_path}")
188
+ return out_path # Return the path for Gradio to load
189
+
190
+
191
+ ############################
192
+ # Transcription function
193
+ ############################
194
+
195
+ def transcribe_audio(reference_speech):
196
+ """
197
+ Transcribe uploaded reference audio with Whisper, return text.
198
+ If no file, return empty string.
199
+ """
200
+ if reference_speech is None:
201
+ return ""
202
+ audio_path = reference_speech # Because type="filepath"
203
+
204
+ if not os.path.exists(audio_path):
205
+ return "File not found."
206
+
207
+ print("[Info] Transcribing with Whisper...")
208
+ model = whisper.load_model("medium") # or "large-v2" etc.
209
+ result = model.transcribe(audio_path)
210
+ return result["text"]
211
+
212
+ ############################
213
+ # Gradio UI
214
+ ############################
215
+
216
+ def main():
217
+ with gr.Blocks() as demo:
218
+ gr.Markdown("## VoiceStar TTS with Editable Reference Text")
219
+
220
+ with gr.Row():
221
+ reference_speech_input = gr.Audio(
222
+ label="Reference Speech",
223
+ type="filepath",
224
+ elem_id="ref_speech"
225
+ )
226
+ transcribe_button = gr.Button("Transcribe")
227
+
228
+ # The transcribed text appears here and can be edited
229
+ reference_text_box = gr.Textbox(
230
+ label="Reference Text (Editable)",
231
+ placeholder="Click 'Transcribe' to auto-fill from reference speech...",
232
+ lines=2
233
+ )
234
+
235
+ target_text_box = gr.Textbox(
236
+ label="Target Text",
237
+ value="VoiceStar is a very interesting model, it's duration controllable and can extrapolate to unseen duration.",
238
+ lines=3
239
+ )
240
+
241
+ model_name_box = gr.Textbox(
242
+ label="Model Name",
243
+ value="VoiceStar_840M_40s"
244
+ )
245
+
246
+ model_root_box = gr.Textbox(
247
+ label="Model Root Directory",
248
+ value="/data1/scratch/pyp/BoostedVoiceEditor/runs"
249
+ )
250
+
251
+ reference_duration_box = gr.Textbox(
252
+ label="Target Duration (Optional)",
253
+ placeholder="Leave empty for auto-estimate."
254
+ )
255
+
256
+ top_k_box = gr.Number(label="top_k", value=10)
257
+ temperature_box = gr.Number(label="temperature", value=1.0)
258
+ kvcache_box = gr.Number(label="kvcache (1 or 0)", value=1)
259
+ repeat_prompt_box = gr.Number(label="repeat_prompt", value=1)
260
+ stop_repetition_box = gr.Number(label="stop_repetition", value=3)
261
+ seed_box = gr.Number(label="Random Seed", value=1)
262
+ output_dir_box = gr.Textbox(label="Output Directory", value="./generated_tts")
263
+
264
+ generate_button = gr.Button("Generate TTS")
265
+ output_audio = gr.Audio(label="Generated Audio", type="filepath")
266
+
267
+ # 1) When user clicks "Transcribe", we call `transcribe_audio`
268
+ transcribe_button.click(
269
+ fn=transcribe_audio,
270
+ inputs=[reference_speech_input],
271
+ outputs=[reference_text_box],
272
+ )
273
+
274
+ # 2) The actual TTS generation function.
275
+ def gradio_inference(
276
+ reference_speech,
277
+ reference_text,
278
+ target_text,
279
+ model_name,
280
+ model_root,
281
+ target_duration,
282
+ top_k,
283
+ temperature,
284
+ kvcache,
285
+ repeat_prompt,
286
+ stop_repetition,
287
+ seed,
288
+ output_dir
289
+ ):
290
+ # Convert any empty strings to None for optional fields
291
+ dur = float(target_duration) if target_duration else None
292
+
293
+ out_path = run_inference(
294
+ reference_speech=reference_speech,
295
+ reference_text=reference_text if reference_text else None,
296
+ target_text=target_text,
297
+ model_name=model_name,
298
+ model_root=model_root,
299
+ target_duration=dur,
300
+ top_k=int(top_k),
301
+ temperature=float(temperature),
302
+ kvcache=int(kvcache),
303
+ repeat_prompt=int(repeat_prompt),
304
+ stop_repetition=int(stop_repetition),
305
+ seed=int(seed),
306
+ output_dir=output_dir
307
+ )
308
+ return out_path
309
+
310
+ # 3) Link the "Generate TTS" button
311
+ generate_button.click(
312
+ fn=gradio_inference,
313
+ inputs=[
314
+ reference_speech_input,
315
+ reference_text_box,
316
+ target_text_box,
317
+ model_name_box,
318
+ model_root_box,
319
+ reference_duration_box,
320
+ top_k_box,
321
+ temperature_box,
322
+ kvcache_box,
323
+ repeat_prompt_box,
324
+ stop_repetition_box,
325
+ seed_box,
326
+ output_dir_box
327
+ ],
328
+ outputs=[output_audio],
329
+ )
330
+
331
+ demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)
332
+
333
+ if __name__ == "__main__":
334
+ main()
inference_tts_utils.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, pickle
2
+ import logging
3
+ import os, random
4
+ import numpy as np
5
+ import torch
6
+ import torchaudio
7
+
8
+ from data.tokenizer import (
9
+ AudioTokenizer,
10
+ TextTokenizer,
11
+ tokenize_audio,
12
+ tokenize_text
13
+ )
14
+ import argparse, time, tqdm
15
+
16
+
17
+ # this script only works for the musicgen architecture
18
+ def get_args():
19
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
20
+ parser.add_argument("--manifest_fn", type=str, default="path/to/eval_metadata_file")
21
+ parser.add_argument("--audio_root", type=str, default="path/to/audio_folder")
22
+ parser.add_argument("--exp_dir", type=str, default="path/to/model_folder")
23
+ parser.add_argument("--seed", type=int, default=1)
24
+ parser.add_argument("--codec_audio_sr", type=int, default=16000, help='the sample rate of audio that the codec is trained for')
25
+ parser.add_argument("--codec_sr", type=int, default=50, help='the sample rate of the codec codes')
26
+ parser.add_argument("--top_k", type=int, default=0, help="sampling param")
27
+ parser.add_argument("--top_p", type=float, default=0.8, help="sampling param")
28
+ parser.add_argument("--temperature", type=float, default=1.0, help="sampling param")
29
+ parser.add_argument("--output_dir", type=str, default=None)
30
+ parser.add_argument("--device", type=str, default="cuda")
31
+ parser.add_argument("--signature", type=str, default=None, help="path to the encodec model")
32
+ parser.add_argument("--crop_concat", type=int, default=0)
33
+ parser.add_argument("--stop_repetition", type=int, default=-1, help="used for inference, when the number of consecutive repetition of a token is bigger than this, stop it")
34
+ parser.add_argument("--kvcache", type=int, default=1, help='if true, use kv cache, which is 4-8x faster than without')
35
+ parser.add_argument("--sample_batch_size", type=int, default=1, help="batch size for sampling, NOTE that it's not running inference for several samples, but duplicate one input sample batch_size times, and during inference, we only return the shortest generation")
36
+ parser.add_argument("--silence_tokens", type=str, default="[1388,1898,131]", help="note that if you are not using the pretrained encodec 6f79c6a8, make sure you specified it yourself, rather than using the default")
37
+ return parser.parse_args()
38
+
39
+
40
+ @torch.no_grad()
41
+ def inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_tokenizer, audio_fn, target_text, device, decode_config, prompt_end_frame, target_generation_length, delay_pattern_increment, prefix_transcript=None, quiet=False, repeat_prompt=0, multi_trial=[]):
42
+ # seq_len_thres = 500 # 10s, 26% of the data in seed tts
43
+ # encode audio
44
+ encoded_frames = tokenize_audio(audio_tokenizer, audio_fn, offset=0, num_frames=prompt_end_frame)
45
+ # if sequence length is shorter than seq_len_thres, repeat the audio
46
+ # if encoded_frames.shape[2] < seq_len_thres:
47
+ # encoded_frames = torch.cat([encoded_frames, encoded_frames, encoded_frames], dim=2)
48
+ # doubled = True
49
+ single_encoded_frames = encoded_frames
50
+
51
+ if isinstance(repeat_prompt, int) and repeat_prompt > 0:
52
+ cur_repeat_prompt = repeat_prompt
53
+ while cur_repeat_prompt > 0:
54
+ encoded_frames = torch.cat([encoded_frames, single_encoded_frames], dim=2)
55
+ cur_repeat_prompt -= 1
56
+ elif isinstance(repeat_prompt, str) and repeat_prompt.lower() == "max":
57
+ repeat_prompt = 0
58
+ while encoded_frames.shape[2] + decode_config['codec_sr'] * target_generation_length + delay_pattern_increment + single_encoded_frames.shape[2] < model_args.audio_max_length * decode_config['codec_sr']:
59
+ encoded_frames = torch.cat([encoded_frames, single_encoded_frames], dim=2)
60
+ repeat_prompt += 1
61
+ if getattr(model_args, "y_sep_token", None) != None:
62
+ encoded_frames = torch.cat([encoded_frames, torch.LongTensor([model_args.y_sep_token]*model_args.n_codebooks).unsqueeze(0).unsqueeze(2).to(encoded_frames.device)], dim=2)
63
+ # print(encoded_frames.shape)
64
+ original_audio = encoded_frames.transpose(2,1) # [1,T,K]
65
+ assert original_audio.ndim==3 and original_audio.shape[0] == 1 and original_audio.shape[2] == model_args.n_codebooks, original_audio.shape
66
+
67
+ # phonemize
68
+ if isinstance(target_text, list):
69
+ text_tokens = [phn2num[phn] for phn in target_text if phn in phn2num]
70
+ else:
71
+ text_tokens = [phn2num[phn] for phn in
72
+ tokenize_text(
73
+ text_tokenizer, text=target_text.strip()
74
+ ) if phn in phn2num
75
+ ]
76
+ if getattr(model_args, "x_sep_token", None) != None:
77
+ assert prefix_transcript != None, "prefix_transcript must be provided if x_sep_token is not None"
78
+ if prefix_transcript is not None:
79
+ if isinstance(prefix_transcript, list):
80
+ prefix_tokens = [phn2num[phn] for phn in prefix_transcript if phn in phn2num]
81
+ else:
82
+ prefix_tokens = [phn2num[phn] for phn in
83
+ tokenize_text(
84
+ text_tokenizer, text=prefix_transcript.strip()
85
+ ) if phn in phn2num
86
+ ]
87
+ # if doubled:
88
+ # prefix_tokens = prefix_tokens + prefix_tokens + prefix_tokens
89
+ single_prefix_tokens = prefix_tokens
90
+ while repeat_prompt > 0:
91
+ prefix_tokens = prefix_tokens + single_prefix_tokens
92
+ repeat_prompt -= 1
93
+ if getattr(model_args, "x_sep_token", None) != None:
94
+ text_tokens = prefix_tokens + [getattr(model_args, "x_sep_token", None)] + text_tokens
95
+ else:
96
+ text_tokens = prefix_tokens + text_tokens
97
+ if getattr(model_args, "add_eos_to_text", 0) != 0:
98
+ text_tokens.append(model_args.add_eos_to_text)
99
+ if getattr(model_args, "add_bos_to_text", 0) != 0:
100
+ text_tokens = [model_args.add_bos_to_text] + text_tokens
101
+ text_tokens = torch.LongTensor(text_tokens).unsqueeze(0)
102
+ text_tokens_lens = torch.LongTensor([text_tokens.shape[-1]])
103
+
104
+ if not quiet:
105
+ logging.info(f"original audio length: {original_audio.shape[1]} codec frames, which is {original_audio.shape[1]/decode_config['codec_sr']:.2f} sec.")
106
+
107
+
108
+ if getattr(model_args, "parallel_pattern", 0) != 0:
109
+ tgt_y_lens = torch.LongTensor([int(original_audio.shape[1] + decode_config['codec_sr'] * target_generation_length + 2)]) # parallel pattern, therefore only add the empty_token (i.e. the sos token) and eos (i.e. 2 more tokens). Note that the delayed pattern between, both sos and eos is counted (sos is counted in the n_codebooks, eos is counted in the 1)
110
+ else:
111
+ tgt_y_lens = torch.LongTensor([int(original_audio.shape[1] + decode_config['codec_sr'] * target_generation_length + delay_pattern_increment)]) # delay pattern increment has accounted for the added eos
112
+
113
+ # forward
114
+ assert decode_config['sample_batch_size'] <= 1
115
+ stime = time.time()
116
+ assert multi_trial == []
117
+ if not quiet:
118
+ logging.info(f"running inference with batch size 1")
119
+ concat_frames, gen_frames = model.inference_tts(
120
+ text_tokens.to(device),
121
+ text_tokens_lens.to(device),
122
+ original_audio[...,:model_args.n_codebooks].to(device), # [1,T,8]
123
+ tgt_y_lens = tgt_y_lens.to(device),
124
+ top_k=decode_config['top_k'],
125
+ top_p=decode_config['top_p'],
126
+ min_p=decode_config['min_p'],
127
+ temperature=decode_config['temperature'],
128
+ stop_repetition=decode_config['stop_repetition'],
129
+ kvcache=decode_config['kvcache'],
130
+ silence_tokens=eval(decode_config['silence_tokens']) if type(decode_config['silence_tokens'])==str else decode_config['silence_tokens']
131
+ ) # output is [1,K,T]
132
+ if not quiet:
133
+ logging.info(f"inference on one sample take: {time.time() - stime:.4f} sec.")
134
+
135
+ logging.info(f"generated encoded_frames.shape: {gen_frames.shape}, which is {gen_frames.shape[-1]/decode_config['codec_sr']} sec.")
136
+
137
+ # for timestamp, codes in enumerate(gen_frames[0].transpose(1,0)):
138
+ # logging.info(f"{timestamp}: {codes.tolist()}")
139
+ # decode (both original and generated)
140
+ # concat_sample = audio_tokenizer.decode(
141
+ # [(concat_frames, None)] # [1,T,8] -> [1,8,T]
142
+ # )
143
+ if getattr(model_args, "y_sep_token", None) != None:
144
+ concat_frames = torch.cat([concat_frames[:, :, :original_audio.shape[1]-1], concat_frames[:, :, original_audio.shape[1]:]], dim=2)
145
+ concat_sample = audio_tokenizer.decode(
146
+ concat_frames # [1,8,T]
147
+ )
148
+ gen_sample = audio_tokenizer.decode(
149
+ gen_frames
150
+ )
151
+ #Empty cuda cache between runs
152
+ if torch.cuda.is_available():
153
+ torch.cuda.empty_cache()
154
+ # return
155
+ return concat_sample, gen_sample
main.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import torch, os
3
+
4
+ from tqdm import tqdm
5
+ import pickle
6
+ import argparse
7
+ import logging, datetime
8
+ import torch.distributed as dist
9
+ from config import MyParser
10
+ from steps import trainer
11
+ from copy_codebase import copy_codebase
12
+
13
+ def world_info_from_env():
14
+ local_rank = int(os.environ["LOCAL_RANK"])
15
+ global_rank = int(os.environ["RANK"])
16
+ world_size = int(os.environ["WORLD_SIZE"])
17
+ return local_rank, global_rank, world_size
18
+
19
+ if __name__ == "__main__":
20
+ formatter = (
21
+ "%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s"
22
+ )
23
+ logging.basicConfig(format=formatter, level=logging.INFO)
24
+
25
+ torch.cuda.empty_cache()
26
+ args = MyParser().parse_args()
27
+ exp_dir = Path(args.exp_dir)
28
+ exp_dir.mkdir(exist_ok=True, parents=True)
29
+ logging.info(f"exp_dir: {str(exp_dir)}")
30
+
31
+ if args.resume and (os.path.exists("%s/bundle.pth" % args.exp_dir) or os.path.exists("%s/bundle_prev.pth" % args.exp_dir)):
32
+ if not os.path.exists("%s/bundle.pth" % args.exp_dir):
33
+ os.system(f"cp {args.exp_dir}/bundle_prev.pth {args.exp_dir}/bundle.pth")
34
+ resume = args.resume
35
+ assert(bool(args.exp_dir))
36
+ with open("%s/args.pkl" % args.exp_dir, "rb") as f:
37
+ old_args = pickle.load(f)
38
+ new_args = vars(args)
39
+ old_args = vars(old_args)
40
+ for key in new_args:
41
+ if key not in old_args or old_args[key] != new_args[key]:
42
+ old_args[key] = new_args[key]
43
+ args = argparse.Namespace(**old_args)
44
+ args.resume = resume
45
+ else:
46
+ args.resume = False
47
+ with open("%s/args.pkl" % args.exp_dir, "wb") as f:
48
+ pickle.dump(args, f)
49
+
50
+ # make timeout longer (for generation)
51
+ timeout = datetime.timedelta(seconds=7200) # 60 minutes
52
+
53
+ if args.multinodes:
54
+ _local_rank, _, _ = world_info_from_env()
55
+ dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, timeout=timeout)
56
+ else:
57
+ dist.init_process_group(backend='nccl', init_method='env://', timeout=timeout)
58
+
59
+ if args.local_wandb:
60
+ os.environ["WANDB_MODE"] = "offline"
61
+
62
+ rank = dist.get_rank()
63
+ if rank == 0:
64
+ logging.info(args)
65
+ logging.info(f"exp_dir: {str(exp_dir)}")
66
+ world_size = dist.get_world_size()
67
+
68
+ local_rank = int(_local_rank) if args.multinodes else rank
69
+ num_devices= torch.cuda.device_count()
70
+ logging.info(f"{local_rank=}, {rank=}, {world_size=}, {type(local_rank)=}, {type(rank)=}, {type(world_size)=}")
71
+ for device_idx in range(num_devices):
72
+ device_name = torch.cuda.get_device_name(device_idx)
73
+ logging.info(f"Device {device_idx}: {device_name}")
74
+
75
+ torch.cuda.set_device(local_rank)
76
+ if rank == 0:
77
+ user_dir = os.path.expanduser("~")
78
+ codebase_name = "VoiceStar"
79
+ now = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
80
+ copy_codebase(os.path.join(user_dir, codebase_name), os.path.join(exp_dir, f"{codebase_name}_{now}"), max_size_mb=5, gitignore_path=os.path.join(user_dir, codebase_name, ".gitignore"))
81
+ my_trainer = trainer.Trainer(args, world_size, rank, local_rank)
82
+ my_trainer.train()
models/modules/__init__.py ADDED
File without changes
models/modules/activation.py ADDED
@@ -0,0 +1,781 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
2
+ from typing import Optional, Tuple
3
+
4
+ import torch
5
+ from torch import Tensor
6
+ from torch.nn import Linear, Module
7
+ from torch.nn import functional as F
8
+ from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
9
+ from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
10
+ from torch.nn.parameter import Parameter
11
+ import logging
12
+ from typing import Callable, List, Optional, Tuple, Union
13
+ from typing import TYPE_CHECKING
14
+ if TYPE_CHECKING:
15
+ from torch.types import _dtype as DType
16
+ else:
17
+ # The JIT doesn't understand Union, nor torch.dtype here
18
+ DType = int
19
+
20
+ def _canonical_mask(
21
+ mask: Optional[Tensor],
22
+ mask_name: str,
23
+ other_type: Optional[DType],
24
+ other_name: str,
25
+ target_type: DType,
26
+ check_other: bool = True,
27
+ ) -> Optional[Tensor]:
28
+
29
+ if mask is not None:
30
+ _mask_dtype = mask.dtype
31
+ _mask_is_float = torch.is_floating_point(mask)
32
+ if _mask_dtype != torch.bool and not _mask_is_float:
33
+ raise AssertionError(
34
+ f"only bool and floating types of {mask_name} are supported")
35
+ if check_other and other_type is not None:
36
+ if _mask_dtype != other_type:
37
+ warnings.warn(
38
+ f"Support for mismatched {mask_name} and {other_name} "
39
+ "is deprecated. Use same type for both instead."
40
+ )
41
+ if not _mask_is_float:
42
+ mask = (
43
+ torch.zeros_like(mask, dtype=target_type)
44
+ .masked_fill_(mask, float("-inf"))
45
+ )
46
+ return mask
47
+
48
+ def _in_projection_packed(
49
+ q: Tensor,
50
+ k: Tensor,
51
+ v: Tensor,
52
+ w: Tensor,
53
+ b: Optional[Tensor] = None,
54
+ ) -> List[Tensor]:
55
+ r"""
56
+ Performs the in-projection step of the attention operation, using packed weights.
57
+ Output is a triple containing projection tensors for query, key and value.
58
+
59
+ Args:
60
+ q, k, v: query, key and value tensors to be projected. For self-attention,
61
+ these are typically the same tensor; for encoder-decoder attention,
62
+ k and v are typically the same tensor. (We take advantage of these
63
+ identities for performance if they are present.) Regardless, q, k and v
64
+ must share a common embedding dimension; otherwise their shapes may vary.
65
+ w: projection weights for q, k and v, packed into a single tensor. Weights
66
+ are packed along dimension 0, in q, k, v order.
67
+ b: optional projection biases for q, k and v, packed into a single tensor
68
+ in q, k, v order.
69
+
70
+ Shape:
71
+ Inputs:
72
+ - q: :math:`(..., E)` where E is the embedding dimension
73
+ - k: :math:`(..., E)` where E is the embedding dimension
74
+ - v: :math:`(..., E)` where E is the embedding dimension
75
+ - w: :math:`(E * 3, E)` where E is the embedding dimension
76
+ - b: :math:`E * 3` where E is the embedding dimension
77
+
78
+ Output:
79
+ - in output list :math:`[q', k', v']`, each output tensor will have the
80
+ same shape as the corresponding input tensor.
81
+ """
82
+ E = q.size(-1)
83
+ if k is v:
84
+ if q is k:
85
+ # self-attention
86
+ proj = F.linear(q, w, b)
87
+ # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
88
+ proj = proj.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
89
+ return proj[0], proj[1], proj[2]
90
+ else:
91
+ # encoder-decoder attention
92
+ w_q, w_kv = w.split([E, E * 2])
93
+ if b is None:
94
+ b_q = b_kv = None
95
+ else:
96
+ b_q, b_kv = b.split([E, E * 2])
97
+ q_proj = F.linear(q, w_q, b_q)
98
+ kv_proj = F.linear(k, w_kv, b_kv)
99
+ # reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk()
100
+ kv_proj = kv_proj.unflatten(-1, (2, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
101
+ return (q_proj, kv_proj[0], kv_proj[1])
102
+ else:
103
+ w_q, w_k, w_v = w.chunk(3)
104
+ if b is None:
105
+ b_q = b_k = b_v = None
106
+ else:
107
+ b_q, b_k, b_v = b.chunk(3)
108
+ return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)
109
+
110
+ def _none_or_dtype(input: Optional[Tensor]) -> Optional[DType]:
111
+ if input is None:
112
+ return None
113
+ elif isinstance(input, torch.Tensor):
114
+ return input.dtype
115
+ raise RuntimeError("input to _none_or_dtype() must be None or torch.Tensor")
116
+
117
+ def rotate_half(x):
118
+ x1 = x[..., :x.shape[-1] // 2]
119
+ x2 = x[..., x.shape[-1] // 2:]
120
+ return torch.cat([-x2, x1], dim=-1)
121
+
122
+ def apply_rotary_pos_emb(q, k, q_sinu=None, k_sinu=None, sinu=None, unsqueeze_dim=1, args=None, q_offset=0):
123
+ if sinu is not None:
124
+ q_emb = q * sinu['cos'][:, q_offset:q_offset+q.shape[2]].unsqueeze(unsqueeze_dim) + rotate_half(q) * sinu['sin'][:, q_offset:q_offset+q.shape[2]].unsqueeze(unsqueeze_dim)
125
+ k_emb = k * sinu['cos'][:, :k.shape[2]].unsqueeze(unsqueeze_dim) + rotate_half(k) * sinu['sin'][:, :k.shape[2]].unsqueeze(unsqueeze_dim)
126
+ if q_sinu is not None:
127
+ assert sinu is None, "sinu must be None"
128
+ q_emb = q * q_sinu['cos'][:, :, q_offset:q_offset+q.shape[2]] + rotate_half(q) * q_sinu['sin'][:, :, q_offset:q_offset+q.shape[2]]
129
+ k_emb = k * k_sinu['cos'][:, :, :k.shape[2]] + rotate_half(k) * k_sinu['sin'][:, :, :k.shape[2]]
130
+ # else:
131
+ # assert freqs is not None, "freqs must be provided"
132
+ # assert key_lens is not None, "key_lens must be provided"
133
+ # assert query_lens is not None, "query_lens must be provided"
134
+ # # key_multiple
135
+ # assert key_lens.ndim==1, key_lens
136
+ # assert query_lens.ndim==1, query_lens
137
+ # q_lens_expanded = query_lens.unsqueeze(-1).unsqueeze(-1) # [B, 1, 1]
138
+ # k_lens_expanded = key_lens.unsqueeze(-1).unsqueeze(-1) # [B, 1, 1]
139
+ # query_ids_multiple = q_lens_expanded / (q_lens_expanded - 1)
140
+ # key_ids_multiple = k_lens_expanded / (k_lens_expanded - 1)
141
+ # # freqs.shape [1, x_len_max, d]
142
+ # # torch.set_printoptions(edgeitems=200)
143
+ # # logging.info(f"{freqs[:, :q.shape[2]]=}")
144
+ # # logging.info(f"{query_ids_multiple=}")
145
+ # # logging.info(f"{key_ids_multiple=}")
146
+ # # print(f"q.shape: {q.shape}")
147
+ # # print(f"q_offset: {q_offset}")
148
+ # # print(f"k.shape: {k.shape}")
149
+ # q_emb = freqs[:, q_offset:q_offset+q.shape[-2]] * query_ids_multiple # [B, q_len_max, d]
150
+ # k_emb = freqs[:, :k.shape[2]] * key_ids_multiple # [B, k_len_max, d]
151
+ # # logging.info(f"{q_emb[0, :, :5]=}")
152
+ # # logging.info(f"{k_emb[0, :, :5]=}")
153
+ # multiple = k_lens_expanded if multiple_key_length else q_lens_expanded
154
+ # if progress_no_multiple:
155
+ # multiple = 1
156
+ # q_emb = q_emb / q_lens_expanded * multiple * progress_scale
157
+ # k_emb = k_emb / k_lens_expanded * multiple * progress_scale
158
+ # q_cos = q_emb.cos().unsqueeze(unsqueeze_dim) # [B, 1, q_len_max, d] # 1 is for nhead
159
+ # q_sin = q_emb.sin().unsqueeze(unsqueeze_dim)
160
+ # k_cos = k_emb.cos().unsqueeze(unsqueeze_dim)
161
+ # k_sin = k_emb.sin().unsqueeze(unsqueeze_dim)
162
+ # # # visualize rotary pos emb with dummy feature
163
+ # # q_tmp = torch.ones_like(q)
164
+ # # k_tmp = torch.ones_like(k)
165
+ # # q_tmp_emb = q_tmp * q_cos + rotate_half(q_tmp) * q_sin
166
+ # # k_tmp_emb = k_tmp * k_cos + rotate_half(k_tmp) * k_sin
167
+ # # sims = q_tmp_emb @ k_tmp_emb.transpose(-2, -1)
168
+ # # import matplotlib.pyplot as plt
169
+ # # for i, sim in enumerate(sims):
170
+ # # plt.imshow(sim[0][:query_lens[i], :key_lens[i]].detach().cpu().numpy())
171
+ # # plt.savefig(f"sim{i}_head0.png")
172
+ # # plt.imshow(sim[5][:query_lens[i], :key_lens[i]].detach().cpu().numpy())
173
+ # # plt.savefig(f"sim{i}_head5.png")
174
+ # q_emb = q * q_cos + rotate_half(q) * q_sin
175
+ # k_emb = k * k_cos + rotate_half(k) * k_sin
176
+ # # # visualize the real attention weights
177
+ # # sims = q_emb @ k_emb.transpose(-2, -1)
178
+ # # from datetime import datetime
179
+ # # from matplotlib import pyplot as plt
180
+ # # now = datetime.now()
181
+ # # for i, sim in enumerate(sims):
182
+ # # for ihead, si in enumerate(sim):
183
+ # # if query_lens[i] == key_lens[i]:
184
+ # # continue
185
+ # # plt.imshow(si[:query_lens[i], :key_lens[i]].detach().cpu().numpy())
186
+ # # plt.savefig(f"sample{i}_head{ihead}_{now.strftime('%Y-%m-%d_%H-%M-%S')}.png")
187
+ return q_emb, k_emb
188
+
189
+
190
+ class MultiheadAttention(Module):
191
+ r"""Allows the model to jointly attend to information
192
+ from different representation subspaces as described in the paper:
193
+ `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
194
+
195
+ Multi-Head Attention is defined as:
196
+
197
+ .. math::
198
+ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
199
+
200
+ where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
201
+
202
+ ``forward()`` will use a special optimized implementation if all of the following
203
+ conditions are met:
204
+
205
+ - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This
206
+ restriction will be loosened in the future.)
207
+ - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
208
+ - training is disabled (using ``.eval()``)
209
+ - dropout is 0
210
+ - ``add_bias_kv`` is ``False``
211
+ - ``add_zero_attn`` is ``False``
212
+ - ``batch_first`` is ``True`` and the input is batched
213
+ - ``kdim`` and ``vdim`` are equal to ``embed_dim``
214
+ - at most one of ``key_padding_mask`` or ``attn_mask`` is passed
215
+ - if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
216
+ nor ``attn_mask`` is passed
217
+
218
+ If the optimized implementation is in use, a
219
+ `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
220
+ ``query``/``key``/``value`` to represent padding more efficiently than using a
221
+ padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
222
+ will be returned, and an additional speedup proportional to the fraction of the input
223
+ that is padding can be expected.
224
+
225
+ Args:
226
+ embed_dim: Total dimension of the model.
227
+ num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
228
+ across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
229
+ dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
230
+ bias: If specified, adds bias to input / output projection layers. Default: ``True``.
231
+ add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
232
+ add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
233
+ Default: ``False``.
234
+ kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
235
+ vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
236
+ batch_first: If ``True``, then the input and output tensors are provided
237
+ as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
238
+
239
+ Examples::
240
+
241
+ >>> # xdoctest: +SKIP
242
+ >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
243
+ >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
244
+
245
+ """
246
+ __constants__ = ["batch_first"]
247
+ bias_k: Optional[torch.Tensor]
248
+ bias_v: Optional[torch.Tensor]
249
+
250
+ def __init__(
251
+ self,
252
+ embed_dim,
253
+ num_heads,
254
+ dropout=0.0,
255
+ bias=True,
256
+ add_bias_kv=False,
257
+ add_zero_attn=False,
258
+ kdim=None,
259
+ vdim=None,
260
+ batch_first=False,
261
+ linear1_cls=Linear,
262
+ linear2_cls=Linear,
263
+ device=None,
264
+ dtype=None,
265
+ ) -> None:
266
+ factory_kwargs = {"device": device, "dtype": dtype}
267
+ super(MultiheadAttention, self).__init__()
268
+ self.embed_dim = embed_dim
269
+ self.kdim = kdim if kdim is not None else embed_dim
270
+ self.vdim = vdim if vdim is not None else embed_dim
271
+ self._qkv_same_embed_dim = (
272
+ self.kdim == embed_dim and self.vdim == embed_dim
273
+ )
274
+
275
+ self.num_heads = num_heads
276
+ self.dropout = dropout
277
+ self.batch_first = batch_first
278
+ self.head_dim = embed_dim // num_heads
279
+ assert (
280
+ self.head_dim * num_heads == self.embed_dim
281
+ ), "embed_dim must be divisible by num_heads"
282
+
283
+ if add_bias_kv:
284
+ self.bias_k = Parameter(
285
+ torch.empty((1, 1, embed_dim), **factory_kwargs)
286
+ )
287
+ self.bias_v = Parameter(
288
+ torch.empty((1, 1, embed_dim), **factory_kwargs)
289
+ )
290
+ else:
291
+ self.bias_k = self.bias_v = None
292
+
293
+ if linear1_cls == Linear:
294
+ if not self._qkv_same_embed_dim:
295
+ self.q_proj_weight = Parameter(
296
+ torch.empty((embed_dim, embed_dim), **factory_kwargs)
297
+ )
298
+ self.k_proj_weight = Parameter(
299
+ torch.empty((embed_dim, self.kdim), **factory_kwargs)
300
+ )
301
+ self.v_proj_weight = Parameter(
302
+ torch.empty((embed_dim, self.vdim), **factory_kwargs)
303
+ )
304
+ self.register_parameter("in_proj_weight", None)
305
+ else:
306
+ # go down this route with music_gen
307
+ self.in_proj_weight = Parameter(
308
+ torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
309
+ )
310
+ self.register_parameter("q_proj_weight", None)
311
+ self.register_parameter("k_proj_weight", None)
312
+ self.register_parameter("v_proj_weight", None)
313
+
314
+ if bias: # True by default
315
+ self.in_proj_bias = Parameter(
316
+ torch.empty(3 * embed_dim, **factory_kwargs)
317
+ )
318
+ else:
319
+ self.register_parameter("in_proj_bias", None)
320
+ self.out_proj = NonDynamicallyQuantizableLinear(
321
+ embed_dim, embed_dim, bias=bias, **factory_kwargs
322
+ )
323
+
324
+ self._reset_parameters()
325
+ else:
326
+ if not self._qkv_same_embed_dim:
327
+ raise NotImplementedError
328
+ else:
329
+ self.in_proj_linear = linear1_cls(
330
+ embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
331
+ )
332
+ self.in_proj_weight = self.in_proj_linear.weight
333
+
334
+ self.register_parameter("q_proj_weight", None)
335
+ self.register_parameter("k_proj_weight", None)
336
+ self.register_parameter("v_proj_weight", None)
337
+
338
+ if bias:
339
+ self.in_proj_bias = self.in_proj_linear.bias
340
+ else:
341
+ self.register_parameter("in_proj_bias", None)
342
+
343
+ self.out_proj = linear2_cls(
344
+ embed_dim, embed_dim, bias=bias, **factory_kwargs
345
+ )
346
+
347
+ if self.bias_k is not None:
348
+ xavier_normal_(self.bias_k)
349
+ if self.bias_v is not None:
350
+ xavier_normal_(self.bias_v)
351
+
352
+ self.add_zero_attn = add_zero_attn
353
+
354
+ def _reset_parameters(self):
355
+ if self._qkv_same_embed_dim:
356
+ xavier_uniform_(self.in_proj_weight)
357
+ else:
358
+ xavier_uniform_(self.q_proj_weight)
359
+ xavier_uniform_(self.k_proj_weight)
360
+ xavier_uniform_(self.v_proj_weight)
361
+
362
+ if self.in_proj_bias is not None:
363
+ constant_(self.in_proj_bias, 0.0)
364
+ constant_(self.out_proj.bias, 0.0)
365
+
366
+ if self.bias_k is not None:
367
+ xavier_normal_(self.bias_k)
368
+ if self.bias_v is not None:
369
+ xavier_normal_(self.bias_v)
370
+
371
+ def __setstate__(self, state):
372
+ # Support loading old MultiheadAttention checkpoints generated by v1.1.0
373
+ if "_qkv_same_embed_dim" not in state:
374
+ state["_qkv_same_embed_dim"] = True
375
+
376
+ super(MultiheadAttention, self).__setstate__(state)
377
+
378
+ def forward(
379
+ self,
380
+ query: Tensor,
381
+ key: Tensor,
382
+ value: Tensor,
383
+ key_padding_mask: Optional[Tensor] = None,
384
+ need_weights: bool = True,
385
+ attn_mask: Optional[Tensor] = None,
386
+ average_attn_weights: bool = True,
387
+ past: Optional[Tensor] = None,
388
+ q_sinu = None,
389
+ k_sinu = None,
390
+ sinu = None,
391
+ args = None,
392
+ q_offset = 0,
393
+ ) -> Tuple[Tensor, Optional[Tensor]]:
394
+ r"""
395
+ Args:
396
+ query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
397
+ or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
398
+ :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
399
+ Queries are compared against key-value pairs to produce the output.
400
+ See "Attention Is All You Need" for more details.
401
+ key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
402
+ or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
403
+ :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
404
+ See "Attention Is All You Need" for more details.
405
+ value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
406
+ ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
407
+ sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
408
+ See "Attention Is All You Need" for more details.
409
+ key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
410
+ to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
411
+ Binary and byte masks are supported.
412
+ For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
413
+ the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
414
+ need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
415
+ Default: ``True``.
416
+ attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
417
+ :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
418
+ :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
419
+ broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
420
+ Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
421
+ corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
422
+ corresponding position is not allowed to attend. For a float mask, the mask values will be added to
423
+ the attention weight.
424
+ average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
425
+ heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
426
+ effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
427
+ sinu: for direct original rope positional encoding
428
+ freqs: for progress monitoring with rope positional encoding
429
+ q_offset: for progress monitoring with rope positional encoding, during inference when kvcache is on, most of the time query is only of length 1, so we need to offset the query to get the correct progress
430
+
431
+ Outputs:
432
+ - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
433
+ :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
434
+ where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
435
+ embedding dimension ``embed_dim``.
436
+ - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
437
+ returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
438
+ :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
439
+ :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
440
+ head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
441
+
442
+ .. note::
443
+ `batch_first` argument is ignored for unbatched inputs.
444
+ """
445
+ is_batched = query.dim() == 3
446
+ if key_padding_mask is not None:
447
+ _kpm_dtype = key_padding_mask.dtype
448
+ if _kpm_dtype != torch.bool and not torch.is_floating_point(
449
+ key_padding_mask
450
+ ):
451
+ raise AssertionError(
452
+ "only bool and floating types of key_padding_mask are supported"
453
+ )
454
+ why_not_fast_path = ""
455
+ if not is_batched:
456
+ why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
457
+ elif query is not key or key is not value:
458
+ # When lifting this restriction, don't forget to either
459
+ # enforce that the dtypes all match or test cases where
460
+ # they don't!
461
+ why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
462
+ elif (
463
+ self.in_proj_bias is not None
464
+ and query.dtype != self.in_proj_bias.dtype
465
+ ):
466
+ why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
467
+ elif (
468
+ self.in_proj_weight is not None
469
+ and query.dtype != self.in_proj_weight.dtype
470
+ ):
471
+ # this case will fail anyway, but at least they'll get a useful error message.
472
+ why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
473
+ elif self.training:
474
+ why_not_fast_path = "training is enabled"
475
+ elif not self.batch_first:
476
+ why_not_fast_path = "batch_first was not True"
477
+ elif self.bias_k is not None:
478
+ why_not_fast_path = "self.bias_k was not None"
479
+ elif self.bias_v is not None:
480
+ why_not_fast_path = "self.bias_v was not None"
481
+ elif self.dropout:
482
+ why_not_fast_path = f"dropout was {self.dropout}, required zero"
483
+ elif self.add_zero_attn:
484
+ why_not_fast_path = "add_zero_attn was enabled"
485
+ elif not self._qkv_same_embed_dim:
486
+ why_not_fast_path = "_qkv_same_embed_dim was not True"
487
+ elif attn_mask is not None:
488
+ why_not_fast_path = "attn_mask was not None"
489
+ elif query.is_nested and key_padding_mask is not None:
490
+ why_not_fast_path = (
491
+ "key_padding_mask is not supported with NestedTensor input"
492
+ )
493
+ elif self.num_heads % 2 == 1:
494
+ why_not_fast_path = "num_heads is odd"
495
+ elif torch.is_autocast_enabled():
496
+ why_not_fast_path = "autocast is enabled"
497
+
498
+ if not why_not_fast_path:
499
+ tensor_args = (
500
+ query,
501
+ key,
502
+ value,
503
+ self.in_proj_weight,
504
+ self.in_proj_bias,
505
+ self.out_proj.weight,
506
+ self.out_proj.bias,
507
+ )
508
+ # We have to use list comprehensions below because TorchScript does not support
509
+ # generator expressions.
510
+ if torch.overrides.has_torch_function(tensor_args):
511
+ why_not_fast_path = "some Tensor argument has_torch_function"
512
+ elif not all(
513
+ [
514
+ (x is None or x.is_cuda or "cpu" in str(x.device))
515
+ for x in tensor_args
516
+ ]
517
+ ):
518
+ why_not_fast_path = (
519
+ "some Tensor argument is neither CUDA nor CPU"
520
+ )
521
+ elif torch.is_grad_enabled() and any(
522
+ [x is not None and x.requires_grad for x in tensor_args]
523
+ ):
524
+ why_not_fast_path = (
525
+ "grad is enabled and at least one of query or the "
526
+ "input/output projection weights or biases requires_grad"
527
+ )
528
+ if not why_not_fast_path:
529
+ return torch._native_multi_head_attention(
530
+ query,
531
+ key,
532
+ value,
533
+ self.embed_dim,
534
+ self.num_heads,
535
+ self.in_proj_weight,
536
+ self.in_proj_bias,
537
+ self.out_proj.weight,
538
+ self.out_proj.bias,
539
+ key_padding_mask
540
+ if key_padding_mask is not None
541
+ else attn_mask,
542
+ need_weights,
543
+ average_attn_weights,
544
+ 1
545
+ if key_padding_mask is not None
546
+ else 0
547
+ if attn_mask is not None
548
+ else None,
549
+ )
550
+
551
+ any_nested = query.is_nested or key.is_nested or value.is_nested
552
+ assert not any_nested, (
553
+ "MultiheadAttention does not support NestedTensor outside of its fast path. "
554
+ + f"The fast path was not hit because {why_not_fast_path}"
555
+ )
556
+
557
+ if self.batch_first and is_batched:
558
+ # make sure that the transpose op does not affect the "is" property
559
+ if key is value:
560
+ if query is key:
561
+ query = key = value = query.transpose(1, 0)
562
+ else:
563
+ query, key = [x.transpose(1, 0) for x in (query, key)]
564
+ value = key
565
+ else:
566
+ query, key, value = [
567
+ x.transpose(1, 0) for x in (query, key, value)
568
+ ]
569
+
570
+ if not self._qkv_same_embed_dim:
571
+ attn_output, attn_output_weights = F.multi_head_attention_forward(
572
+ query,
573
+ key,
574
+ value,
575
+ self.embed_dim,
576
+ self.num_heads,
577
+ self.in_proj_weight,
578
+ self.in_proj_bias,
579
+ self.bias_k,
580
+ self.bias_v,
581
+ self.add_zero_attn,
582
+ self.dropout,
583
+ self.out_proj.weight,
584
+ self.out_proj.bias,
585
+ training=self.training,
586
+ key_padding_mask=key_padding_mask,
587
+ need_weights=need_weights,
588
+ attn_mask=attn_mask,
589
+ use_separate_proj_weight=True,
590
+ q_proj_weight=self.q_proj_weight,
591
+ k_proj_weight=self.k_proj_weight,
592
+ v_proj_weight=self.v_proj_weight,
593
+ average_attn_weights=average_attn_weights,
594
+ )
595
+ else:
596
+ # music_gen should go down this route
597
+ # logging.info("using in_proj_weight")
598
+ # attn_output, attn_output_weights = F.multi_head_attention_forward(
599
+ # query,
600
+ # key,
601
+ # value,
602
+ # self.embed_dim,
603
+ # self.num_heads,
604
+ # self.in_proj_weight,
605
+ # self.in_proj_bias,
606
+ # self.bias_k,
607
+ # self.bias_v,
608
+ # self.add_zero_attn,
609
+ # self.dropout,
610
+ # self.out_proj.weight,
611
+ # self.out_proj.bias,
612
+ # training=self.training,
613
+ # key_padding_mask=key_padding_mask,
614
+ # need_weights=need_weights,
615
+ # attn_mask=attn_mask,
616
+ # average_attn_weights=average_attn_weights,
617
+ # )
618
+ # re-write the self.attention here, to get k, v cache
619
+ tgt_len, bsz, embed_dim = query.shape
620
+ src_len, _, _ = key.shape
621
+ num_heads = self.num_heads
622
+ key_padding_mask = _canonical_mask(
623
+ mask=key_padding_mask,
624
+ mask_name="key_padding_mask",
625
+ other_type=_none_or_dtype(attn_mask),
626
+ other_name="attn_mask",
627
+ target_type=query.dtype
628
+ )
629
+ attn_mask = _canonical_mask(
630
+ mask=attn_mask,
631
+ mask_name="attn_mask",
632
+ other_type=None,
633
+ other_name="",
634
+ target_type=query.dtype,
635
+ check_other=False,
636
+ )
637
+ head_dim = self.embed_dim // self.num_heads
638
+ assert head_dim * self.num_heads == self.embed_dim, f"embed_dim {self.embed_dim} not divisible by num_heads {self.num_heads}"
639
+ assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
640
+ q, k, v = _in_projection_packed(query, key, value, self.in_proj_weight, self.in_proj_bias)
641
+ # k_present, v_present = k, v
642
+
643
+ #
644
+ # reshape q, k, v for multihead attention and make em batch first
645
+ #
646
+
647
+
648
+
649
+ q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
650
+ k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
651
+ v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1) # (bsz * num_heads, src_len, head_dim)
652
+ src_len = k.size(1)
653
+ if past is not None and past.ndim > 2:
654
+ expected_src_len = src_len + past[0].shape[-2]
655
+ else:
656
+ expected_src_len = src_len
657
+
658
+
659
+ # ensure attn_mask's dim is 3
660
+ if attn_mask is not None:
661
+ if attn_mask.dim() == 2:
662
+ correct_2d_size = (tgt_len, expected_src_len)
663
+ if attn_mask.shape != correct_2d_size:
664
+ raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
665
+ attn_mask = attn_mask.unsqueeze(0)
666
+ elif attn_mask.dim() == 3:
667
+ correct_3d_size = (bsz * num_heads, tgt_len, expected_src_len)
668
+ if attn_mask.shape != correct_3d_size:
669
+ raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
670
+ else:
671
+ raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
672
+
673
+ if key_padding_mask is not None:
674
+ assert key_padding_mask.shape == (bsz, expected_src_len), \
675
+ f"expecting key_padding_mask shape of {(bsz, expected_src_len)}, but got {key_padding_mask.shape}"
676
+ key_padding_mask = key_padding_mask.view(bsz, 1, 1, expected_src_len). \
677
+ expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, expected_src_len)
678
+ if attn_mask is None:
679
+ attn_mask = key_padding_mask
680
+ else:
681
+ attn_mask = attn_mask + key_padding_mask
682
+
683
+ if not self.training:
684
+ dropout_p = 0.0
685
+ else:
686
+ dropout_p = self.dropout
687
+
688
+ if need_weights:
689
+ raise NotImplementedError("need_weights not implemented for music_gen")
690
+ # B, Nt, E = q.shape
691
+ # q_scaled = q / math.sqrt(E)
692
+
693
+ # assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
694
+
695
+ # if attn_mask is not None:
696
+ # attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
697
+ # else:
698
+ # attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
699
+ # attn_output_weights = softmax(attn_output_weights, dim=-1)
700
+ # if dropout_p > 0.0:
701
+ # attn_output_weights = dropout(attn_output_weights, p=dropout_p)
702
+
703
+ # attn_output = torch.bmm(attn_output_weights, v)
704
+
705
+ # attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
706
+ # attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
707
+ # attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
708
+
709
+ # # optionally average attention weights over heads
710
+ # attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
711
+ # if average_attn_weights:
712
+ # attn_output_weights = attn_output_weights.mean(dim=1)
713
+
714
+ # if not is_batched:
715
+ # # squeeze the output if input was unbatched
716
+ # attn_output = attn_output.squeeze(1)
717
+ # attn_output_weights = attn_output_weights.squeeze(0)
718
+ # return attn_output, attn_output_weights
719
+ else:
720
+ # attn_mask can be either (L,S) or (N*num_heads, L, S)
721
+ # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
722
+ # in order to match the input for SDPA of (N, num_heads, L, S)
723
+ if attn_mask is not None:
724
+ if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
725
+ attn_mask = attn_mask.unsqueeze(0)
726
+ else:
727
+ attn_mask = attn_mask.view(bsz, num_heads, -1, expected_src_len)
728
+
729
+ q = q.view(bsz, num_heads, tgt_len, head_dim)
730
+ k = k.view(bsz, num_heads, src_len, head_dim)
731
+ v = v.view(bsz, num_heads, src_len, head_dim)
732
+ # logging.info(f"shape of past: {past.shape}")
733
+ if past is not None:
734
+ present = torch.stack([k, v], dim=0) # (2, bsz, num_heads, src_len, head_dim)
735
+ if past.ndim > 2: # this means we use kvcache, otherwise we just pass in a placeholder, but not actually using kvcache
736
+ pk, pv = past
737
+ k = torch.cat([pk, k], dim=-2)
738
+ v = torch.cat([pv, v], dim=-2)
739
+ else:
740
+ present = None
741
+ # when using kvcache, need to offset postion of q when applying rotary pos emb
742
+ # here we assume that this kvcache is only used in self-attention, and therefore k and q always have the same seq_len
743
+ # rope positional encoding
744
+ if sinu is not None:
745
+ # direct rotary
746
+ # logging.info("perform rotary positional encoding")
747
+ q, k = apply_rotary_pos_emb(q, k, sinu=sinu, args = args, q_offset=q_offset)
748
+ if q_sinu is not None:
749
+ assert sinu is None, "sinu and q_sinu cannot be used together"
750
+ assert k_sinu is not None, "k_sinu must be provided"
751
+ q, k = apply_rotary_pos_emb(q, k, q_sinu=q_sinu, k_sinu=k_sinu, args = args, q_offset=q_offset)
752
+
753
+ # if self.training and it's cross attention, will get attention_weights
754
+ if args != None and self.training and getattr(args, "attention_alignment_loss", 0) and not (query is key):
755
+ attention_weights = q @ k.transpose(-1, -2)
756
+ else:
757
+ attention_weights = None
758
+ attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal=False)
759
+ attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
760
+
761
+ attn_output = F.linear(attn_output, self.out_proj.weight, self.out_proj.bias)
762
+ attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
763
+ if not is_batched:
764
+ # squeeze the output if input was unbatched
765
+ attn_output = attn_output.squeeze(1)
766
+ # if self.training:
767
+ # return attn_output, None
768
+ # else:
769
+ # return (attn_output, present), None
770
+
771
+ # harded coded, the code do not support returning attn weigths yet
772
+ attn_output_weights=None
773
+ if self.batch_first and is_batched:
774
+ if attention_weights != None:
775
+ return {"attn_output": attn_output.transpose(1, 0), "attention_weights": attention_weights}, present
776
+ return attn_output.transpose(1, 0), present
777
+ else:
778
+ if attention_weights != None:
779
+ return {"attn_output": attn_output, "attention_weights": attention_weights}, present
780
+ return attn_output, present
781
+
models/modules/embedding.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py
2
+ # Copyright 2023 (authors: Feiteng Li)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ import logging
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+
22
+ class TokenEmbedding(nn.Module):
23
+ def __init__(
24
+ self,
25
+ dim_model: int,
26
+ vocab_size: int,
27
+ dropout: float = 0.0,
28
+ ):
29
+ super().__init__()
30
+
31
+ self.vocab_size = vocab_size
32
+ self.dim_model = dim_model
33
+
34
+ self.dropout = torch.nn.Dropout(p=dropout)
35
+ self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model)
36
+
37
+ @property
38
+ def weight(self) -> torch.Tensor:
39
+ return self.word_embeddings.weight
40
+
41
+ def embedding(self, index: int) -> torch.Tensor:
42
+ return self.word_embeddings.weight[index : index + 1]
43
+
44
+ def forward(self, x: torch.Tensor):
45
+ X = self.word_embeddings(x)
46
+ X = self.dropout(X)
47
+
48
+ return X
49
+
50
+
51
+ class SinePositionalEmbedding(nn.Module):
52
+ def __init__(
53
+ self,
54
+ dim_model: int,
55
+ dropout: float = 0.0,
56
+ scale: bool = False,
57
+ alpha: bool = False,
58
+ ):
59
+ super().__init__()
60
+ self.dim_model = dim_model
61
+ self.x_scale = math.sqrt(dim_model) if scale else 1.0
62
+ self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
63
+ self.dropout = torch.nn.Dropout(p=dropout)
64
+
65
+ self.reverse = False
66
+ self.pe = None
67
+ self.extend_pe(torch.tensor(0.0).expand(1, 4000))
68
+
69
+ def extend_pe(self, x):
70
+ """Reset the positional encodings."""
71
+ if self.pe is not None:
72
+ if self.pe.size(1) >= x.size(1):
73
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
74
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
75
+ return
76
+ pe = torch.zeros(x.size(1), self.dim_model)
77
+ if self.reverse:
78
+ position = torch.arange(
79
+ x.size(1) - 1, -1, -1.0, dtype=torch.float32
80
+ ).unsqueeze(1)
81
+ else:
82
+ position = torch.arange(
83
+ 0, x.size(1), dtype=torch.float32
84
+ ).unsqueeze(1)
85
+ div_term = torch.exp(
86
+ torch.arange(0, self.dim_model, 2, dtype=torch.float32)
87
+ * -(math.log(10000.0) / self.dim_model)
88
+ )
89
+ pe[:, 0::2] = torch.sin(position * div_term)
90
+ pe[:, 1::2] = torch.cos(position * div_term)
91
+ pe = pe.unsqueeze(0)
92
+ self.pe = pe.to(device=x.device, dtype=x.dtype).detach()
93
+
94
+ def forward(self, x: torch.Tensor, *args) -> torch.Tensor:
95
+ self.extend_pe(x)
96
+ output = x.unsqueeze(-1) if x.ndim == 2 else x
97
+ output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
98
+ return self.dropout(output)
99
+
100
+
101
+ class SinePositionalEmbedding_progress(nn.Module):
102
+ def __init__(
103
+ self,
104
+ dim_model: int,
105
+ dropout: float = 0.0,
106
+ scale: bool = False,
107
+ alpha: bool = False,
108
+ args = None
109
+ ):
110
+ super().__init__()
111
+ self.args = args
112
+ self.dim_model = dim_model
113
+ self.x_scale = math.sqrt(dim_model) if scale else 1.0
114
+ self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
115
+ self.dropout = torch.nn.Dropout(p=dropout)
116
+
117
+ self.reverse = False
118
+ self.div_term = torch.exp(
119
+ torch.arange(0, self.dim_model, 2, dtype=torch.float32)
120
+ * -(math.log(args.sinusoidal_base) / self.dim_model)
121
+ ).unsqueeze(0).unsqueeze(0) # [1, 1, dim_model//2]
122
+ self.position = None
123
+ self.extend_position(torch.tensor(0.0).expand(1, 10000))
124
+ self.progress_scale = getattr(args, "progress_scale", 1.0)
125
+
126
+ def extend_position(self, x):
127
+ """Reset the positional encodings."""
128
+ if self.position is not None:
129
+ if self.div_term.dtype != x.dtype or self.div_term.device != x.device:
130
+ self.div_term = self.div_term.to(dtype=x.dtype, device=x.device)
131
+ if self.position.size(1) >= x.size(1):
132
+ if self.position.dtype != x.dtype or self.position.device != x.device:
133
+ self.position = self.position.to(dtype=x.dtype, device=x.device)
134
+ return
135
+ if self.reverse:
136
+ self.position = torch.arange(
137
+ x.size(1) - 1, -1, -1.0, dtype=torch.float32
138
+ ).unsqueeze(0).unsqueeze(2).to(x)
139
+ else:
140
+ self.position = torch.arange(
141
+ 0, x.size(1), dtype=torch.float32
142
+ ).unsqueeze(0).unsqueeze(2).to(x) # [1, seq_len, 1]
143
+
144
+ def forward(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor:
145
+ assert x.ndim == 3, x.shape
146
+ self.extend_position(x)
147
+ x_lens = x_lens.unsqueeze(1).unsqueeze(2) # [B, 1, 1]
148
+ multiple = x_lens / (x_lens - 1)
149
+ progress = self.position[:, :x.shape[1]] * multiple / x_lens * self.progress_scale
150
+ # torch.set_printoptions(edgeitems=100)
151
+ # for i in range(x_lens.shape[0]):
152
+ # logging.info(f"{progress[i, :x_lens[i,0,0], 0]}")
153
+ invfreq = self.div_term * progress # might want to use a scale term here
154
+ pe = torch.zeros_like(x)
155
+ pe[..., 0::2] = torch.sin(invfreq)
156
+ pe[..., 1::2] = torch.cos(invfreq)
157
+ output = x * self.x_scale + self.alpha * pe
158
+ return self.dropout(output)
models/modules/sampling.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ def top_k_top_p_filtering(
5
+ logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
6
+ ):
7
+ """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
8
+ Args:
9
+ logits: logits distribution shape (batch size, vocabulary size)
10
+ if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
11
+ if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
12
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
13
+ Make sure we keep at least min_tokens_to_keep per batch example in the output
14
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
15
+ """
16
+ if top_k > 0:
17
+ top_k = min(
18
+ max(top_k, min_tokens_to_keep), logits.size(-1)
19
+ ) # Safety check
20
+ # Remove all tokens with a probability less than the last token of the top-k
21
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
22
+ logits[indices_to_remove] = filter_value
23
+
24
+ if top_p < 1.0:
25
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
26
+ cumulative_probs = torch.cumsum(
27
+ F.softmax(sorted_logits, dim=-1), dim=-1
28
+ )
29
+
30
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
31
+ sorted_indices_to_remove = cumulative_probs > top_p
32
+ if min_tokens_to_keep > 1:
33
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
34
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
35
+ # Shift the indices to the right to keep also the first token above the threshold
36
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
37
+ ..., :-1
38
+ ].clone()
39
+ sorted_indices_to_remove[..., 0] = 0
40
+
41
+ # scatter sorted tensors to original indexing
42
+ indices_to_remove = sorted_indices_to_remove.scatter(
43
+ 1, sorted_indices, sorted_indices_to_remove
44
+ )
45
+ logits[indices_to_remove] = filter_value
46
+ return logits
47
+
48
+ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
49
+ # temperature: (`optional`) float
50
+ # The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
51
+ # top_k: (`optional`) int
52
+ # The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
53
+ # top_p: (`optional`) float
54
+ # The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
55
+
56
+ # Temperature (higher temperature => more likely to sample low probability tokens)
57
+ if temperature != 1.0:
58
+ logits = logits / temperature
59
+ # Top-p/top-k filtering
60
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
61
+ # Sample
62
+ token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
63
+ return token
models/modules/scaling.py ADDED
@@ -0,0 +1,1406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/scaling.py
2
+ # Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+
19
+ import collections
20
+ import logging
21
+ import random
22
+ import math
23
+ from functools import reduce
24
+ from itertools import repeat
25
+ from typing import Optional, Tuple, Union
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+ import torch.nn.functional as F
30
+ from torch import Tensor
31
+ from torch.nn import Embedding as ScaledEmbedding
32
+
33
+ # from valle.utils import Transpose
34
+
35
+ class Transpose(nn.Identity):
36
+ """(N, T, D) -> (N, D, T)"""
37
+
38
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
39
+ return input.transpose(1, 2)
40
+
41
+ class ActivationBalancerFunction(torch.autograd.Function):
42
+ @staticmethod
43
+ def forward(
44
+ ctx,
45
+ x: Tensor,
46
+ scale_factor: Tensor,
47
+ sign_factor: Optional[Tensor],
48
+ channel_dim: int,
49
+ ) -> Tensor:
50
+ if channel_dim < 0:
51
+ channel_dim += x.ndim
52
+ ctx.channel_dim = channel_dim
53
+ xgt0 = x > 0
54
+ if sign_factor is None:
55
+ ctx.save_for_backward(xgt0, scale_factor)
56
+ else:
57
+ ctx.save_for_backward(xgt0, scale_factor, sign_factor)
58
+ return x
59
+
60
+ @staticmethod
61
+ def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
62
+ if len(ctx.saved_tensors) == 3:
63
+ xgt0, scale_factor, sign_factor = ctx.saved_tensors
64
+ for _ in range(ctx.channel_dim, x_grad.ndim - 1):
65
+ scale_factor = scale_factor.unsqueeze(-1)
66
+ sign_factor = sign_factor.unsqueeze(-1)
67
+ factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
68
+ else:
69
+ xgt0, scale_factor = ctx.saved_tensors
70
+ for _ in range(ctx.channel_dim, x_grad.ndim - 1):
71
+ scale_factor = scale_factor.unsqueeze(-1)
72
+ factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
73
+ neg_delta_grad = x_grad.abs() * factor
74
+ return (
75
+ x_grad - neg_delta_grad,
76
+ None,
77
+ None,
78
+ None,
79
+ )
80
+
81
+
82
+ def _compute_scale_factor(
83
+ x: Tensor,
84
+ channel_dim: int,
85
+ min_abs: float,
86
+ max_abs: float,
87
+ gain_factor: float,
88
+ max_factor: float,
89
+ ) -> Tensor:
90
+ if channel_dim < 0:
91
+ channel_dim += x.ndim
92
+ sum_dims = [d for d in range(x.ndim) if d != channel_dim]
93
+ x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)
94
+
95
+ if min_abs == 0.0:
96
+ below_threshold = 0.0
97
+ else:
98
+ # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
99
+ # x_abs)_mean , min_abs.
100
+ below_threshold = (
101
+ (min_abs - x_abs_mean) * (gain_factor / min_abs)
102
+ ).clamp(min=0, max=max_factor)
103
+
104
+ above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(
105
+ min=0, max=max_factor
106
+ )
107
+
108
+ return below_threshold - above_threshold
109
+
110
+
111
+ def _compute_sign_factor(
112
+ x: Tensor,
113
+ channel_dim: int,
114
+ min_positive: float,
115
+ max_positive: float,
116
+ gain_factor: float,
117
+ max_factor: float,
118
+ ) -> Tensor:
119
+ if channel_dim < 0:
120
+ channel_dim += x.ndim
121
+ sum_dims = [d for d in range(x.ndim) if d != channel_dim]
122
+ proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims)
123
+ if min_positive == 0.0:
124
+ factor1 = 0.0
125
+ else:
126
+ # 0 if proportion_positive >= min_positive, else can be
127
+ # as large as max_factor.
128
+ factor1 = (
129
+ (min_positive - proportion_positive) * (gain_factor / min_positive)
130
+ ).clamp_(min=0, max=max_factor)
131
+
132
+ if max_positive == 1.0:
133
+ factor2 = 0.0
134
+ else:
135
+ # 0 if self.proportion_positive <= max_positive, else can be
136
+ # as large as -max_factor.
137
+ factor2 = (
138
+ (proportion_positive - max_positive)
139
+ * (gain_factor / (1.0 - max_positive))
140
+ ).clamp_(min=0, max=max_factor)
141
+ sign_factor = factor1 - factor2
142
+ # require min_positive != 0 or max_positive != 1:
143
+ assert not isinstance(sign_factor, float)
144
+ return sign_factor
145
+
146
+
147
+ class ActivationScaleBalancerFunction(torch.autograd.Function):
148
+ """
149
+ This object is used in class ActivationBalancer when the user specified
150
+ min_positive=0, max_positive=1, so there are no constraints on the signs
151
+ of the activations and only the absolute value has a constraint.
152
+ """
153
+
154
+ @staticmethod
155
+ def forward(
156
+ ctx,
157
+ x: Tensor,
158
+ sign_factor: Tensor,
159
+ scale_factor: Tensor,
160
+ channel_dim: int,
161
+ ) -> Tensor:
162
+ if channel_dim < 0:
163
+ channel_dim += x.ndim
164
+ ctx.channel_dim = channel_dim
165
+ xgt0 = x > 0
166
+ ctx.save_for_backward(xgt0, sign_factor, scale_factor)
167
+ return x
168
+
169
+ @staticmethod
170
+ def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
171
+ xgt0, sign_factor, scale_factor = ctx.saved_tensors
172
+ for _ in range(ctx.channel_dim, x_grad.ndim - 1):
173
+ sign_factor = sign_factor.unsqueeze(-1)
174
+ scale_factor = scale_factor.unsqueeze(-1)
175
+
176
+ factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
177
+ neg_delta_grad = x_grad.abs() * factor
178
+ return (
179
+ x_grad - neg_delta_grad,
180
+ None,
181
+ None,
182
+ None,
183
+ )
184
+
185
+
186
+ class RandomClampFunction(torch.autograd.Function):
187
+ @staticmethod
188
+ def forward(
189
+ ctx,
190
+ x: Tensor,
191
+ min: Optional[float],
192
+ max: Optional[float],
193
+ prob: float,
194
+ reflect: float,
195
+ ) -> Tensor:
196
+ x_clamped = torch.clamp(x, min=min, max=max)
197
+ mask = torch.rand_like(x) < prob
198
+ ans = torch.where(mask, x_clamped, x)
199
+ if x.requires_grad:
200
+ ctx.save_for_backward(ans == x)
201
+ ctx.reflect = reflect
202
+ if reflect != 0.0:
203
+ ans = ans * (1.0 + reflect) - (x * reflect)
204
+ return ans
205
+
206
+ @staticmethod
207
+ def backward(
208
+ ctx, ans_grad: Tensor
209
+ ) -> Tuple[Tensor, None, None, None, None]:
210
+ (is_same,) = ctx.saved_tensors
211
+ x_grad = ans_grad * is_same.to(ans_grad.dtype)
212
+ reflect = ctx.reflect
213
+ if reflect != 0.0:
214
+ x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect)
215
+ return x_grad, None, None, None, None
216
+
217
+
218
+ def random_clamp(
219
+ x: Tensor,
220
+ min: Optional[float] = None,
221
+ max: Optional[float] = None,
222
+ prob: float = 0.5,
223
+ reflect: float = 0.0,
224
+ ):
225
+ return RandomClampFunction.apply(x, min, max, prob, reflect)
226
+
227
+
228
+ def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor:
229
+ """
230
+ A randomized way of casting a floating point value to half precision.
231
+ """
232
+ if x.dtype == torch.float16:
233
+ return x
234
+ x_abs = x.abs()
235
+ is_too_small = x_abs < min_abs
236
+ # for elements where is_too_small is true, random_val will contain +-min_abs with
237
+ # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations,
238
+ # for those elements].
239
+ random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs)
240
+ return torch.where(is_too_small, random_val, x).to(torch.float16)
241
+
242
+
243
+ class RandomGradFunction(torch.autograd.Function):
244
+ """
245
+ Does nothing in forward pass; in backward pass, gets rid of very small grads using
246
+ randomized approach that preserves expectations (intended to reduce roundoff).
247
+ """
248
+
249
+ @staticmethod
250
+ def forward(ctx, x: Tensor, min_abs: float) -> Tensor:
251
+ ctx.min_abs = min_abs
252
+ return x
253
+
254
+ @staticmethod
255
+ def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]:
256
+ if ans_grad.dtype == torch.float16:
257
+ return (
258
+ random_cast_to_half(
259
+ ans_grad.to(torch.float32), min_abs=ctx.min_abs
260
+ ),
261
+ None,
262
+ )
263
+ else:
264
+ return ans_grad, None
265
+
266
+
267
+ class RandomGrad(torch.nn.Module):
268
+ """
269
+ Gets rid of very small gradients using an expectation-preserving method, intended to increase
270
+ accuracy of training when using amp (automatic mixed precision)
271
+ """
272
+
273
+ def __init__(self, min_abs: float = 5.0e-06):
274
+ super(RandomGrad, self).__init__()
275
+ self.min_abs = min_abs
276
+
277
+ def forward(self, x: Tensor):
278
+ if (
279
+ torch.jit.is_scripting()
280
+ or not self.training
281
+ or torch.jit.is_tracing()
282
+ ):
283
+ return x
284
+ else:
285
+ return RandomGradFunction.apply(x, self.min_abs)
286
+
287
+
288
+ class SoftmaxFunction(torch.autograd.Function):
289
+ """
290
+ Tries to handle half-precision derivatives in a randomized way that should
291
+ be more accurate for training than the default behavior.
292
+ """
293
+
294
+ @staticmethod
295
+ def forward(ctx, x: Tensor, dim: int):
296
+ ans = x.softmax(dim=dim)
297
+ # if x dtype is float16, x.softmax() returns a float32 because
298
+ # (presumably) that op does not support float16, and autocast
299
+ # is enabled.
300
+ if torch.is_autocast_enabled():
301
+ ans = ans.to(torch.float16)
302
+ ctx.save_for_backward(ans)
303
+ ctx.x_dtype = x.dtype
304
+ ctx.dim = dim
305
+ return ans
306
+
307
+ @staticmethod
308
+ def backward(ctx, ans_grad: Tensor):
309
+ (ans,) = ctx.saved_tensors
310
+ with torch.cuda.amp.autocast(enabled=False):
311
+ ans_grad = ans_grad.to(torch.float32)
312
+ ans = ans.to(torch.float32)
313
+ x_grad = ans_grad * ans
314
+ x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
315
+ return x_grad, None
316
+
317
+
318
+ def softmax(x: Tensor, dim: int):
319
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
320
+ return x.softmax(dim)
321
+
322
+ return SoftmaxFunction.apply(x, dim)
323
+
324
+
325
+ class MaxEigLimiterFunction(torch.autograd.Function):
326
+ @staticmethod
327
+ def forward(
328
+ ctx,
329
+ x: Tensor,
330
+ coeffs: Tensor,
331
+ direction: Tensor,
332
+ channel_dim: int,
333
+ grad_scale: float,
334
+ ) -> Tensor:
335
+ ctx.channel_dim = channel_dim
336
+ ctx.grad_scale = grad_scale
337
+ ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach())
338
+ return x
339
+
340
+ @staticmethod
341
+ def backward(ctx, x_grad, *args):
342
+ with torch.enable_grad():
343
+ (x_orig, coeffs, new_direction) = ctx.saved_tensors
344
+ x_orig.requires_grad = True
345
+ num_channels = x_orig.shape[ctx.channel_dim]
346
+ x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels)
347
+ new_direction.requires_grad = False
348
+ x = x - x.mean(dim=0)
349
+ x_var = (x ** 2).mean()
350
+ x_residual = x - coeffs * new_direction
351
+ x_residual_var = (x_residual ** 2).mean()
352
+ # `variance_proportion` is the proportion of the variance accounted for
353
+ # by the top eigen-direction. This is to be minimized.
354
+ variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
355
+ variance_proportion.backward()
356
+ x_orig_grad = x_orig.grad
357
+ x_extra_grad = (
358
+ x_orig.grad
359
+ * ctx.grad_scale
360
+ * x_grad.norm()
361
+ / (x_orig_grad.norm() + 1.0e-20)
362
+ )
363
+ return x_grad + x_extra_grad.detach(), None, None, None, None
364
+
365
+
366
+ class BasicNorm(torch.nn.Module):
367
+ """
368
+ This is intended to be a simpler, and hopefully cheaper, replacement for
369
+ LayerNorm. The observation this is based on, is that Transformer-type
370
+ networks, especially with pre-norm, sometimes seem to set one of the
371
+ feature dimensions to a large constant value (e.g. 50), which "defeats"
372
+ the LayerNorm because the output magnitude is then not strongly dependent
373
+ on the other (useful) features. Presumably the weight and bias of the
374
+ LayerNorm are required to allow it to do this.
375
+
376
+ So the idea is to introduce this large constant value as an explicit
377
+ parameter, that takes the role of the "eps" in LayerNorm, so the network
378
+ doesn't have to do this trick. We make the "eps" learnable.
379
+
380
+ Args:
381
+ num_channels: the number of channels, e.g. 512.
382
+ channel_dim: the axis/dimension corresponding to the channel,
383
+ interprted as an offset from the input's ndim if negative.
384
+ shis is NOT the num_channels; it should typically be one of
385
+ {-2, -1, 0, 1, 2, 3}.
386
+ eps: the initial "epsilon" that we add as ballast in:
387
+ scale = ((input_vec**2).mean() + epsilon)**-0.5
388
+ Note: our epsilon is actually large, but we keep the name
389
+ to indicate the connection with conventional LayerNorm.
390
+ learn_eps: if true, we learn epsilon; if false, we keep it
391
+ at the initial value.
392
+ eps_min: float
393
+ eps_max: float
394
+ """
395
+
396
+ def __init__(
397
+ self,
398
+ num_channels: int,
399
+ channel_dim: int = -1, # CAUTION: see documentation.
400
+ eps: float = 0.25,
401
+ learn_eps: bool = True,
402
+ eps_min: float = -3.0,
403
+ eps_max: float = 3.0,
404
+ ) -> None:
405
+ super(BasicNorm, self).__init__()
406
+ self.num_channels = num_channels
407
+ self.channel_dim = channel_dim
408
+ if learn_eps:
409
+ self.eps = nn.Parameter(torch.tensor(eps).log().detach())
410
+ else:
411
+ self.register_buffer("eps", torch.tensor(eps).log().detach())
412
+ self.eps_min = eps_min
413
+ self.eps_max = eps_max
414
+
415
+ def forward(self, x: Tensor) -> Tensor:
416
+ assert x.shape[self.channel_dim] == self.num_channels
417
+ eps = self.eps
418
+ if self.training and random.random() < 0.25:
419
+ # with probability 0.25, in training mode, clamp eps between the min
420
+ # and max; this will encourage it to learn parameters within the
421
+ # allowed range by making parameters that are outside the allowed
422
+ # range noisy.
423
+
424
+ # gradients to allow the parameter to get back into the allowed region if it happens to exit it.
425
+ eps = eps.clamp(min=self.eps_min, max=self.eps_max)
426
+ scales = (
427
+ torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps.exp()
428
+ ) ** -0.5
429
+ return x * scales
430
+
431
+
432
+ def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
433
+ """
434
+ Behaves like a constructor of a modified version of nn.Linear
435
+ that gives an easy way to set the default initial parameter scale.
436
+
437
+ Args:
438
+ Accepts the standard args and kwargs that nn.Linear accepts
439
+ e.g. in_features, out_features, bias=False.
440
+
441
+ initial_scale: you can override this if you want to increase
442
+ or decrease the initial magnitude of the module's output
443
+ (affects the initialization of weight_scale and bias_scale).
444
+ Another option, if you want to do something like this, is
445
+ to re-initialize the parameters.
446
+ """
447
+ ans = nn.Linear(*args, **kwargs)
448
+ with torch.no_grad():
449
+ ans.weight[:] *= initial_scale
450
+ if ans.bias is not None:
451
+ torch.nn.init.uniform_(
452
+ ans.bias, -0.1 * initial_scale, 0.1 * initial_scale
453
+ )
454
+ return ans
455
+
456
+
457
+ def ScaledConv1d(
458
+ *args,
459
+ initial_scale: float = 1.0,
460
+ kernel_size: int = 3,
461
+ padding: str = "same",
462
+ **kwargs,
463
+ ) -> nn.Conv1d:
464
+ """
465
+ Behaves like a constructor of a modified version of nn.Conv1d
466
+ that gives an easy way to set the default initial parameter scale.
467
+
468
+ Args:
469
+ Accepts the standard args and kwargs that nn.Linear accepts
470
+ e.g. in_features, out_features, bias=False.
471
+
472
+ initial_scale: you can override this if you want to increase
473
+ or decrease the initial magnitude of the module's output
474
+ (affects the initialization of weight_scale and bias_scale).
475
+ Another option, if you want to do something like this, is
476
+ to re-initialize the parameters.
477
+ """
478
+ ans = nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs)
479
+ with torch.no_grad():
480
+ ans.weight[:] *= initial_scale
481
+ if ans.bias is not None:
482
+ torch.nn.init.uniform_(
483
+ ans.bias, -0.1 * initial_scale, 0.1 * initial_scale
484
+ )
485
+ return ans
486
+
487
+
488
+ def TransposeScaledConv1d(
489
+ *args,
490
+ initial_scale: float = 1.0,
491
+ kernel_size: int = 3,
492
+ padding: str = "same",
493
+ **kwargs,
494
+ ) -> nn.Sequential:
495
+ """
496
+ Transpose -> ScaledConv1d
497
+ """
498
+ return nn.Sequential(
499
+ Transpose(),
500
+ ScaledConv1d(
501
+ *args,
502
+ initial_scale=initial_scale,
503
+ kernel_size=kernel_size,
504
+ padding=padding,
505
+ **kwargs,
506
+ ),
507
+ )
508
+
509
+
510
+ def ScaledConv1dTranspose(
511
+ *args,
512
+ initial_scale: float = 1.0,
513
+ kernel_size: int = 3,
514
+ padding: str = "same",
515
+ **kwargs,
516
+ ) -> nn.Sequential:
517
+ """
518
+ Transpose -> ScaledConv1d
519
+ """
520
+ return nn.Sequential(
521
+ ScaledConv1d(
522
+ *args,
523
+ initial_scale=initial_scale,
524
+ kernel_size=kernel_size,
525
+ padding=padding,
526
+ **kwargs,
527
+ ),
528
+ Transpose(),
529
+ )
530
+
531
+
532
+ def TransposeConv1d(
533
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
534
+ ) -> nn.Sequential:
535
+ """
536
+ Transpose -> Conv1d
537
+ """
538
+ return nn.Sequential(
539
+ Transpose(),
540
+ nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
541
+ )
542
+
543
+
544
+ def Conv1dTranspose(
545
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
546
+ ) -> nn.Sequential:
547
+ """
548
+ ScaledConv1d -> Transpose
549
+ """
550
+ return nn.Sequential(
551
+ nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
552
+ Transpose(),
553
+ )
554
+
555
+
556
+ class SRLinear(nn.Linear):
557
+ """https://arxiv.org/abs/2303.06296
558
+ Stabilizing Transformer Training by Preventing Attention Entropy Collapse
559
+ """
560
+
561
+ def __init__(self, in_features, out_features, bias=True, **kwargs):
562
+ super().__init__(in_features, out_features, bias=bias, **kwargs)
563
+ self.register_buffer(
564
+ "u", nn.functional.normalize(torch.randn(in_features), dim=0)
565
+ )
566
+ with torch.no_grad():
567
+ sigma = self.get_sigma()
568
+ self.register_buffer("spectral_norm", sigma)
569
+ self.sigma = nn.Parameter(torch.ones(1))
570
+
571
+ def get_sigma(self):
572
+ with torch.no_grad():
573
+ u = self.u
574
+ v = self.weight.mv(u)
575
+ v = nn.functional.normalize(v, dim=0)
576
+ u = self.weight.T.mv(v)
577
+ u = nn.functional.normalize(u, dim=0)
578
+ self.u.data.copy_(u)
579
+ return torch.einsum("c,cd,d->", v, self.weight, u)
580
+
581
+ def get_weight(self):
582
+ sigma = self.get_sigma()
583
+ if self.training:
584
+ self.spectral_norm.data.copy_(sigma)
585
+ weight = (self.sigma / sigma) * self.weight
586
+ return weight
587
+
588
+ def forward(self, x):
589
+ return nn.functional.linear(x, self.get_weight(), self.bias)
590
+
591
+
592
+ class SRConv1d(SRLinear):
593
+ def __init__(
594
+ self,
595
+ in_features,
596
+ out_features,
597
+ kernel_size,
598
+ stride: int = 1,
599
+ padding: str = "same",
600
+ bias: bool = True,
601
+ **kwargs,
602
+ ):
603
+ in_features = in_features * kernel_size
604
+ super().__init__(in_features, out_features, bias=bias, **kwargs)
605
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
606
+ self.kernel_size = kernel_size
607
+ self.stride = stride
608
+ self.padding = padding
609
+
610
+ def forward(self, x):
611
+ in_features = self.in_features // self.kernel_size
612
+ weight = self.get_weight().view(
613
+ self.out_features, in_features, self.kernel_size
614
+ )
615
+ return nn.functional.conv1d(
616
+ x, weight, bias=self.bias, stride=self.stride, padding=self.padding
617
+ )
618
+
619
+
620
+ def TransposeSRConv1d(
621
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
622
+ ) -> nn.Sequential:
623
+ """
624
+ Transpose -> SRConv1d
625
+ """
626
+ return nn.Sequential(
627
+ Transpose(),
628
+ SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
629
+ )
630
+
631
+
632
+ def SRConv1dTranspose(
633
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
634
+ ) -> nn.Sequential:
635
+ """
636
+ SRConv1d -> Transpose
637
+ """
638
+ return nn.Sequential(
639
+ SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
640
+ Transpose(),
641
+ )
642
+
643
+
644
+ class ActivationBalancer(torch.nn.Module):
645
+ """
646
+ Modifies the backpropped derivatives of a function to try to encourage, for
647
+ each channel, that it is positive at least a proportion `threshold` of the
648
+ time. It does this by multiplying negative derivative values by up to
649
+ (1+max_factor), and positive derivative values by up to (1-max_factor),
650
+ interpolated from 1 at the threshold to those extremal values when none
651
+ of the inputs are positive.
652
+
653
+ Args:
654
+ num_channels: the number of channels
655
+ channel_dim: the dimension/axis corresponding to the channel, e.g.
656
+ -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
657
+ min_positive: the minimum, per channel, of the proportion of the time
658
+ that (x > 0), below which we start to modify the derivatives.
659
+ max_positive: the maximum, per channel, of the proportion of the time
660
+ that (x > 0), above which we start to modify the derivatives.
661
+ max_factor: the maximum factor by which we modify the derivatives for
662
+ either the sign constraint or the magnitude constraint;
663
+ e.g. with max_factor=0.02, the the derivatives would be multiplied by
664
+ values in the range [0.98..1.02].
665
+ sign_gain_factor: determines the 'gain' with which we increase the
666
+ change in gradient once the constraints on min_positive and max_positive
667
+ are violated.
668
+ scale_gain_factor: determines the 'gain' with which we increase the
669
+ change in gradient once the constraints on min_abs and max_abs
670
+ are violated.
671
+ min_abs: the minimum average-absolute-value difference from the mean
672
+ value per channel, which we allow, before we start to modify
673
+ the derivatives to prevent this.
674
+ max_abs: the maximum average-absolute-value difference from the mean
675
+ value per channel, which we allow, before we start to modify
676
+ the derivatives to prevent this.
677
+ min_prob: determines the minimum probability with which we modify the
678
+ gradients for the {min,max}_positive and {min,max}_abs constraints,
679
+ on each forward(). This is done randomly to prevent all layers
680
+ from doing it at the same time. Early in training we may use
681
+ higher probabilities than this; it will decay to this value.
682
+ """
683
+
684
+ def __init__(
685
+ self,
686
+ num_channels: int,
687
+ channel_dim: int,
688
+ min_positive: float = 0.05,
689
+ max_positive: float = 0.95,
690
+ max_factor: float = 0.04,
691
+ sign_gain_factor: float = 0.01,
692
+ scale_gain_factor: float = 0.02,
693
+ min_abs: float = 0.2,
694
+ max_abs: float = 100.0,
695
+ min_prob: float = 0.1,
696
+ ):
697
+ super(ActivationBalancer, self).__init__()
698
+ self.num_channels = num_channels
699
+ self.channel_dim = channel_dim
700
+ self.min_positive = min_positive
701
+ self.max_positive = max_positive
702
+ self.max_factor = max_factor
703
+ self.min_abs = min_abs
704
+ self.max_abs = max_abs
705
+ self.min_prob = min_prob
706
+ self.sign_gain_factor = sign_gain_factor
707
+ self.scale_gain_factor = scale_gain_factor
708
+
709
+ # count measures how many times the forward() function has been called.
710
+ # We occasionally sync this to a tensor called `count`, that exists to
711
+ # make sure it is synced to disk when we load and save the model.
712
+ self.cpu_count = 0
713
+ self.register_buffer("count", torch.tensor(0, dtype=torch.int64))
714
+
715
+ def forward(self, x: Tensor) -> Tensor:
716
+ if (
717
+ torch.jit.is_scripting()
718
+ or not x.requires_grad
719
+ or torch.jit.is_tracing()
720
+ ):
721
+ return _no_op(x)
722
+
723
+ count = self.cpu_count
724
+ self.cpu_count += 1
725
+
726
+ if random.random() < 0.01:
727
+ # Occasionally sync self.cpu_count with self.count.
728
+ # count affects the decay of 'prob'. don't do this on every iter,
729
+ # because syncing with the GPU is slow.
730
+ self.cpu_count = max(self.cpu_count, self.count.item())
731
+ self.count.fill_(self.cpu_count)
732
+
733
+ # the prob of doing some work exponentially decreases from 0.5 till it hits
734
+ # a floor at min_prob (==0.1, by default)
735
+ prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0)))
736
+
737
+ if random.random() < prob:
738
+ sign_gain_factor = 0.5
739
+ if self.min_positive != 0.0 or self.max_positive != 1.0:
740
+ sign_factor = _compute_sign_factor(
741
+ x,
742
+ self.channel_dim,
743
+ self.min_positive,
744
+ self.max_positive,
745
+ gain_factor=self.sign_gain_factor / prob,
746
+ max_factor=self.max_factor,
747
+ )
748
+ else:
749
+ sign_factor = None
750
+
751
+ scale_factor = _compute_scale_factor(
752
+ x.detach(),
753
+ self.channel_dim,
754
+ min_abs=self.min_abs,
755
+ max_abs=self.max_abs,
756
+ gain_factor=self.scale_gain_factor / prob,
757
+ max_factor=self.max_factor,
758
+ )
759
+ return ActivationBalancerFunction.apply(
760
+ x,
761
+ scale_factor,
762
+ sign_factor,
763
+ self.channel_dim,
764
+ )
765
+ else:
766
+ return _no_op(x)
767
+
768
+
769
+ def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor:
770
+ """
771
+ Returns x unmodified, but in backprop will put a penalty for the excess of
772
+ the absolute values of elements of x over the limit "limit". E.g. if
773
+ limit == 10.0, then if x has any values over 10 it will get a penalty.
774
+
775
+ Caution: the value of this penalty will be affected by grad scaling used
776
+ in automatic mixed precision training. For this reasons we use this,
777
+ it shouldn't really matter, or may even be helpful; we just use this
778
+ to disallow really implausible values of scores to be given to softmax.
779
+ """
780
+ x_sign = x.sign()
781
+ over_limit = (x.abs() - limit) > 0
782
+ # The following is a memory efficient way to penalize the absolute values of
783
+ # x that's over the limit. (The memory efficiency comes when you think
784
+ # about which items torch needs to cache for the autograd, and which ones it
785
+ # can throw away). The numerical value of aux_loss as computed here will
786
+ # actually be larger than it should be, by limit * over_limit.sum(), but it
787
+ # has the same derivative as the real aux_loss which is penalty * (x.abs() -
788
+ # limit).relu().
789
+ aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x)
790
+ # note: we don't do sum() here on aux)_loss, but it's as if we had done
791
+ # sum() due to how with_loss() works.
792
+ x = with_loss(x, aux_loss)
793
+ # you must use x for something, or this will be ineffective.
794
+ return x
795
+
796
+
797
+ def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
798
+ if x.ndim == 2:
799
+ return x.diag()
800
+ else:
801
+ (batch, dim, dim) = x.shape
802
+ x = x.reshape(batch, dim * dim)
803
+ x = x[:, :: dim + 1]
804
+ assert x.shape == (batch, dim)
805
+ return x
806
+
807
+
808
+ def _whitening_metric(x: Tensor, num_groups: int):
809
+ """
810
+ Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of
811
+ of the centered feature covariance are the same within each group's covariance matrix
812
+ and also between groups.
813
+ Args:
814
+ x: a Tensor of shape (*, num_channels)
815
+ num_groups: the number of groups of channels, a number >=1 that divides num_channels
816
+ Returns:
817
+ Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and
818
+ greater than 1.0 otherwise.
819
+ """
820
+ assert x.dtype != torch.float16
821
+ x = x.reshape(-1, x.shape[-1])
822
+ (num_frames, num_channels) = x.shape
823
+ assert num_channels % num_groups == 0
824
+ channels_per_group = num_channels // num_groups
825
+ x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1)
826
+ # x now has shape (num_groups, num_frames, channels_per_group)
827
+ # subtract the mean so we use the centered, not uncentered, covariance.
828
+ # My experience has been that when we "mess with the gradients" like this,
829
+ # it's better not do anything that tries to move the mean around, because
830
+ # that can easily cause instability.
831
+ x = x - x.mean(dim=1, keepdim=True)
832
+ # x_covar: (num_groups, channels_per_group, channels_per_group)
833
+ x_covar = torch.matmul(x.transpose(1, 2), x)
834
+ x_covar_mean_diag = _diag(x_covar).mean()
835
+ # the following expression is what we'd get if we took the matrix product
836
+ # of each covariance and measured the mean of its trace, i.e.
837
+ # the same as _diag(torch.matmul(x_covar, x_covar)).mean().
838
+ x_covarsq_mean_diag = (x_covar ** 2).sum() / (
839
+ num_groups * channels_per_group
840
+ )
841
+ # this metric will be >= 1.0; the larger it is, the less 'white' the data was.
842
+ metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20)
843
+ return metric
844
+
845
+
846
+ class WhiteningPenaltyFunction(torch.autograd.Function):
847
+ @staticmethod
848
+ def forward(
849
+ ctx,
850
+ x: Tensor,
851
+ num_groups: int,
852
+ whitening_limit: float,
853
+ grad_scale: float,
854
+ ) -> Tensor:
855
+ ctx.save_for_backward(x)
856
+ ctx.num_groups = num_groups
857
+ ctx.whitening_limit = whitening_limit
858
+ ctx.grad_scale = grad_scale
859
+ return x
860
+
861
+ @staticmethod
862
+ def backward(ctx, x_grad: Tensor):
863
+ (x_orig,) = ctx.saved_tensors
864
+ with torch.enable_grad():
865
+ with torch.cuda.amp.autocast(enabled=False):
866
+ x_detached = x_orig.to(torch.float32).detach()
867
+ x_detached.requires_grad = True
868
+
869
+ metric = _whitening_metric(x_detached, ctx.num_groups)
870
+
871
+ if random.random() < 0.005 or __name__ == "__main__":
872
+ logging.info(
873
+ f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, "
874
+ f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}"
875
+ )
876
+
877
+ (metric - ctx.whitening_limit).relu().backward()
878
+ penalty_grad = x_detached.grad
879
+ scale = ctx.grad_scale * (
880
+ x_grad.to(torch.float32).norm()
881
+ / (penalty_grad.norm() + 1.0e-20)
882
+ )
883
+ penalty_grad = penalty_grad * scale
884
+ return x_grad + penalty_grad.to(x_grad.dtype), None, None, None
885
+
886
+
887
+ class Whiten(nn.Module):
888
+ def __init__(
889
+ self,
890
+ num_groups: int,
891
+ whitening_limit: float,
892
+ prob: Union[float, Tuple[float, float]],
893
+ grad_scale: float,
894
+ ):
895
+ """
896
+ Args:
897
+ num_groups: the number of groups to divide the channel dim into before
898
+ whitening. We will attempt to make the feature covariance
899
+ within each group, after mean subtraction, as "white" as possible,
900
+ while having the same trace across all groups.
901
+ whitening_limit: a value greater than 1.0, that dictates how much
902
+ freedom we have to violate the constraints. 1.0 would mean perfectly
903
+ white, with exactly the same trace across groups; larger values
904
+ give more freedom. E.g. 2.0.
905
+ prob: the probability with which we apply the gradient modification
906
+ (also affects the grad scale). May be supplied as a float,
907
+ or as a pair (min_prob, max_prob)
908
+
909
+ grad_scale: determines the scale on the gradient term from this object,
910
+ relative to the rest of the gradient on the attention weights.
911
+ E.g. 0.02 (you may want to use smaller values than this if prob is large)
912
+ """
913
+ super(Whiten, self).__init__()
914
+ assert num_groups >= 1
915
+ assert whitening_limit >= 1
916
+ assert grad_scale >= 0
917
+ self.num_groups = num_groups
918
+ self.whitening_limit = whitening_limit
919
+ if isinstance(prob, float):
920
+ assert 0 < prob <= 1
921
+ self.prob = prob
922
+ else:
923
+ (self.min_prob, self.max_prob) = prob
924
+ assert 0 < self.min_prob < self.max_prob <= 1
925
+ self.prob = self.max_prob
926
+
927
+ self.grad_scale = grad_scale
928
+
929
+ def forward(self, x: Tensor) -> Tensor:
930
+ """
931
+ In the forward pass, this function just returns the input unmodified.
932
+ In the backward pass, it will modify the gradients to ensure that the
933
+ distribution in each group has close to (lambda times I) as the covariance
934
+ after mean subtraction, with the same lambda across groups.
935
+ For whitening_limit > 1, there will be more freedom to violate this
936
+ constraint.
937
+
938
+ Args:
939
+ x: the input of shape (*, num_channels)
940
+
941
+ Returns:
942
+ x, unmodified. You should make sure
943
+ you use the returned value, or the graph will be freed
944
+ and nothing will happen in backprop.
945
+ """
946
+ if (
947
+ not x.requires_grad
948
+ or random.random() > self.prob
949
+ or self.grad_scale == 0
950
+ ):
951
+ return _no_op(x)
952
+ else:
953
+ if hasattr(self, "min_prob") and random.random() < 0.25:
954
+ # occasionally switch between min_prob and max_prob, based on whether
955
+ # we are above or below the threshold.
956
+ if (
957
+ _whitening_metric(x.to(torch.float32), self.num_groups)
958
+ > self.whitening_limit
959
+ ):
960
+ # there would be a change to the grad.
961
+ self.prob = self.max_prob
962
+ else:
963
+ self.prob = self.min_prob
964
+
965
+ return WhiteningPenaltyFunction.apply(
966
+ x, self.num_groups, self.whitening_limit, self.grad_scale
967
+ )
968
+
969
+
970
+ class WithLoss(torch.autograd.Function):
971
+ @staticmethod
972
+ def forward(ctx, x: Tensor, y: Tensor):
973
+ ctx.y_shape = y.shape
974
+ return x
975
+
976
+ @staticmethod
977
+ def backward(ctx, ans_grad: Tensor):
978
+ return ans_grad, torch.ones(
979
+ ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device
980
+ )
981
+
982
+
983
+ def with_loss(x, y):
984
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
985
+ return x
986
+ # returns x but adds y.sum() to the loss function.
987
+ return WithLoss.apply(x, y)
988
+
989
+
990
+ def _no_op(x: Tensor) -> Tensor:
991
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
992
+ return x
993
+ else:
994
+ # a no-op function that will have a node in the autograd graph,
995
+ # to avoid certain bugs relating to backward hooks
996
+ return x.chunk(1, dim=-1)[0]
997
+
998
+
999
+ class Identity(torch.nn.Module):
1000
+ def __init__(self):
1001
+ super(Identity, self).__init__()
1002
+
1003
+ def forward(self, x):
1004
+ return _no_op(x)
1005
+
1006
+
1007
+ class MaxEig(torch.nn.Module):
1008
+ """
1009
+ Modifies the backpropped derivatives of a function to try to discourage
1010
+ that any given direction in activation space accounts for more than
1011
+ a specified proportion of the covariance (e.g. 0.2).
1012
+
1013
+
1014
+ Args:
1015
+ num_channels: the number of channels
1016
+ channel_dim: the dimension/axis corresponding to the channel, e.g.
1017
+ -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
1018
+ max_var_per_eig: the maximum proportion of the variance of the
1019
+ features/channels, after mean subtraction, that can come from
1020
+ any given eigenvalue.
1021
+ min_prob: the minimum probability with which we apply this during any invocation
1022
+ of forward(), assuming last time we applied the constraint it was
1023
+ not active; supplied for speed.
1024
+ scale: determines the scale with which we modify the gradients, relative
1025
+ to the existing / unmodified gradients
1026
+ """
1027
+
1028
+ def __init__(
1029
+ self,
1030
+ num_channels: int,
1031
+ channel_dim: int,
1032
+ max_var_per_eig: float = 0.2,
1033
+ min_prob: float = 0.01,
1034
+ scale: float = 0.01,
1035
+ ):
1036
+ super(MaxEig, self).__init__()
1037
+ self.num_channels = num_channels
1038
+ self.channel_dim = channel_dim
1039
+ self.scale = scale
1040
+ assert max_var_per_eig == 0.0 or max_var_per_eig > 1.0 / num_channels
1041
+ self.max_var_per_eig = max_var_per_eig
1042
+
1043
+ # we figure out the dominant direction using the power method: starting with
1044
+ # a random vector, keep multiplying by the covariance and renormalizing.
1045
+ with torch.no_grad():
1046
+ # arbitrary.. would use randn() but want to leave the rest of the model's
1047
+ # random parameters unchanged for comparison
1048
+ direction = torch.arange(num_channels).to(torch.float)
1049
+ direction = direction / direction.norm()
1050
+ self.register_buffer("max_eig_direction", direction)
1051
+
1052
+ self.min_prob = min_prob
1053
+ # cur_prob is the current probability we'll use to apply the ActivationBalancer.
1054
+ # We'll regress this towards prob, each tiem we try to apply it and it is not
1055
+ # active.
1056
+ self.cur_prob = 1.0
1057
+
1058
+ def forward(self, x: Tensor) -> Tensor:
1059
+ if (
1060
+ torch.jit.is_scripting()
1061
+ or self.max_var_per_eig <= 0
1062
+ or random.random() > self.cur_prob
1063
+ or torch.jit.is_tracing()
1064
+ ):
1065
+ return _no_op(x)
1066
+
1067
+ with torch.cuda.amp.autocast(enabled=False):
1068
+ eps = 1.0e-20
1069
+ orig_x = x
1070
+ x = x.to(torch.float32)
1071
+ with torch.no_grad():
1072
+ x = x.transpose(self.channel_dim, -1).reshape(
1073
+ -1, self.num_channels
1074
+ )
1075
+ x = x - x.mean(dim=0)
1076
+ new_direction, coeffs = self._find_direction_coeffs(
1077
+ x, self.max_eig_direction
1078
+ )
1079
+ x_var = (x ** 2).mean()
1080
+ x_residual = x - coeffs * new_direction
1081
+ x_residual_var = (x_residual ** 2).mean()
1082
+
1083
+ # `variance_proportion` is the proportion of the variance accounted for
1084
+ # by the top eigen-direction.
1085
+ variance_proportion = (x_var - x_residual_var) / (
1086
+ x_var + 1.0e-20
1087
+ )
1088
+
1089
+ # ensure new direction is nonzero even if x == 0, by including `direction`.
1090
+ self._set_direction(
1091
+ 0.1 * self.max_eig_direction + new_direction
1092
+ )
1093
+
1094
+ if random.random() < 0.01 or __name__ == "__main__":
1095
+ logging.info(
1096
+ f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}"
1097
+ )
1098
+
1099
+ if variance_proportion >= self.max_var_per_eig:
1100
+ # The constraint is active. Note, we should quite rarely
1101
+ # reach here, only near the beginning of training if we are
1102
+ # starting to diverge, should this constraint be active.
1103
+ cur_prob = self.cur_prob
1104
+ self.cur_prob = (
1105
+ 1.0 # next time, do the update with probability 1.0.
1106
+ )
1107
+ return MaxEigLimiterFunction.apply(
1108
+ orig_x, coeffs, new_direction, self.channel_dim, self.scale
1109
+ )
1110
+ else:
1111
+ # let self.cur_prob exponentially approach self.min_prob, as
1112
+ # long as the constraint is inactive.
1113
+ self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob
1114
+ return orig_x
1115
+
1116
+ def _set_direction(self, direction: Tensor):
1117
+ """
1118
+ Sets self.max_eig_direction to a normalized version of `direction`
1119
+ """
1120
+ direction = direction.detach()
1121
+ direction = direction / direction.norm()
1122
+ direction_sum = direction.sum().item()
1123
+ if direction_sum - direction_sum == 0: # no inf/nan
1124
+ self.max_eig_direction[:] = direction
1125
+ else:
1126
+ logging.info(
1127
+ f"Warning: sum of direction in MaxEig is {direction_sum}, "
1128
+ "num_channels={self.num_channels}, channel_dim={self.channel_dim}"
1129
+ )
1130
+
1131
+ def _find_direction_coeffs(
1132
+ self, x: Tensor, prev_direction: Tensor
1133
+ ) -> Tuple[Tensor, Tensor, Tensor]:
1134
+ """
1135
+ Figure out (an approximation to) the proportion of the variance of a set of
1136
+ feature vectors that can be attributed to the top eigen-direction.
1137
+ Args:
1138
+ x: a Tensor of shape (num_frames, num_channels), with num_frames > 1.
1139
+ prev_direction: a Tensor of shape (num_channels,), that is our previous estimate
1140
+ of the top eigen-direction, or a random direction if this is the first
1141
+ iteration. Does not have to be normalized, but should be nonzero.
1142
+
1143
+ Returns: (cur_direction, coeffs), where:
1144
+ cur_direction: a Tensor of shape (num_channels,) that is the current
1145
+ estimate of the top eigen-direction.
1146
+ coeffs: a Tensor of shape (num_frames, 1) that minimizes, or
1147
+ approximately minimizes, (x - coeffs * cur_direction).norm()
1148
+ """
1149
+ (num_frames, num_channels) = x.shape
1150
+ assert num_channels > 1 and num_frames > 1
1151
+ assert prev_direction.shape == (num_channels,)
1152
+ # `coeffs` are the coefficients of `prev_direction` in x.
1153
+ # actually represent the coeffs up to a constant positive factor.
1154
+ coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10
1155
+ cur_direction = (x * coeffs).sum(dim=0) / (
1156
+ (coeffs ** 2).sum() + 1.0e-20
1157
+ )
1158
+ return cur_direction, coeffs
1159
+
1160
+
1161
+ class DoubleSwishFunction(torch.autograd.Function):
1162
+ """
1163
+ double_swish(x) = x * torch.sigmoid(x-1)
1164
+ This is a definition, originally motivated by its close numerical
1165
+ similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
1166
+
1167
+ Memory-efficient derivative computation:
1168
+ double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
1169
+ double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
1170
+ Now, s'(x) = s(x) * (1-s(x)).
1171
+ double_swish'(x) = x * s'(x) + s(x).
1172
+ = x * s(x) * (1-s(x)) + s(x).
1173
+ = double_swish(x) * (1-s(x)) + s(x)
1174
+ ... so we just need to remember s(x) but not x itself.
1175
+ """
1176
+
1177
+ @staticmethod
1178
+ def forward(ctx, x: Tensor) -> Tensor:
1179
+ requires_grad = x.requires_grad
1180
+ x_dtype = x.dtype
1181
+ if x.dtype == torch.float16:
1182
+ x = x.to(torch.float32)
1183
+
1184
+ s = torch.sigmoid(x - 1.0)
1185
+ y = x * s
1186
+
1187
+ if requires_grad:
1188
+ deriv = y * (1 - s) + s
1189
+ # notes on derivative of x * sigmoid(x - 1):
1190
+ # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29
1191
+ # min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund
1192
+ # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound.
1193
+ # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which
1194
+ # floors), should be expectation-preserving.
1195
+ floor = -0.043637
1196
+ ceil = 1.2
1197
+ d_scaled = (deriv - floor) * (
1198
+ 255.0 / (ceil - floor)
1199
+ ) + torch.rand_like(deriv)
1200
+ if __name__ == "__main__":
1201
+ # for self-testing only.
1202
+ assert d_scaled.min() >= 0.0
1203
+ assert d_scaled.max() < 256.0
1204
+ d_int = d_scaled.to(torch.uint8)
1205
+ ctx.save_for_backward(d_int)
1206
+ if x.dtype == torch.float16 or torch.is_autocast_enabled():
1207
+ y = y.to(torch.float16)
1208
+ return y
1209
+
1210
+ @staticmethod
1211
+ def backward(ctx, y_grad: Tensor) -> Tensor:
1212
+ (d,) = ctx.saved_tensors
1213
+ # the same constants as used in forward pass.
1214
+ floor = -0.043637
1215
+ ceil = 1.2
1216
+ d = d * ((ceil - floor) / 255.0) + floor
1217
+ return y_grad * d
1218
+
1219
+
1220
+ class DoubleSwish(torch.nn.Module):
1221
+ def forward(self, x: Tensor) -> Tensor:
1222
+ """Return double-swish activation function which is an approximation to Swish(Swish(x)),
1223
+ that we approximate closely with x * sigmoid(x-1).
1224
+ """
1225
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
1226
+ return x * torch.sigmoid(x - 1.0)
1227
+ return DoubleSwishFunction.apply(x)
1228
+
1229
+
1230
+ def BalancedDoubleSwish(
1231
+ d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25
1232
+ ) -> nn.Sequential:
1233
+ """
1234
+ ActivationBalancer -> DoubleSwish
1235
+ """
1236
+ balancer = ActivationBalancer(
1237
+ d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob
1238
+ )
1239
+ return nn.Sequential(
1240
+ balancer,
1241
+ DoubleSwish(),
1242
+ )
1243
+
1244
+
1245
+ def _test_max_eig():
1246
+ for proportion in [0.1, 0.5, 10.0]:
1247
+ logging.info(f"proportion = {proportion}")
1248
+ x = torch.randn(100, 128)
1249
+ direction = torch.randn(128)
1250
+ coeffs = torch.randn(100, 1)
1251
+ x += proportion * direction * coeffs
1252
+
1253
+ x.requires_grad = True
1254
+
1255
+ num_channels = 128
1256
+ m = MaxEig(
1257
+ num_channels, 1, 0.5, scale=0.1 # channel_dim # max_var_per_eig
1258
+ ) # grad_scale
1259
+
1260
+ for _ in range(4):
1261
+ y = m(x)
1262
+
1263
+ y_grad = torch.randn_like(x)
1264
+ y.backward(gradient=y_grad)
1265
+
1266
+ if proportion < 0.2:
1267
+ assert torch.allclose(x.grad, y_grad, atol=1.0e-02)
1268
+ elif proportion > 1.0:
1269
+ assert not torch.allclose(x.grad, y_grad)
1270
+
1271
+
1272
+ def _test_whiten():
1273
+ for proportion in [0.1, 0.5, 10.0]:
1274
+ logging.info(f"_test_whiten(): proportion = {proportion}")
1275
+ x = torch.randn(100, 128)
1276
+ direction = torch.randn(128)
1277
+ coeffs = torch.randn(100, 1)
1278
+ x += proportion * direction * coeffs
1279
+
1280
+ x.requires_grad = True
1281
+
1282
+ num_channels = 128
1283
+ m = Whiten(
1284
+ 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit,
1285
+ ) # grad_scale
1286
+
1287
+ for _ in range(4):
1288
+ y = m(x)
1289
+
1290
+ y_grad = torch.randn_like(x)
1291
+ y.backward(gradient=y_grad)
1292
+
1293
+ if proportion < 0.2:
1294
+ assert torch.allclose(x.grad, y_grad)
1295
+ elif proportion > 1.0:
1296
+ assert not torch.allclose(x.grad, y_grad)
1297
+
1298
+
1299
+ def _test_activation_balancer_sign():
1300
+ probs = torch.arange(0, 1, 0.01)
1301
+ N = 1000
1302
+ x = 1.0 * (
1303
+ (2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0
1304
+ )
1305
+ x = x.detach()
1306
+ x.requires_grad = True
1307
+ m = ActivationBalancer(
1308
+ probs.numel(),
1309
+ channel_dim=0,
1310
+ min_positive=0.05,
1311
+ max_positive=0.95,
1312
+ max_factor=0.2,
1313
+ min_abs=0.0,
1314
+ )
1315
+
1316
+ y_grad = torch.sign(torch.randn(probs.numel(), N))
1317
+
1318
+ y = m(x)
1319
+ y.backward(gradient=y_grad)
1320
+ print("_test_activation_balancer_sign: x = ", x)
1321
+ print("_test_activation_balancer_sign: y grad = ", y_grad)
1322
+ print("_test_activation_balancer_sign: x grad = ", x.grad)
1323
+
1324
+
1325
+ def _test_activation_balancer_magnitude():
1326
+ magnitudes = torch.arange(0, 1, 0.01)
1327
+ N = 1000
1328
+ x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(
1329
+ -1
1330
+ )
1331
+ x = x.detach()
1332
+ x.requires_grad = True
1333
+ m = ActivationBalancer(
1334
+ magnitudes.numel(),
1335
+ channel_dim=0,
1336
+ min_positive=0.0,
1337
+ max_positive=1.0,
1338
+ max_factor=0.2,
1339
+ min_abs=0.2,
1340
+ max_abs=0.8,
1341
+ min_prob=1.0,
1342
+ )
1343
+
1344
+ y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
1345
+
1346
+ y = m(x)
1347
+ y.backward(gradient=y_grad)
1348
+ print("_test_activation_balancer_magnitude: x = ", x)
1349
+ print("_test_activation_balancer_magnitude: y grad = ", y_grad)
1350
+ print("_test_activation_balancer_magnitude: x grad = ", x.grad)
1351
+
1352
+
1353
+ def _test_basic_norm():
1354
+ num_channels = 128
1355
+ m = BasicNorm(num_channels=num_channels, channel_dim=1)
1356
+
1357
+ x = torch.randn(500, num_channels)
1358
+
1359
+ y = m(x)
1360
+
1361
+ assert y.shape == x.shape
1362
+ x_rms = (x ** 2).mean().sqrt()
1363
+ y_rms = (y ** 2).mean().sqrt()
1364
+ print("x rms = ", x_rms)
1365
+ print("y rms = ", y_rms)
1366
+ assert y_rms < x_rms
1367
+ assert y_rms > 0.5 * x_rms
1368
+
1369
+
1370
+ def _test_double_swish_deriv():
1371
+ x = torch.randn(10, 12, dtype=torch.double) * 3.0
1372
+ x.requires_grad = True
1373
+ m = DoubleSwish()
1374
+
1375
+ tol = (1.2 - (-0.043637)) / 255.0
1376
+ torch.autograd.gradcheck(m, x, atol=tol)
1377
+
1378
+ # for self-test.
1379
+ x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
1380
+ x.requires_grad = True
1381
+ y = m(x)
1382
+
1383
+
1384
+ def _test_softmax():
1385
+ a = torch.randn(2, 10, dtype=torch.float64)
1386
+ b = a.clone()
1387
+ a.requires_grad = True
1388
+ b.requires_grad = True
1389
+ a.softmax(dim=1)[:, 0].sum().backward()
1390
+ print("a grad = ", a.grad)
1391
+ softmax(b, dim=1)[:, 0].sum().backward()
1392
+ print("b grad = ", b.grad)
1393
+ assert torch.allclose(a.grad, b.grad)
1394
+
1395
+
1396
+ if __name__ == "__main__":
1397
+ logging.getLogger().setLevel(logging.INFO)
1398
+ torch.set_num_threads(1)
1399
+ torch.set_num_interop_threads(1)
1400
+ _test_softmax()
1401
+ _test_whiten()
1402
+ _test_max_eig()
1403
+ _test_activation_balancer_sign()
1404
+ _test_activation_balancer_magnitude()
1405
+ _test_basic_norm()
1406
+ _test_double_swish_deriv()
models/modules/transformer.py ADDED
@@ -0,0 +1,1089 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy, logging
2
+ import numbers
3
+ from functools import partial
4
+ from typing import Any, Callable, List, Optional, Tuple, Union
5
+
6
+ import torch
7
+ from torch import Tensor, nn
8
+ from torch.nn import functional as F
9
+
10
+ from .activation import MultiheadAttention
11
+ from .scaling import ActivationBalancer, BalancedDoubleSwish
12
+ from .scaling import BasicNorm as _BasicNorm
13
+
14
+ _shape_t = Union[int, List[int], torch.Size]
15
+
16
+
17
+ class LayerNorm(nn.Module):
18
+ __constants__ = ["normalized_shape", "eps", "elementwise_affine"]
19
+ normalized_shape: Tuple[int, ...]
20
+ eps: float
21
+ elementwise_affine: bool
22
+
23
+ def __init__(
24
+ self,
25
+ normalized_shape: _shape_t,
26
+ eps: float = 1e-5,
27
+ elementwise_affine: bool = True,
28
+ device=None,
29
+ dtype=None,
30
+ ) -> None:
31
+ factory_kwargs = {"device": device, "dtype": dtype}
32
+ super(LayerNorm, self).__init__()
33
+ if isinstance(normalized_shape, numbers.Integral):
34
+ # mypy error: incompatible types in assignment
35
+ normalized_shape = (normalized_shape,) # type: ignore[assignment]
36
+ self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
37
+ self.eps = eps
38
+ self.elementwise_affine = elementwise_affine
39
+ if self.elementwise_affine:
40
+ self.weight = nn.Parameter(
41
+ torch.empty(self.normalized_shape, **factory_kwargs)
42
+ )
43
+ self.bias = nn.Parameter(
44
+ torch.empty(self.normalized_shape, **factory_kwargs)
45
+ )
46
+ else:
47
+ self.register_parameter("weight", None)
48
+ self.register_parameter("bias", None)
49
+
50
+ self.reset_parameters()
51
+
52
+ def reset_parameters(self) -> None:
53
+ if self.elementwise_affine:
54
+ nn.init.ones_(self.weight)
55
+ nn.init.zeros_(self.bias)
56
+
57
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
58
+ if isinstance(input, tuple):
59
+ input, embedding = input
60
+ return (
61
+ F.layer_norm(
62
+ input,
63
+ self.normalized_shape,
64
+ self.weight,
65
+ self.bias,
66
+ self.eps,
67
+ ),
68
+ embedding,
69
+ )
70
+
71
+ assert embedding is None
72
+ return F.layer_norm(
73
+ input, self.normalized_shape, self.weight, self.bias, self.eps
74
+ )
75
+
76
+ def extra_repr(self) -> str:
77
+ return (
78
+ "{normalized_shape}, eps={eps}, "
79
+ "elementwise_affine={elementwise_affine}".format(**self.__dict__)
80
+ )
81
+
82
+
83
+ class AdaptiveLayerNorm(nn.Module):
84
+ r"""Adaptive Layer Normalization"""
85
+
86
+ def __init__(self, d_model, norm) -> None:
87
+ super(AdaptiveLayerNorm, self).__init__()
88
+ self.project_layer = nn.Linear(d_model, 2 * d_model)
89
+ self.norm = norm
90
+ self.d_model = d_model
91
+ self.eps = self.norm.eps
92
+
93
+ def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
94
+ if isinstance(input, tuple):
95
+ input, embedding = input
96
+ weight, bias = torch.split(
97
+ self.project_layer(embedding),
98
+ split_size_or_sections=self.d_model,
99
+ dim=-1,
100
+ )
101
+ return (weight * self.norm(input) + bias, embedding)
102
+
103
+ weight, bias = torch.split(
104
+ self.project_layer(embedding),
105
+ split_size_or_sections=self.d_model,
106
+ dim=-1,
107
+ )
108
+ return weight * self.norm(input) + bias
109
+
110
+
111
+ class BasicNorm(_BasicNorm):
112
+ def __init__(
113
+ self,
114
+ d_model: int,
115
+ eps: float = 1e-5,
116
+ device=None,
117
+ dtype=None,
118
+ ):
119
+ super(BasicNorm, self).__init__(d_model, eps=eps)
120
+
121
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
122
+ if isinstance(input, tuple):
123
+ input, embedding = input
124
+ return (
125
+ super(BasicNorm, self).forward(input),
126
+ embedding,
127
+ )
128
+
129
+ assert embedding is None
130
+ return super(BasicNorm, self).forward(input)
131
+
132
+
133
+ class BalancedBasicNorm(nn.Module):
134
+ def __init__(
135
+ self,
136
+ d_model: int,
137
+ eps: float = 1e-5,
138
+ device=None,
139
+ dtype=None,
140
+ ):
141
+ super(BalancedBasicNorm, self).__init__()
142
+ self.balancer = ActivationBalancer(
143
+ d_model,
144
+ channel_dim=-1,
145
+ min_positive=0.45,
146
+ max_positive=0.55,
147
+ max_abs=6.0,
148
+ )
149
+ self.norm = BasicNorm(d_model, eps, device=device, dtype=dtype)
150
+
151
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
152
+ if isinstance(input, tuple):
153
+ input, embedding = input
154
+ return self.norm((self.balancer(input), embedding))
155
+
156
+ assert embedding is None
157
+ return self.norm(self.balancer(input))
158
+
159
+
160
+ class IdentityNorm(nn.Module):
161
+ def __init__(
162
+ self,
163
+ d_model: int,
164
+ eps: float = 1e-5,
165
+ device=None,
166
+ dtype=None,
167
+ ) -> None:
168
+ super(IdentityNorm, self).__init__()
169
+
170
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
171
+ if isinstance(input, tuple):
172
+ return input
173
+
174
+ assert embedding is None
175
+ return input
176
+
177
+
178
+ class TransformerEncoderLayer(nn.Module):
179
+ __constants__ = ["batch_first", "norm_first"]
180
+
181
+ def __init__(
182
+ self,
183
+ d_model: int,
184
+ nhead: int,
185
+ dim_feedforward: int = 2048,
186
+ dropout: float = 0.1,
187
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
188
+ batch_first: bool = False,
189
+ norm_first: bool = False,
190
+ device=None,
191
+ dtype=None,
192
+ linear1_self_attention_cls: nn.Module = nn.Linear,
193
+ linear2_self_attention_cls: nn.Module = nn.Linear,
194
+ linear1_feedforward_cls: nn.Module = nn.Linear,
195
+ linear2_feedforward_cls: nn.Module = nn.Linear,
196
+ layer_norm_cls: nn.Module = LayerNorm,
197
+ layer_norm_eps: float = 1e-5,
198
+ adaptive_layer_norm=False,
199
+ ) -> None:
200
+ factory_kwargs = {"device": device, "dtype": dtype}
201
+ super(TransformerEncoderLayer, self).__init__()
202
+ self.self_attn = MultiheadAttention(
203
+ d_model,
204
+ nhead,
205
+ dropout=dropout,
206
+ batch_first=batch_first,
207
+ linear1_cls=linear1_self_attention_cls,
208
+ linear2_cls=linear2_self_attention_cls,
209
+ **factory_kwargs,
210
+ )
211
+
212
+ # Implementation of Feedforward model
213
+ self.linear1 = linear1_feedforward_cls(
214
+ d_model, dim_feedforward, **factory_kwargs
215
+ )
216
+ self.dropout = nn.Dropout(dropout)
217
+ self.linear2 = linear2_feedforward_cls(
218
+ dim_feedforward, d_model, **factory_kwargs
219
+ )
220
+
221
+ self.norm_first = norm_first
222
+ self.dropout1 = nn.Dropout(dropout)
223
+ self.dropout2 = nn.Dropout(dropout)
224
+
225
+ # Legacy string support for activation function.
226
+ if isinstance(activation, str):
227
+ activation = _get_activation_fn(activation)
228
+ elif isinstance(activation, partial):
229
+ activation = activation(d_model)
230
+ elif activation == BalancedDoubleSwish:
231
+ activation = BalancedDoubleSwish(d_model)
232
+
233
+ # # We can't test self.activation in forward() in TorchScript,
234
+ # # so stash some information about it instead.
235
+ # if activation is F.relu or isinstance(activation, torch.nn.ReLU):
236
+ # self.activation_relu_or_gelu = 1
237
+ # elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
238
+ # self.activation_relu_or_gelu = 2
239
+ # else:
240
+ # self.activation_relu_or_gelu = 0
241
+ self.activation = activation
242
+
243
+ norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
244
+ if layer_norm_cls == IdentityNorm:
245
+ norm2 = BalancedBasicNorm(
246
+ d_model, eps=layer_norm_eps, **factory_kwargs
247
+ )
248
+ else:
249
+ norm2 = layer_norm_cls(
250
+ d_model, eps=layer_norm_eps, **factory_kwargs
251
+ )
252
+
253
+ if adaptive_layer_norm:
254
+ self.norm1 = AdaptiveLayerNorm(d_model, norm1)
255
+ self.norm2 = AdaptiveLayerNorm(d_model, norm2)
256
+ else:
257
+ self.norm1 = norm1
258
+ self.norm2 = norm2
259
+
260
+ def __setstate__(self, state):
261
+ super(TransformerEncoderLayer, self).__setstate__(state)
262
+ if not hasattr(self, "activation"):
263
+ self.activation = F.relu
264
+
265
+ def forward(
266
+ self,
267
+ src,
268
+ src_mask: Optional[Tensor] = None,
269
+ src_key_padding_mask: Optional[Tensor] = None,
270
+ need_weights: Optional[bool] = False,
271
+ past: Optional[Tensor] = None,
272
+ ) -> Tensor:
273
+ r"""Pass the input through the encoder layer.
274
+
275
+ Args:
276
+ src: the sequence to the encoder layer (required).
277
+ src_mask: the mask for the src sequence (optional).
278
+ src_key_padding_mask: the mask for the src keys per batch (optional).
279
+
280
+ Shape:
281
+ see the docs in Transformer class.
282
+ """
283
+ if isinstance(src, dict):
284
+ sinu = src["sinu"]
285
+ pm_sinu = src["pm_sinu"]
286
+ src = src["input"]
287
+ else:
288
+ sinu = None
289
+ pm_sinu = None
290
+ x, stage_embedding = src, None
291
+ is_src_tuple = False
292
+ if isinstance(src, tuple):
293
+ x, stage_embedding = src
294
+ is_src_tuple = True
295
+
296
+ if src_key_padding_mask is not None:
297
+ _skpm_dtype = src_key_padding_mask.dtype
298
+ if _skpm_dtype != torch.bool and not torch.is_floating_point(
299
+ src_key_padding_mask
300
+ ):
301
+ raise AssertionError(
302
+ "only bool and floating types of key_padding_mask are supported"
303
+ )
304
+ if need_weights:
305
+ raise NotImplementedError
306
+ if self.norm_first:
307
+ out, attn = self._sa_block_attn(
308
+ self.norm1(x, stage_embedding),
309
+ src_mask,
310
+ src_key_padding_mask,
311
+ past, sinu = sinu
312
+ )
313
+ out, present = out # present is the kvcache of the present timestep
314
+ x = x + out
315
+ x = x + self._ff_block(self.norm2(x, stage_embedding))
316
+ else:
317
+ out, attn = self._sa_block_attn(x, src_mask, src_key_padding_mask, past, sinu = sinu)
318
+ out, present = out # present is the kvcache of the present timestep
319
+ x = self.norm1(
320
+ x + out,
321
+ stage_embedding,
322
+ )
323
+ x = self.norm2(x + self._ff_block(x), stage_embedding)
324
+ assert not is_src_tuple
325
+ # return (x, stage_embedding)
326
+ return (x, attn)
327
+ else:
328
+ if self.norm_first:
329
+ out = self._sa_block(
330
+ self.norm1(x, stage_embedding),
331
+ src_mask,
332
+ src_key_padding_mask, past, sinu = sinu, q_sinu=pm_sinu['q'], k_sinu=pm_sinu['q']
333
+ )
334
+ out, present = out # present is the kvcache of the present timestep
335
+ x = x + out
336
+ x = x + self._ff_block(self.norm2(x, stage_embedding))
337
+ else:
338
+ out = self._sa_block(x, src_mask, src_key_padding_mask, sinu = sinu, q_sinu=pm_sinu['q'], k_sinu=pm_sinu['q'])
339
+ out, present = out # present is the kvcache of the present timestep
340
+ x = self.norm1(
341
+ x + out,
342
+ stage_embedding, past
343
+ )
344
+ x = self.norm2(x + self._ff_block(x), stage_embedding)
345
+
346
+ if is_src_tuple:
347
+ x = (x, stage_embedding)
348
+ if present != None:
349
+ x = [x, present]
350
+ return x
351
+
352
+ # self-attention block
353
+ def _sa_block(
354
+ self,
355
+ x: Tensor,
356
+ attn_mask: Optional[Tensor],
357
+ key_padding_mask: Optional[Tensor],
358
+ past: Optional[Tensor] = None,
359
+ sinu = None,
360
+ q_sinu = None,
361
+ k_sinu = None
362
+ ) -> Tensor:
363
+ x = self.self_attn(
364
+ x,
365
+ x,
366
+ x,
367
+ attn_mask=attn_mask,
368
+ key_padding_mask=key_padding_mask,
369
+ need_weights=False,
370
+ past=past,
371
+ sinu = sinu,
372
+ q_sinu = q_sinu,
373
+ k_sinu = k_sinu
374
+ )
375
+ x, present = x
376
+ return self.dropout1(x), present
377
+
378
+ # self-attention block, also return attention weights
379
+ def _sa_block_attn(
380
+ self,
381
+ x: Tensor,
382
+ attn_mask: Optional[Tensor],
383
+ key_padding_mask: Optional[Tensor],
384
+ past: Optional[Tensor] = None,
385
+ ) -> Tensor:
386
+ x, attn = self.self_attn(
387
+ x,
388
+ x,
389
+ x,
390
+ attn_mask=attn_mask,
391
+ key_padding_mask=key_padding_mask,
392
+ need_weights=True,
393
+ past=past
394
+ )
395
+ x, present = x
396
+ return (self.dropout1(x), present), attn
397
+
398
+ # feed forward block
399
+ def _ff_block(self, x: Tensor) -> Tensor:
400
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
401
+ return self.dropout2(x)
402
+
403
+ def pre_compute_sinusoidal(dim, base, max_len = 10000): # 4000 max length equivalent of mimi code is 320s, as mimi is 12.5hz
404
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
405
+ position_ids_expanded = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) # [x_len_max, 1]
406
+ inv_freq_expanded = inv_freq.unsqueeze(0).float() # [1, d//2]
407
+ freqs = position_ids_expanded @ inv_freq_expanded # [x_len_max, d//2]
408
+ freqs = torch.cat((freqs, freqs), dim=-1).unsqueeze(0) # [1, x_len_max, d]
409
+ return {"sin": freqs.sin(), "cos": freqs.cos()}
410
+
411
+ def pre_compute_freqs(dim, base, max_len = 10000): # 4000 max length equivalent of mimi code is 320s, as mimi is 12.5hz
412
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
413
+ position_ids_expanded = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) # [x_len_max, 1]
414
+ inv_freq_expanded = inv_freq.unsqueeze(0).float() # [1, d//2]
415
+ freqs = position_ids_expanded @ inv_freq_expanded # [x_len_max, d//2]
416
+ freqs = torch.cat((freqs, freqs), dim=-1).unsqueeze(0) # [1, x_len_max, d]
417
+ return freqs
418
+
419
+ class TransformerEncoder(nn.Module):
420
+ r"""TransformerEncoder is a stack of N encoder layers. Users can build the
421
+ BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
422
+
423
+ Args:
424
+ encoder_layer: an instance of the TransformerEncoderLayer() class (required).
425
+ num_layers: the number of sub-encoder-layers in the encoder (required).
426
+ norm: the layer normalization component (optional).
427
+ enable_nested_tensor: if True, input will automatically convert to nested tensor
428
+ (and convert back on output). This will improve the overall performance of
429
+ TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
430
+
431
+ Examples::
432
+ >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
433
+ >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)
434
+ >>> src = torch.rand(10, 32, 512)
435
+ >>> out = transformer_encoder(src)
436
+ """
437
+ __constants__ = ["norm"]
438
+
439
+ def __init__(self, encoder_layer, num_layers, norm=None, rope_base=None, d_model=None, nhead=None, args=None):
440
+ super(TransformerEncoder, self).__init__()
441
+ self.layers = _get_clones(encoder_layer, num_layers)
442
+ self.num_layers = num_layers
443
+ self.norm = norm
444
+ if args != None:
445
+ self.progress_no_multiple = args.progress_no_multiple
446
+ self.progress_scale = args.progress_scale
447
+ else:
448
+ self.progress_no_multiple = False
449
+ self.progress_scale = 1
450
+
451
+ if rope_base is not None:
452
+ if self.progress_no_multiple:
453
+ self.pm_freqs = pre_compute_freqs(d_model//nhead, rope_base)
454
+ self.sinu = None
455
+ else:
456
+ self.sinu = pre_compute_sinusoidal(d_model/nhead, rope_base)
457
+ self.pm_freqs = None
458
+ # logging.info(f"get precomputed sinusoidal for {rope_base=}: {self.sinu=}")
459
+ else:
460
+ self.sinu = None
461
+ self.pm_freqs = None
462
+
463
+ def forward(
464
+ self,
465
+ src: Tensor,
466
+ mask: Optional[Tensor] = None,
467
+ src_key_padding_mask: Optional[Tensor] = None,
468
+ return_layer_states: bool = False,
469
+ need_weights:Optional[bool] = False,
470
+ past: Optional[Tensor] = None,
471
+ ) -> Tensor:
472
+ r"""Pass the input through the encoder layers in turn.
473
+
474
+ Args:
475
+ src: the sequence to the encoder (required).
476
+ mask: the mask for the src sequence (optional).
477
+ src_key_padding_mask: the mask for the src keys per batch (optional).
478
+ return_layer_states: return layers' state (optional).
479
+
480
+ Shape:
481
+ see the docs in Transformer class.
482
+ """
483
+ if return_layer_states:
484
+ raise NotImplementedError
485
+ assert not need_weights
486
+ layer_states = [] # layers' output
487
+ output = src
488
+ for mod in self.layers:
489
+ output = mod(
490
+ output,
491
+ src_mask=mask,
492
+ src_key_padding_mask=src_key_padding_mask,
493
+ past=past
494
+ )
495
+ layer_states.append(output[0])
496
+
497
+ if self.norm is not None:
498
+ output = self.norm(output)
499
+
500
+ return layer_states, output
501
+ if need_weights:
502
+ raise NotImplementedError
503
+ assert not return_layer_states
504
+ layer_attn = [] # layers' output
505
+ output = src
506
+ for mod in self.layers:
507
+ output = mod(
508
+ output,
509
+ src_mask=mask,
510
+ src_key_padding_mask=src_key_padding_mask,
511
+ need_weights=True,
512
+ past=past
513
+ )
514
+ layer_attn.append(output[1])
515
+
516
+ if self.norm is not None:
517
+ output = self.norm(output)
518
+
519
+ return layer_attn, output
520
+
521
+ output = src
522
+ all_present = []
523
+ if self.sinu is not None:
524
+ # use rope
525
+ assert self.pm_freqs is None
526
+ for k, v in self.sinu.items():
527
+ self.sinu[k] = v.to(output.device)
528
+ if self.pm_freqs is not None:
529
+ assert self.sinu is None
530
+ self.pm_freqs = self.pm_freqs.to(output.device)
531
+ if src_key_padding_mask != None:
532
+ query_lens = (~src_key_padding_mask).int().sum(-1).to(output.device)
533
+ else:
534
+ query_lens = torch.tensor([output.shape[1]]*output.shape[0]).to(output.device)
535
+ assert query_lens.ndim==1, query_lens
536
+ q_lens_expanded = query_lens.unsqueeze(-1).unsqueeze(-1) # [B, 1, 1]
537
+ query_ids_multiple = q_lens_expanded / (q_lens_expanded - 1)
538
+ q_emb = self.pm_freqs * query_ids_multiple # [B, q_len_max, d]
539
+ q_emb = q_emb / q_lens_expanded * self.progress_scale
540
+ q_cos = q_emb.cos().unsqueeze(1) # [B, 1, q_len_max, d] # 1 is for nhead
541
+ q_sin = q_emb.sin().unsqueeze(1)
542
+ self.pm_sinu = {"q": {"cos": q_cos, "sin": q_sin}}
543
+ else:
544
+ self.pm_sinu = {"q": None}
545
+
546
+ output = {"input": output, "sinu": self.sinu, "pm_sinu": self.pm_sinu}
547
+ for n_layer, mod in enumerate(self.layers):
548
+ output = mod(
549
+ output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, past=None if past is None else past[n_layer]
550
+ )
551
+ if isinstance(output, list):
552
+ output, present = output
553
+ all_present.append(present)
554
+ if self.sinu is not None or self.pm_sinu is not None:
555
+ output = {"input": output, "sinu": self.sinu, "pm_sinu": self.pm_sinu}
556
+ if self.sinu is not None or self.pm_sinu is not None:
557
+ output = output["input"]
558
+ if self.norm is not None:
559
+ output = self.norm(output)
560
+ if all_present != []:
561
+ all_present = torch.stack(all_present, dim=0) # (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
562
+ output = [output, all_present]
563
+ return output
564
+
565
+
566
+ class TransformerDecoderLayer(nn.Module):
567
+ __constants__ = ["batch_first", "norm_first"]
568
+
569
+ def __init__(
570
+ self,
571
+ d_model: int,
572
+ nhead: int,
573
+ dim_feedforward: int = 2048,
574
+ dropout: float = 0.1,
575
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
576
+ linear1_self_attention_cls: nn.Module = nn.Linear,
577
+ linear2_self_attention_cls: nn.Module = nn.Linear,
578
+ linear1_feedforward_cls: nn.Module = nn.Linear,
579
+ linear2_feedforward_cls: nn.Module = nn.Linear,
580
+ batch_first: bool = False,
581
+ norm_first: bool = False,
582
+ device=None,
583
+ dtype=None,
584
+ layer_norm_cls: nn.Module = LayerNorm,
585
+ layer_norm_eps: float = 1e-5,
586
+ adaptive_layer_norm=False,
587
+ ) -> None:
588
+ factory_kwargs = {"device": device, "dtype": dtype}
589
+ super(TransformerDecoderLayer, self).__init__()
590
+ self.self_attn = MultiheadAttention(
591
+ d_model,
592
+ nhead,
593
+ dropout=dropout,
594
+ batch_first=batch_first,
595
+ linear1_cls=linear1_self_attention_cls,
596
+ linear2_cls=linear2_self_attention_cls,
597
+ **factory_kwargs,
598
+ )
599
+ self.multihead_attn = MultiheadAttention(
600
+ d_model,
601
+ nhead,
602
+ dropout=dropout,
603
+ batch_first=batch_first,
604
+ linear1_cls=linear1_self_attention_cls,
605
+ linear2_cls=linear2_self_attention_cls,
606
+ **factory_kwargs,
607
+ )
608
+ # Implementation of Feedforward model
609
+ self.linear1 = linear1_feedforward_cls(
610
+ d_model, dim_feedforward, **factory_kwargs
611
+ )
612
+ self.dropout = nn.Dropout(dropout)
613
+ self.linear2 = linear2_feedforward_cls(
614
+ dim_feedforward, d_model, **factory_kwargs
615
+ )
616
+
617
+ self.norm_first = norm_first
618
+ self.dropout1 = nn.Dropout(dropout)
619
+ self.dropout2 = nn.Dropout(dropout)
620
+ self.dropout3 = nn.Dropout(dropout)
621
+
622
+ # Legacy string support for activation function.
623
+ if isinstance(activation, str):
624
+ self.activation = _get_activation_fn(activation)
625
+ elif isinstance(activation, partial):
626
+ self.activation = activation(d_model)
627
+ elif activation == BalancedDoubleSwish:
628
+ self.activation = BalancedDoubleSwish(d_model)
629
+ else:
630
+ self.activation = activation
631
+
632
+ if adaptive_layer_norm:
633
+ norm1 = layer_norm_cls(
634
+ d_model, eps=layer_norm_eps, **factory_kwargs
635
+ )
636
+ norm2 = layer_norm_cls(
637
+ d_model, eps=layer_norm_eps, **factory_kwargs
638
+ )
639
+ norm3 = layer_norm_cls(
640
+ d_model, eps=layer_norm_eps, **factory_kwargs
641
+ )
642
+
643
+ self.norm1 = AdaptiveLayerNorm(d_model, norm1)
644
+ self.norm2 = AdaptiveLayerNorm(d_model, norm2)
645
+ self.norm3 = AdaptiveLayerNorm(d_model, norm3)
646
+ else:
647
+ self.norm1 = layer_norm_cls(
648
+ d_model, eps=layer_norm_eps, **factory_kwargs
649
+ )
650
+ self.norm2 = layer_norm_cls(
651
+ d_model, eps=layer_norm_eps, **factory_kwargs
652
+ )
653
+ if layer_norm_cls == IdentityNorm:
654
+ self.norm3 = BalancedBasicNorm(
655
+ d_model, eps=layer_norm_eps, **factory_kwargs
656
+ )
657
+ else:
658
+ self.norm3 = layer_norm_cls(
659
+ d_model, eps=layer_norm_eps, **factory_kwargs
660
+ )
661
+
662
+ def forward(
663
+ self,
664
+ tgt: Tensor,
665
+ memory: Tensor,
666
+ tgt_mask: Optional[Tensor] = None,
667
+ memory_mask: Optional[Tensor] = None,
668
+ tgt_key_padding_mask: Optional[Tensor] = None,
669
+ memory_key_padding_mask: Optional[Tensor] = None,
670
+ tgt_is_causal: Optional[bool] = False, # for compatibility with the nn.TransformerDecoder, not used
671
+ memory_is_causal: Optional[bool] = False, # for compatibility with the nn.TransformerDecoder, not used
672
+ past: Optional[Tensor] = None,
673
+ ) -> Tensor:
674
+ r"""Pass the inputs (and mask) through the decoder layer.
675
+
676
+ Args:
677
+ tgt: the sequence to the decoder layer (required).
678
+ memory: the sequence from the last layer of the encoder (required).
679
+ tgt_mask: the mask for the tgt sequence (optional).
680
+ memory_mask: the mask for the memory sequence (optional).
681
+ tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
682
+ memory_key_padding_mask: the mask for the memory keys per batch (optional).
683
+ past: the previous kvcache of the decoder (optional). shape: (2, batch_size, num_heads, seq_len, head_dim)
684
+
685
+ Shape:
686
+ see the docs in Transformer class.
687
+ """
688
+ if isinstance(tgt, dict):
689
+ pm_sinu = tgt["pm_sinu"]
690
+ sinu = tgt["sinu"]
691
+ args = tgt["args"]
692
+ tgt = tgt["input"]
693
+ else:
694
+ pm_sinu = None
695
+ sinu = None
696
+ args = None
697
+ tgt_is_tuple = False
698
+ if isinstance(tgt, tuple):
699
+ x, stage_embedding = tgt
700
+ tgt_is_tuple = True
701
+ else:
702
+ x, stage_embedding = tgt, None
703
+ # logging.info(f"{tgt_key_padding_mask=}, {memory_key_padding_mask=}")
704
+ # logging.info(f"{tgt_key_padding_mask.shape=}, {memory_key_padding_mask.shape=}")
705
+ # logging.info(f"{query_lens=}, {key_lens=}")
706
+
707
+ # past stores the kvcache for self-attention, and it can also be used to infer q_offset
708
+ if past is not None and past.ndim > 2:
709
+ q_offset = past[0].shape[-2] # past is (2, batch_size, num_heads, seq_len, head_dim), 2 contains [k, v], these are for self-attn, therefore also reflect the length of q
710
+ else:
711
+ q_offset = 0
712
+
713
+
714
+ if self.norm_first:
715
+ temp = self._sa_block(
716
+ self.norm1(x, stage_embedding), tgt_mask, tgt_key_padding_mask, q_sinu=pm_sinu['q'], k_sinu=pm_sinu['q'], sinu=sinu, args = args, past=past, q_offset=q_offset
717
+ )
718
+ present = temp[1]
719
+ x = x + temp[0]
720
+ cross_out = self._mha_block(
721
+ self.norm2(x, stage_embedding),
722
+ memory,
723
+ memory_mask,
724
+ memory_key_padding_mask, q_sinu=pm_sinu['q'], k_sinu=pm_sinu['k'], sinu=sinu, args = args, q_offset=q_offset
725
+ )
726
+ if isinstance(cross_out, dict):
727
+ attention_weights = cross_out["attention_weights"]
728
+ cross_out = cross_out["x"]
729
+ else:
730
+ attention_weights = None
731
+ x = x + cross_out
732
+ x = x + self._ff_block(self.norm3(x, stage_embedding))
733
+ else:
734
+ temp = self._sa_block(x, tgt_mask, tgt_key_padding_mask, q_sinu=pm_sinu['q'], k_sinu=pm_sinu['q'], sinu=sinu, args = args, past=past, q_offset=q_offset)
735
+ present = temp[1]
736
+ x = self.norm1(
737
+ x + temp[0],
738
+ stage_embedding,
739
+ )
740
+ cross_out = self._mha_block(
741
+ x, memory, memory_mask, memory_key_padding_mask, q_sinu=pm_sinu['q'], k_sinu=pm_sinu['k'], sinu=sinu, args=args, q_offset=q_offset
742
+ )
743
+ if isinstance(cross_out, dict):
744
+ attention_weights = cross_out["attention_weights"]
745
+ cross_out = cross_out["x"]
746
+ else:
747
+ attention_weights = None
748
+ x = self.norm2(
749
+ x
750
+ + cross_out,
751
+ stage_embedding,
752
+ )
753
+ x = self.norm3(x + self._ff_block(x), stage_embedding)
754
+
755
+ if attention_weights is not None:
756
+ x = {"x": x, "attention_weights": attention_weights}
757
+ if tgt_is_tuple:
758
+ x = (x, stage_embedding)
759
+ if present != None:
760
+ x = [x, present]
761
+ return x
762
+
763
+ # self-attention block
764
+ def _sa_block(
765
+ self,
766
+ x: Tensor,
767
+ attn_mask: Optional[Tensor],
768
+ key_padding_mask: Optional[Tensor],
769
+ q_sinu=None,
770
+ k_sinu=None,
771
+ sinu = None,
772
+ args = None,
773
+ past = None,
774
+ q_offset = 0
775
+ ) -> Tensor:
776
+ # if past is not None and past.ndim > 2:
777
+ # print(f"self-attn, k len: {past[0].shape[-2] + x.shape[-2]}, q len: {x.shape[-2]} q_offset: {q_offset}")
778
+ # else:
779
+ # print(f"self-attn, k len: {x.shape[-2]}, q len: {x.shape[-2]} q_offset: {q_offset}")
780
+ x = self.self_attn(
781
+ x,
782
+ x,
783
+ x,
784
+ attn_mask=attn_mask,
785
+ key_padding_mask=key_padding_mask,
786
+ need_weights=False,
787
+ q_sinu = q_sinu,
788
+ k_sinu = k_sinu,
789
+ sinu = sinu,
790
+ past = past,
791
+ q_offset = q_offset
792
+ )
793
+ x, present = x
794
+ return self.dropout1(x), present
795
+
796
+ # multihead attention block
797
+ def _mha_block(
798
+ self,
799
+ x: Tensor,
800
+ mem: Tensor,
801
+ attn_mask: Optional[Tensor],
802
+ key_padding_mask: Optional[Tensor],
803
+ q_sinu = None,
804
+ k_sinu = None,
805
+ sinu = None,
806
+ args = None,
807
+ q_offset = 0
808
+ ) -> Tensor:
809
+ # print(f"cross-attn, k len: {mem.shape[-2]}, q len: {x.shape[-2]} q_offset: {q_offset}")
810
+ x = self.multihead_attn(
811
+ x,
812
+ mem,
813
+ mem,
814
+ attn_mask=attn_mask,
815
+ key_padding_mask=key_padding_mask,
816
+ need_weights=False,
817
+ q_sinu = q_sinu,
818
+ k_sinu = k_sinu,
819
+ sinu = sinu,
820
+ args = args,
821
+ q_offset = q_offset
822
+ )
823
+ if len(x) == 2 and isinstance(x[0], dict) and "attention_weights" in x[0]:
824
+ x, present = x
825
+ attention_weights = x['attention_weights']
826
+ x = x['attn_output']
827
+ return {"x": self.dropout2(x), "attention_weights": attention_weights}
828
+ elif len(x) == 2:
829
+ x = x[0]
830
+ return self.dropout2(x)
831
+
832
+ # feed forward block
833
+ def _ff_block(self, x: Tensor) -> Tensor:
834
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
835
+ return self.dropout3(x)
836
+
837
+
838
+ def _get_clones(module, N):
839
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
840
+
841
+
842
+ def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
843
+ if activation == "relu":
844
+ return F.relu
845
+ elif activation == "gelu":
846
+ return F.gelu
847
+
848
+ raise RuntimeError(
849
+ "activation should be relu/gelu, not {}".format(activation)
850
+ )
851
+ def _generate_square_subsequent_mask(
852
+ sz: int,
853
+ device: Optional[torch.device] = None,
854
+ dtype: Optional[torch.dtype] = None,
855
+ ) -> Tensor:
856
+ r"""Generate a square causal mask for the sequence.
857
+
858
+ The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0).
859
+ """
860
+ if device is None:
861
+ device = torch.device('cpu')
862
+ if dtype is None:
863
+ dtype = torch.float32
864
+ return torch.triu(
865
+ torch.full((sz, sz), float('-inf'), dtype=dtype, device=device),
866
+ diagonal=1,
867
+ )
868
+ def _get_seq_len(
869
+ src: Tensor,
870
+ batch_first: bool
871
+ ) -> Optional[int]:
872
+
873
+ if src.is_nested:
874
+ return None
875
+ else:
876
+ src_size = src.size()
877
+ if len(src_size) == 2:
878
+ # unbatched: S, E
879
+ return src_size[0]
880
+ else:
881
+ # batched: B, S, E if batch_first else S, B, E
882
+ seq_len_pos = 1 if batch_first else 0
883
+ return src_size[seq_len_pos]
884
+
885
+ def _detect_is_causal_mask(
886
+ mask: Optional[Tensor],
887
+ is_causal: Optional[bool] = None,
888
+ size: Optional[int] = None,
889
+ ) -> bool:
890
+ """Return whether the given attention mask is causal.
891
+
892
+ Warning:
893
+ If ``is_causal`` is not ``None``, its value will be returned as is. If a
894
+ user supplies an incorrect ``is_causal`` hint,
895
+
896
+ ``is_causal=False`` when the mask is in fact a causal attention.mask
897
+ may lead to reduced performance relative to what would be achievable
898
+ with ``is_causal=True``;
899
+ ``is_causal=True`` when the mask is in fact not a causal attention.mask
900
+ may lead to incorrect and unpredictable execution - in some scenarios,
901
+ a causal mask may be applied based on the hint, in other execution
902
+ scenarios the specified mask may be used. The choice may not appear
903
+ to be deterministic, in that a number of factors like alignment,
904
+ hardware SKU, etc influence the decision whether to use a mask or
905
+ rely on the hint.
906
+ ``size`` if not None, check whether the mask is a causal mask of the provided size
907
+ Otherwise, checks for any causal mask.
908
+ """
909
+ # Prevent type refinement
910
+ make_causal = (is_causal is True)
911
+
912
+ if is_causal is None and mask is not None:
913
+ sz = size if size is not None else mask.size(-2)
914
+ causal_comparison = _generate_square_subsequent_mask(
915
+ sz, device=mask.device, dtype=mask.dtype)
916
+
917
+ # Do not use `torch.equal` so we handle batched masks by
918
+ # broadcasting the comparison.
919
+ if mask.size() == causal_comparison.size():
920
+ make_causal = bool((mask == causal_comparison).all())
921
+ else:
922
+ make_causal = False
923
+
924
+ return make_causal
925
+
926
+ class TransformerDecoder(nn.Module):
927
+ r"""TransformerDecoder is a stack of N decoder layers.
928
+
929
+ Args:
930
+ decoder_layer: an instance of the TransformerDecoderLayer() class (required).
931
+ num_layers: the number of sub-decoder-layers in the decoder (required).
932
+ norm: the layer normalization component (optional).
933
+
934
+ Examples::
935
+ >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
936
+ >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
937
+ >>> memory = torch.rand(10, 32, 512)
938
+ >>> tgt = torch.rand(20, 32, 512)
939
+ >>> out = transformer_decoder(tgt, memory)
940
+ """
941
+
942
+ __constants__ = ['norm']
943
+
944
+ def __init__(
945
+ self,
946
+ decoder_layer: "TransformerDecoderLayer",
947
+ num_layers: int,
948
+ norm: Optional[nn.Module] = None,
949
+ rope_base=None,
950
+ d_model=None,
951
+ nhead=None,
952
+ args=None
953
+ ) -> None:
954
+ super().__init__()
955
+ torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
956
+ self.layers = _get_clones(decoder_layer, num_layers)
957
+ self.num_layers = num_layers
958
+ self.norm = norm
959
+ self.args = args
960
+ if getattr(self.args, 'decoder_regular_rope', False):
961
+ self.sinu = pre_compute_sinusoidal(d_model/nhead, rope_base)
962
+ self.pm_freqs = None
963
+ else:
964
+ self.sinu = None
965
+ if rope_base is not None:
966
+ self.pm_freqs = pre_compute_freqs(d_model/nhead, rope_base)
967
+ # logging.info(f"get precomputed freqs for {rope_base=}: {self.freqs=}")
968
+ else:
969
+ self.pm_freqs = None
970
+ self.progress_scale = getattr(self.args, "progress_scale", 1.0)
971
+
972
+ def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
973
+ memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
974
+ memory_key_padding_mask: Optional[Tensor] = None, tgt_is_causal: Optional[bool] = None,
975
+ memory_is_causal: bool = False, query_lens: Optional[Tensor] = None, key_lens: Optional[Tensor] = None, past: Optional[Tensor] = None) -> Tensor:
976
+ r"""Pass the inputs (and mask) through the decoder layer in turn.
977
+
978
+ Args:
979
+ tgt: the sequence to the decoder (required).
980
+ memory: the sequence from the last layer of the encoder (required).
981
+ tgt_mask: the mask for the tgt sequence (optional).
982
+ memory_mask: the mask for the memory sequence (optional).
983
+ tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
984
+ memory_key_padding_mask: the mask for the memory keys per batch (optional).
985
+ tgt_is_causal: If specified, applies a causal mask as ``tgt mask``.
986
+ Default: ``None``; try to detect a causal mask.
987
+ Warning:
988
+ ``tgt_is_causal`` provides a hint that ``tgt_mask`` is
989
+ the causal mask. Providing incorrect hints can result in
990
+ incorrect execution, including forward and backward
991
+ compatibility.
992
+ memory_is_causal: If specified, applies a causal mask as
993
+ ``memory mask``.
994
+ Default: ``False``.
995
+ Warning:
996
+ ``memory_is_causal`` provides a hint that
997
+ ``memory_mask`` is the causal mask. Providing incorrect
998
+ hints can result in incorrect execution, including
999
+ forward and backward compatibility.
1000
+
1001
+ Shape:
1002
+ see the docs in :class:`~torch.nn.Transformer`.
1003
+ """
1004
+ output = tgt
1005
+
1006
+ # seq_len = _get_seq_len(tgt, self.layers[0].self_attn.batch_first)
1007
+ # tgt_is_causal = _detect_is_causal_mask(tgt_mask, tgt_is_causal, seq_len)
1008
+ if self.sinu is not None:
1009
+ assert self.pm_freqs is None
1010
+ for key in self.sinu:
1011
+ self.sinu[key] = self.sinu[key].to(output.device)
1012
+ if self.pm_freqs is not None:
1013
+ assert self.sinu is None
1014
+ if not self.training and hasattr(self, "pm_sinu") and past is not None and past[0].ndim > 2: # inference mode, will use cached sinu for the same example
1015
+ assert self.pm_sinu['q'] is not None and self.pm_sinu['k'] is not None
1016
+ # check batch size, need to modify the batch size if we use multi_trial during inference
1017
+ if self.pm_sinu['q']['cos'].shape[0] != tgt.shape[0]:
1018
+ if self.pm_sinu['q']['cos'].shape[0] > tgt.shape[0]:
1019
+ self.pm_sinu['q']['cos'] = self.pm_sinu['q']['cos'][:tgt.shape[0]]
1020
+ self.pm_sinu['q']['sin'] = self.pm_sinu['q']['sin'][:tgt.shape[0]]
1021
+ self.pm_sinu['k']['cos'] = self.pm_sinu['k']['cos'][:tgt.shape[0]]
1022
+ self.pm_sinu['k']['sin'] = self.pm_sinu['k']['sin'][:tgt.shape[0]]
1023
+ else:
1024
+ assert self.pm_sinu['q']['cos'].shape[0] == 1
1025
+ self.pm_sinu['q']['cos'] = self.pm_sinu['q']['cos'].repeat(tgt.shape[0], 1, 1, 1)
1026
+ self.pm_sinu['q']['sin'] = self.pm_sinu['q']['sin'].repeat(tgt.shape[0], 1, 1, 1)
1027
+ self.pm_sinu['k']['cos'] = self.pm_sinu['k']['cos'].repeat(tgt.shape[0], 1, 1, 1)
1028
+ self.pm_sinu['k']['sin'] = self.pm_sinu['k']['sin'].repeat(tgt.shape[0], 1, 1, 1)
1029
+ pass
1030
+ else:
1031
+ self.pm_freqs = self.pm_freqs.to(output.device)
1032
+ if query_lens is None:
1033
+ query_lens = (~tgt_key_padding_mask).int().sum(-1).to(tgt.device)
1034
+ if key_lens is None:
1035
+ key_lens = (~memory_key_padding_mask).int().sum(-1).to(tgt.device)
1036
+ assert key_lens.ndim==1, key_lens
1037
+ assert query_lens.ndim==1, query_lens
1038
+ q_lens_expanded = query_lens.unsqueeze(-1).unsqueeze(-1) # [B, 1, 1]
1039
+ k_lens_expanded = key_lens.unsqueeze(-1).unsqueeze(-1) # [B, 1, 1]
1040
+ query_ids_multiple = q_lens_expanded / (q_lens_expanded - 1)
1041
+ key_ids_multiple = k_lens_expanded / (k_lens_expanded - 1)
1042
+ q_emb = self.pm_freqs * query_ids_multiple # [B, q_len_max, d]
1043
+ k_emb = self.pm_freqs * key_ids_multiple # [B, k_len_max, d]
1044
+ q_emb = q_emb / q_lens_expanded * self.progress_scale
1045
+ k_emb = k_emb / k_lens_expanded * self.progress_scale
1046
+ q_cos = q_emb.cos().unsqueeze(1) # [B, 1, q_len_max, d] # 1 is for nhead
1047
+ q_sin = q_emb.sin().unsqueeze(1)
1048
+ k_cos = k_emb.cos().unsqueeze(1)
1049
+ k_sin = k_emb.sin().unsqueeze(1)
1050
+ self.pm_sinu = {"q": {"cos": q_cos, "sin": q_sin}, "k": {"cos": k_cos, "sin": k_sin}}
1051
+ else:
1052
+ self.pm_sinu = {"q": None, "k": None}
1053
+
1054
+ output = {"input": output, "pm_sinu": self.pm_sinu, "sinu": self.sinu, "args": self.args}
1055
+ if past != None:
1056
+ all_present = []
1057
+ if self.training and getattr(self.args, "attention_alignment_loss", 0):
1058
+ all_attn_weights = []
1059
+ for i, mod in enumerate(self.layers):
1060
+ output = mod(output, memory, tgt_mask=tgt_mask,
1061
+ memory_mask=memory_mask,
1062
+ tgt_key_padding_mask=tgt_key_padding_mask,
1063
+ memory_key_padding_mask=memory_key_padding_mask,
1064
+ past=past[i] if past != None else None
1065
+ # tgt_is_causal=tgt_is_causal,
1066
+ # memory_is_causal=memory_is_causal
1067
+ )
1068
+ if past != None:
1069
+ output, cur_present = output
1070
+ all_present.append(cur_present)
1071
+ if isinstance(output, dict):
1072
+ current_attn_weights = output["attention_weights"]
1073
+ all_attn_weights.append(current_attn_weights)
1074
+ output = output["x"]
1075
+ if self.sinu is not None or self.pm_sinu is not None:
1076
+ output = {"input": output, "pm_sinu": self.pm_sinu, "sinu": self.sinu, "args": self.args}
1077
+ if self.pm_sinu is not None or self.sinu is not None:
1078
+ output = output["input"]
1079
+ if self.norm is not None:
1080
+ output = self.norm(output)
1081
+ if self.training and getattr(self.args, "attention_alignment_loss", 0):
1082
+ assert len(all_attn_weights) == self.num_layers, f"{len(all_attn_weights)=}, {self.num_layers=}"
1083
+ output = {"output": output, "attention_weights": all_attn_weights}
1084
+ if past != None:
1085
+ all_present = torch.stack(all_present, dim=0)
1086
+ output = [output, all_present]
1087
+ else:
1088
+ output = [output, None]
1089
+ return output
models/modules/utils.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
5
+ """
6
+ Args:
7
+ lengths:
8
+ A 1-D tensor containing sentence lengths.
9
+ max_len:
10
+ The length of masks.
11
+ Returns:
12
+ Return a 2-D bool tensor, where masked positions
13
+ are filled with `True` and non-masked positions are
14
+ filled with `False`.
15
+
16
+ >>> lengths = torch.tensor([1, 3, 2, 5])
17
+ >>> make_pad_mask(lengths)
18
+ tensor([[False, True, True, True, True],
19
+ [False, False, False, True, True],
20
+ [False, False, True, True, True],
21
+ [False, False, False, False, False]])
22
+ """
23
+ assert lengths.ndim == 1, lengths.ndim
24
+ max_len = max(max_len, lengths.max())
25
+ n = lengths.size(0)
26
+ seq_range = torch.arange(0, max_len, device=lengths.device)
27
+ expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
28
+
29
+ return expaned_lengths >= lengths.unsqueeze(-1)
30
+
31
+ def generate_partial_autoregressive_mask(sz, start, end):
32
+ mask = torch.zeros(sz, sz).bool()
33
+ mask[start:end, start:end] = torch.triu(torch.ones(end-start, end-start,dtype=torch.bool), diagonal=1)
34
+ mask[:start, start:end] = True
35
+ mask[end:, start:end] = True
36
+ return mask
models/modules/visualizer.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # cp from https://github.com/lifeiteng/vall-e/blob/main/valle/models/visualizer.py
2
+ #!/usr/bin/env python3
3
+ # Copyright 2023 (authors: Feiteng Li)
4
+ #
5
+ # See ../../../../LICENSE for clarification regarding multiple authors
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+
20
+ from typing import Dict, List, Tuple, Union
21
+
22
+ import matplotlib.pyplot as plt
23
+ import numpy as np
24
+ import torch
25
+
26
+
27
+ def visualize(
28
+ predicts: Tuple[torch.Tensor],
29
+ batch: Dict[str, Union[List, torch.Tensor]],
30
+ output_dir: str,
31
+ limit: int = 4,
32
+ ) -> None:
33
+ text_tokens = batch["text_tokens"].to("cpu").detach().numpy()
34
+ text_tokens_lens = batch["text_tokens_lens"].to("cpu").detach().numpy()
35
+ audio_features = batch["audio_features"].to("cpu").detach().numpy()
36
+ audio_features_lens = (
37
+ batch["audio_features_lens"].to("cpu").detach().numpy()
38
+ )
39
+ assert text_tokens.ndim == 2
40
+
41
+ utt_ids, texts = batch["utt_id"], batch["text"]
42
+
43
+ encoder_outputs = predicts[0].to("cpu").type(torch.float32).detach().numpy()
44
+ decoder_outputs = predicts[1]
45
+ if isinstance(decoder_outputs, list):
46
+ decoder_outputs = decoder_outputs[-1]
47
+ decoder_outputs = (
48
+ decoder_outputs.to("cpu").type(torch.float32).detach().numpy()
49
+ )
50
+
51
+ vmin, vmax = 0, 1024 # Encodec
52
+ if decoder_outputs.dtype == np.float32:
53
+ vmin, vmax = -6, 0 # Fbank
54
+
55
+ num_figures = 3
56
+ for b, (utt_id, text) in enumerate(zip(utt_ids[:limit], texts[:limit])):
57
+ _ = plt.figure(figsize=(14, 8 * num_figures))
58
+
59
+ S = text_tokens_lens[b]
60
+ T = audio_features_lens[b]
61
+
62
+ # encoder
63
+ plt.subplot(num_figures, 1, 1)
64
+ plt.title(f"Text: {text}")
65
+ plt.imshow(
66
+ X=np.transpose(encoder_outputs[b]),
67
+ cmap=plt.get_cmap("jet"),
68
+ aspect="auto",
69
+ interpolation="nearest",
70
+ )
71
+ plt.gca().invert_yaxis()
72
+ plt.axvline(x=S - 0.4, linewidth=2, color="r")
73
+ plt.xlabel("Encoder Output")
74
+ plt.colorbar()
75
+
76
+ # decoder
77
+ plt.subplot(num_figures, 1, 2)
78
+ plt.imshow(
79
+ X=np.transpose(decoder_outputs[b]),
80
+ cmap=plt.get_cmap("jet"),
81
+ aspect="auto",
82
+ interpolation="nearest",
83
+ vmin=vmin,
84
+ vmax=vmax,
85
+ )
86
+ plt.gca().invert_yaxis()
87
+ plt.axvline(x=T - 0.4, linewidth=2, color="r")
88
+ plt.xlabel("Decoder Output")
89
+ plt.colorbar()
90
+
91
+ # target
92
+ plt.subplot(num_figures, 1, 3)
93
+ plt.imshow(
94
+ X=np.transpose(audio_features[b]),
95
+ cmap=plt.get_cmap("jet"),
96
+ aspect="auto",
97
+ interpolation="nearest",
98
+ vmin=vmin,
99
+ vmax=vmax,
100
+ )
101
+ plt.gca().invert_yaxis()
102
+ plt.axvline(x=T - 0.4, linewidth=2, color="r")
103
+ plt.xlabel("Decoder Target")
104
+ plt.colorbar()
105
+
106
+ plt.savefig(f"{output_dir}/{utt_id}.png")
107
+ plt.close()
models/voice_star.py ADDED
@@ -0,0 +1,784 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random, os, copy
2
+ from typing import Dict, Iterator, List, Tuple, Union
3
+ import logging
4
+ import numpy as np
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torchmetrics.classification import MulticlassAccuracy
10
+ import torch.distributed as dist
11
+
12
+ from .modules.utils import make_pad_mask, generate_partial_autoregressive_mask
13
+
14
+ from .modules.embedding import SinePositionalEmbedding, TokenEmbedding, SinePositionalEmbedding_progress
15
+ from .modules.transformer import (
16
+ AdaptiveLayerNorm,
17
+ LayerNorm,
18
+ TransformerDecoderLayer,
19
+ TransformerDecoder,
20
+ TransformerEncoder,
21
+ TransformerEncoderLayer,
22
+ )
23
+
24
+ def top_k_top_p_filtering(
25
+ logits, top_k=0, top_p=1.0, min_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
26
+ ):
27
+ """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
28
+ Args:
29
+ logits: logits distribution shape (batch size, vocabulary size)
30
+ if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
31
+ if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
32
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
33
+ Make sure we keep at least min_tokens_to_keep per batch example in the output
34
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
35
+ """
36
+ if min_p < 1.0:
37
+ probs = F.softmax(logits, dim=-1)
38
+ indices_to_remove = probs < min_p
39
+ if not torch.any(indices_to_remove.sum(-1) == logits.size(-1)):
40
+ logits[indices_to_remove] = filter_value
41
+ top_k = 0
42
+ top_p = 1.0
43
+ # else will use other types of sampling, or no filtering
44
+
45
+ # If top_k is a single integer
46
+ if isinstance(top_k, int) and top_k > 0:
47
+ # Safety check to ensure we don't ask for more than available
48
+ top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1))
49
+
50
+ # Remove all tokens with a probability less than the last token of the top-k
51
+ threshold = torch.topk(logits, top_k, dim=-1)[0][..., -1, None]
52
+ indices_to_remove = logits < threshold
53
+ logits[indices_to_remove] = filter_value
54
+
55
+ # If top_k is a list, assume it has the same length as M
56
+ elif isinstance(top_k, list):
57
+ # Ensure the length matches the first dimension
58
+ assert len(top_k) == logits.size(0), \
59
+ f"top_k list length ({len(top_k)}) must match logits.size(0) ({logits.size(0)})"
60
+
61
+ for i in range(logits.size(0)):
62
+ k_i = top_k[i]
63
+ if k_i > 0:
64
+ # Safety check
65
+ k_i = min(max(k_i, min_tokens_to_keep), logits.size(-1))
66
+ row_threshold = torch.topk(logits[i], k_i, dim=-1)[0][-1]
67
+ indices_to_remove_i = logits[i] < row_threshold
68
+ logits[i, indices_to_remove_i] = filter_value
69
+
70
+ if top_p < 1.0:
71
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
72
+ cumulative_probs = torch.cumsum(
73
+ F.softmax(sorted_logits, dim=-1), dim=-1
74
+ )
75
+
76
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
77
+ sorted_indices_to_remove = cumulative_probs > top_p
78
+ if min_tokens_to_keep > 1:
79
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
80
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
81
+ # Shift the indices to the right to keep also the first token above the threshold
82
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
83
+ ..., :-1
84
+ ].clone()
85
+ sorted_indices_to_remove[..., 0] = 0
86
+
87
+ return logits
88
+
89
+
90
+ def topk_sampling(logits, top_k=10, top_p=1.0, min_p=1.0, temperature=1.0):
91
+ # temperature: (`optional`) float
92
+ # The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
93
+ # top_k: (`optional`) int
94
+ # The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
95
+ # top_p: (`optional`) float
96
+ # The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
97
+
98
+ # Temperature (higher temperature => more likely to sample low probability tokens)
99
+ if temperature != 1.0:
100
+ logits = logits / temperature
101
+ # Top-p/top-k filtering
102
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p, min_p=min_p)
103
+ # Sample
104
+ token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
105
+ return token
106
+
107
+
108
+
109
+ class VoiceStar(nn.Module):
110
+ def __init__(self, args):
111
+ super().__init__()
112
+ self.args = args
113
+ assert self.args.enc_dec ^ self.args.dec, f"self.args.enc_dec: {self.args.enc_dec}, self.args.dec: {self.args.dec}"
114
+ if not getattr(self.args, "special_first", False):
115
+ self.args.special_first = 0
116
+ if not getattr(self.args, "n_special", False):
117
+ self.args.n_special = 3
118
+ self.args.eos = getattr(self.args, "eos", -1)
119
+ self.eog = nn.Parameter(torch.full((self.args.n_codebooks, 1), self.args.eog, dtype=torch.long), requires_grad=False) # [K 1]
120
+ if self.args.eos > 0:
121
+ assert self.args.eos != self.args.audio_pad_token and self.args.eos != self.args.empty_token, self.args.eos
122
+ self.eos = nn.Parameter(torch.full((self.args.n_codebooks, 1), self.args.eos, dtype=torch.long), requires_grad=False) # [K 1]
123
+ if type(self.args.audio_vocab_size) == str:
124
+ self.args.audio_vocab_size = eval(self.args.audio_vocab_size)
125
+ if type(self.args.audio_vocab_size) == list: # otherwise they are all lists
126
+ assert self.args.special_first
127
+
128
+
129
+ self.n_text_tokens = self.args.text_vocab_size + 1
130
+ assert self.args.text_pad_token == self.args.text_vocab_size, f"self.args.text_vocab_size: {self.args.text_vocab_size}, self.args.text_pad_token: {self.args.text_pad_token}"
131
+
132
+ if self.args.special_first and type(self.args.audio_vocab_size) == list:
133
+ self.n_audio_tokens = [tok + self.args.n_special for tok in self.args.audio_vocab_size] # special tokens: empty token, EOG token, audio pad token
134
+ assert self.args.empty_token == 0, self.args.empty_token
135
+ assert self.args.eog == 1, self.args.eog
136
+ assert self.args.audio_pad_token == 2, self.args.audio_pad_token
137
+ else:
138
+ self.n_audio_tokens = [self.args.audio_vocab_size + self.args.n_special] * self.args.n_codebooks # special tokens: empty token, EOG token, audio pad token
139
+ assert self.args.audio_vocab_size == self.args.empty_token, self.args.empty_token
140
+ assert self.args.eog == self.args.audio_vocab_size + 1, self.args.eog
141
+ assert self.args.audio_pad_token == self.args.audio_vocab_size + 2, self.args.audio_pad_token
142
+
143
+ self.text_embedding = TokenEmbedding(
144
+ dim_model=self.args.d_model,
145
+ vocab_size=self.n_text_tokens,
146
+ dropout=self.args.text_embedding_dropout
147
+ )
148
+
149
+ self.audio_embedding = nn.ModuleList(
150
+ [
151
+ TokenEmbedding(
152
+ dim_model=self.args.audio_embedding_dim,
153
+ vocab_size=self.n_audio_tokens[k],
154
+ dropout=self.args.audio_embedding_dropout
155
+ ) for k in range(self.args.n_codebooks)
156
+ ]
157
+ )
158
+
159
+ rope_base = getattr(self.args, "rope_base", None)
160
+ use_sinusoidal = getattr(self.args, "use_sinusoidal", False)
161
+ use_sinusoidal_progress = getattr(self.args, "use_sinusoidal_progress", False)
162
+ logging.info(f"rope_base: {rope_base}, use_sinusoidal: {use_sinusoidal}")
163
+ if use_sinusoidal:
164
+ self.text_positional_embedding = SinePositionalEmbedding(
165
+ self.args.d_model,
166
+ dropout=self.args.text_positional_embedding_dropout,
167
+ scale=False,
168
+ alpha=True, # learnable scaler, scale the volume of positional embedding
169
+ )
170
+ self.audio_positional_embedding = SinePositionalEmbedding(
171
+ self.args.d_model,
172
+ dropout=self.args.audio_positional_embedding_dropout,
173
+ scale=False,
174
+ alpha=True, # learnable scaler, scale the volume of positional embedding
175
+ )
176
+ elif use_sinusoidal_progress:
177
+ self.text_positional_embedding = SinePositionalEmbedding_progress(
178
+ self.args.d_model,
179
+ dropout=self.args.text_positional_embedding_dropout,
180
+ scale=False,
181
+ alpha=True, # learnable scaler, scale the volume of positional embedding
182
+ args = self.args
183
+ )
184
+ self.audio_positional_embedding = SinePositionalEmbedding_progress(
185
+ self.args.d_model,
186
+ dropout=self.args.audio_positional_embedding_dropout,
187
+ scale=False,
188
+ alpha=True, # learnable scaler, scale the volume of positional embedding
189
+ args = self.args
190
+ )
191
+
192
+ else:
193
+ class NoOp:
194
+ def __init__(self):
195
+ pass
196
+ def __call__(self, *args, **kwargs):
197
+ return args[0]
198
+ self.text_positional_embedding = NoOp()
199
+ self.audio_positional_embedding = NoOp()
200
+
201
+ if self.args.enc_dec:
202
+ enc_layer = TransformerEncoderLayer(
203
+ d_model=self.args.d_model,
204
+ nhead=self.args.nhead,
205
+ dim_feedforward=self.args.d_model*4,
206
+ dropout=self.args.trm_dropout,
207
+ batch_first=True,
208
+ norm_first=True,
209
+ layer_norm_cls=LayerNorm
210
+ ) # use the pre-norm arch
211
+
212
+ self.encoder = TransformerEncoder(
213
+ encoder_layer=enc_layer,
214
+ num_layers=self.args.num_encoder_layers,
215
+ norm=LayerNorm(self.args.d_model),
216
+ rope_base = self.args.rope_base,
217
+ d_model = self.args.d_model,
218
+ nhead = self.args.nhead,
219
+ args = self.args
220
+ ) # use the pre-norm arch
221
+
222
+ dec_layer = TransformerDecoderLayer(
223
+ d_model=self.args.d_model,
224
+ nhead=self.args.nhead,
225
+ dim_feedforward=self.args.d_model*4,
226
+ dropout=self.args.trm_dropout,
227
+ batch_first=True,
228
+ norm_first=True,
229
+ layer_norm_cls=LayerNorm
230
+ )
231
+
232
+ self.decoder = TransformerDecoder(
233
+ decoder_layer=dec_layer,
234
+ num_layers=self.args.num_decoder_layers,
235
+ norm=LayerNorm(self.args.d_model),
236
+ rope_base = self.args.rope_base,
237
+ d_model = self.args.d_model,
238
+ nhead = self.args.nhead,
239
+ args = self.args
240
+ ) # NOTE: this one I use torch.nn native implementation, as it's not implemented in .modules
241
+
242
+ else:
243
+ dec_layer = TransformerEncoderLayer(
244
+ self.args.d_model,
245
+ self.args.nhead,
246
+ dim_feedforward=self.args.d_model * 4,
247
+ dropout=self.args.trm_dropout,
248
+ batch_first=True,
249
+ norm_first=True,
250
+ layer_norm_cls=LayerNorm
251
+ )
252
+ self.decoder = TransformerEncoder(
253
+ dec_layer,
254
+ num_layers=self.args.num_decoder_layers,
255
+ norm=LayerNorm(self.args.d_model),
256
+ )
257
+
258
+ if type(self.args.audio_vocab_size) == int:
259
+ self.predict_layer = nn.ModuleList(
260
+ [
261
+ nn.Sequential(nn.Linear(self.args.d_model, self.args.audio_vocab_size//2), nn.GELU(), nn.Linear(self.args.audio_vocab_size//2, self.n_audio_tokens[k])) for k in range(self.args.n_codebooks)
262
+ ]
263
+ )
264
+ else:
265
+ self.predict_layer = nn.ModuleList(
266
+ [
267
+ nn.Sequential(nn.Linear(self.args.d_model, self.args.d_model//2), nn.GELU(), nn.Linear(self.args.d_model//2, self.n_audio_tokens[k])) for k in range(self.args.n_codebooks)
268
+ ]
269
+ )
270
+
271
+ self.accuracy_metrics = nn.ModuleList(
272
+ [MulticlassAccuracy(
273
+ self.n_audio_tokens[k],
274
+ top_k=10,
275
+ average="micro",
276
+ multidim_average="global",
277
+ ignore_index=None,
278
+ ) for k in range(self.args.n_codebooks)]
279
+ )
280
+
281
+ if self.args.eog_weight != 1:
282
+ raise NotImplementedError("now have different vocab_size for different codebooks, therefore currently don't support eog_weight")
283
+ self.class_weight = nn.Parameter(torch.ones(self.n_audio_tokens), requires_grad=False)
284
+ self.class_weight.data[self.args.eog] = self.args.eog_weight
285
+
286
+ def dec_forward(
287
+ self,
288
+ x_input,
289
+ x_lens,
290
+ x_attention_mask,
291
+ x_padding_mask,
292
+ y_input,
293
+ new_y_lens,
294
+ y_attention_mask,
295
+ y_padding_mask,
296
+ need_weights=False,
297
+ past=None,
298
+ last_3_tokens=False
299
+ ):
300
+ x_attn_mask = F.pad(
301
+ x_attention_mask,
302
+ (0, new_y_lens.max()),
303
+ value=True,
304
+ ) # x attn to all x, doesn't attn to any y, this follow figure 3 of the valle paper
305
+ y_attn_mask = F.pad(
306
+ y_attention_mask,
307
+ (x_lens.max(), 0), # y is padded at the front
308
+ value=False,
309
+ ) # y attn to all x, for y itself use lower triangle mask to ensure autoregressive
310
+ xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
311
+
312
+ # merge key padding and attention masks
313
+ bsz, src_len = x_input.shape[0], x_lens.max() + new_y_lens.max()
314
+ xy_padding_mask = torch.concat([x_padding_mask, y_padding_mask], dim=1)
315
+ _xy_padding_mask = (
316
+ xy_padding_mask.view(bsz, 1, 1, src_len)
317
+ .expand(-1, self.args.nhead, -1, -1)
318
+ .reshape(bsz * self.args.nhead, 1, src_len)
319
+ )
320
+ xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
321
+
322
+ new_attn_mask = torch.zeros_like(xy_attn_mask)
323
+ new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
324
+ xy_attn_mask = new_attn_mask
325
+
326
+ xy_input = torch.cat([x_input, y_input], dim=1)
327
+ if need_weights:
328
+ raise NotImplementedError("not implemented yet")
329
+ out, layer_attn_weights = self.decoder((xy_input, None), mask=xy_attn_mask, need_weights=True)
330
+ return layer_attn_weights
331
+
332
+ if past == None: # do not use kvcache
333
+ out, _ = self.decoder((xy_input, None), mask=xy_attn_mask)
334
+ return out[:, x_lens.max():], None
335
+ else: # use kvcache
336
+ if past.ndim > 3: # uses kvcache, only need to pass the last tokens, this doesn't work with multi-span speech editing yet
337
+ if last_3_tokens:
338
+ xy_input = xy_input[:, -3:]
339
+ xy_attn_mask = xy_attn_mask[:, -3:]
340
+ else:
341
+ xy_input = xy_input[:, -1:]
342
+ xy_attn_mask = xy_attn_mask[:, -1:]
343
+
344
+ out, present = self.decoder((xy_input, None), mask=xy_attn_mask, past=past)
345
+ if isinstance(out, tuple): # get rid of stage_embedding
346
+ out = out[0]
347
+
348
+ if out.shape[1] > x_lens.max(): # the first pass, not kvcache yet
349
+ return out[:, x_lens.max():], present
350
+ else: # used kvcache
351
+ return out, present
352
+
353
+ def enc_dec_forward(
354
+ self,
355
+ xa,
356
+ x_attention_mask,
357
+ x_padding_mask,
358
+ y_input,
359
+ new_y_lens,
360
+ y_attention_mask,
361
+ y_padding_mask,
362
+ tgt_y_lens=None,
363
+ need_weights=False,
364
+ past=None,
365
+ last_3_tokens=False
366
+ ):
367
+ assert not need_weights
368
+ if past != None and past.ndim > 3:
369
+ y_input = y_input[:, -1:]
370
+ y_attention_mask = y_attention_mask[-1:]
371
+ yhat, present = self.decoder(tgt=y_input, memory=xa, tgt_mask=y_attention_mask, tgt_key_padding_mask=y_padding_mask, memory_key_padding_mask=x_padding_mask, query_lens=tgt_y_lens, past=past)
372
+ return yhat, present
373
+
374
+ def forward(self, batch, calc_loss = False):
375
+ """
376
+ Args:
377
+ x:
378
+ A 2-D tensor of shape (N, S).
379
+ x_lens:
380
+ A 1-D tensor of shape (N,). It contains the number of tokens in `x`
381
+ before padding.
382
+ y:
383
+ A 3-D tensor of shape (N, K, T).
384
+ where K is the number of codebooks
385
+ y_lens:
386
+ A 1-D tensor of shape (N,). It contains the number of tokens in `x`
387
+ before padding.
388
+ """
389
+ x, x_lens, y, y_lens = batch["x"], batch["x_lens"], batch["y"], batch["y_lens"]
390
+ if len(x) == 0:
391
+ return None
392
+ x = x[:, :x_lens.max()] # this deal with gradient accumulation, where x_lens.max() might not be longer than the length of the current slice of x
393
+ y = y[...,:y_lens.max()]
394
+ assert x.ndim == 2, x.shape
395
+ assert x_lens.ndim == 1, x_lens.shape
396
+ assert y.ndim == 3 and y.shape[1] == self.args.n_codebooks, y.shape
397
+ assert y_lens.ndim == 1, y_lens.shape
398
+ x_padding_mask = make_pad_mask(x_lens).to(x.device)
399
+ x_attention_mask = torch.triu(torch.ones(x.shape[1], x.shape[1]), diagonal=1).bool().to(x_padding_mask.device)
400
+ x_input = self.text_embedding(x)
401
+ x_input = self.text_positional_embedding(x_input, x_lens)
402
+ y_with_eos = [torch.cat([item[:, :y_lens[i]], self.eos], dim=-1) for i, item in enumerate(y)]
403
+ targets = y_with_eos
404
+ # apply delayed stacking on y
405
+ shifted_y = []
406
+ patterns = []
407
+ new_y_lens = []
408
+ if getattr(self, "empty_tokens", None) == None:
409
+ self.empty_tokens = torch.full((self.args.n_codebooks, self.args.n_codebooks), self.args.empty_token, dtype=torch.long).to(y.device) # [K, K]
410
+ for i in range(len(y)):
411
+ tmp = torch.cat([y_with_eos[i], self.empty_tokens], dim=-1) # [K, T+n_codebooks]
412
+ for ii in range(self.args.n_codebooks):
413
+ tmp[ii] = torch.roll(tmp[ii], shifts=ii+1, dims=0)
414
+ shifted_y.append(tmp.transpose(1,0)) # [K, T+n_codebooks] -> [T+n_codebooks, K]
415
+ new_y_lens.append(y_with_eos[i].shape[1] + self.empty_tokens.shape[1])
416
+
417
+ new_y_lens = torch.LongTensor(new_y_lens).to(y.device)
418
+
419
+ cated_y = torch.nn.utils.rnn.pad_sequence(shifted_y, batch_first=False, padding_value=self.args.audio_pad_token)
420
+ assert cated_y.shape == torch.Size([max(new_y_lens), len(y), self.args.n_codebooks]), cated_y.shape
421
+ cated_y = cated_y.permute(2,0,1) # [T,B,K]->[K,T,B]
422
+ stacked_embedded_y = torch.stack([self.audio_embedding[k](cated_y[k]) for k in range(self.args.n_codebooks)], dim=0) # [K, T, B, D]
423
+ assert stacked_embedded_y.shape[0] == self.args.n_codebooks and stacked_embedded_y.shape[2] == len(y) and stacked_embedded_y.shape[-1] == self.args.d_model, stacked_embedded_y.shape
424
+ embedded_y = stacked_embedded_y.sum(dim=0) # [K,T,B,D]->[T,B,D]
425
+ embedded_y = embedded_y.transpose(1,0) # [T,B,D]->[B,T,D]
426
+ assert embedded_y.shape[1:] == torch.Size([max(new_y_lens), self.args.d_model]), embedded_y.shape
427
+ y_input = self.audio_positional_embedding(embedded_y, new_y_lens)
428
+ y_padding_mask = make_pad_mask(new_y_lens).to(y.device)
429
+ y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y_padding_mask.device)
430
+ if self.args.dec:
431
+ y_out = self.dec_forward(
432
+ x_input,
433
+ x_lens,
434
+ x_attention_mask,
435
+ x_padding_mask,
436
+ y_input,
437
+ new_y_lens,
438
+ y_attention_mask,
439
+ y_padding_mask
440
+ )
441
+ else:
442
+ xa = self.encoder(src=x_input, src_key_padding_mask=x_padding_mask)
443
+ y_out = self.enc_dec_forward(
444
+ xa,
445
+ x_attention_mask,
446
+ x_padding_mask,
447
+ y_input,
448
+ new_y_lens,
449
+ y_attention_mask,
450
+ y_padding_mask
451
+ )
452
+ y_out = y_out[0] # no kv-caching during training
453
+ assert y_out.shape == y_input.shape, f"y_out.shape: {y_out.shape}, y_input.shape: {y_input.shape}" # [B S D]
454
+ logits = torch.stack([self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)], dim=1) # [B K S card]
455
+ assert logits.shape[1] == self.args.n_codebooks and logits.shape[3] == self.n_audio_tokens[0], logits.shape
456
+ logits_use = [logit[:, :new_y_lens[i]] for i, logit in enumerate(logits)] # each of shape [K, T, card]
457
+ logits_final = []
458
+ for i, logit in enumerate(logits_use):
459
+ logit_copy = logit.clone()
460
+ for ii in range(self.args.n_codebooks):
461
+ logit_copy[ii] = torch.roll(logit_copy[ii], shifts=-ii, dims=0)
462
+ logit = logit_copy[:, :-self.args.n_codebooks] # [K, T, card] -> [K, T-n_codebooks, card]
463
+ logits_final.append(logit)
464
+ if self.args.no_loss_on_prefix:
465
+ assert "y_sep_token_position" in batch, f"y_sep_token_position should be in batch, but it's not"
466
+ logit_temp = []
467
+ target_temp = []
468
+ for jj, (logit, target) in enumerate(zip(logits_final, targets)):
469
+ # TODO already taken into consideration in depth transformer
470
+ logit_temp.append(logit[:, batch['y_sep_token_position'][jj]:])
471
+ target_temp.append(target[:, batch['y_sep_token_position'][jj]:])
472
+ logits_final = logit_temp
473
+ targets = target_temp
474
+ logits = torch.cat(logits_final, dim=1) # [K, T1+T2+T3+..., card]
475
+ targets = torch.cat(targets, dim=1) # [K, T1+T2+T3+...]
476
+
477
+ assert targets.shape[:2] == logits.shape[:2], f"{targets.shape}, {logits.shape}"
478
+ loss = []
479
+ ntokens = []
480
+ top10acc = []
481
+ for k, (logit, target) in enumerate(zip(logits, targets)): # even though the loss and top10acc is calculated in a loop (loop through n_codebooks), validation is still taking a lot of mem, need to optimize this a little more
482
+ loss.append(F.cross_entropy(logit, target, reduction='mean', weight=self.class_weight.data if self.args.eog_weight!=1 else None, ignore_index=self.args.y_sep_token if self.args.y_sep_token != None else -100)) # ignore audio sep token as it's unpredictable (like the random early stop bug happened in 2023)
483
+ # NOTE have to ignore the sep token in the loss calculation
484
+ top10acc.append(self.accuracy_metrics[k](logit.detach(), target))
485
+ ntokens.append(len(logit))
486
+
487
+ all_ntokens = sum(ntokens)
488
+ if self.args.codebook_weight != None:
489
+ codebook_weight = eval(self.args.codebook_weight) if isinstance(self.args.codebook_weight, str) else self.args.codebook_weight
490
+ else:
491
+ codebook_weight = [1.] * self.args.n_codebooks
492
+ perplexity_by_codebook = [torch.exp(l) for l in loss]
493
+ loss = sum([l*nt*cw for l, nt, cw in zip(loss, ntokens, codebook_weight)])
494
+
495
+ top10acc_by_codebook = [t10a*nt for t10a, nt in zip(top10acc, ntokens)]
496
+ top10acc = sum(top10acc_by_codebook)
497
+
498
+ ntokens = torch.tensor(all_ntokens).to(logits.device)
499
+
500
+ ret = {
501
+ "loss": loss,
502
+ "perplexity_by_codebook": perplexity_by_codebook,
503
+ "top10acc": top10acc,
504
+ "top10acc_by_codebook": top10acc_by_codebook,
505
+ "effective_ntoken": ntokens,
506
+ }
507
+
508
+ return ret
509
+
510
+ def inference_tts(
511
+ self,
512
+ x: torch.Tensor,
513
+ x_lens: torch.Tensor,
514
+ y: torch.Tensor,
515
+ tgt_y_lens: torch.Tensor, #
516
+ top_k: Union[int, list[int]]=-100,
517
+ top_p: float=1.0,
518
+ min_p: float=1.0,
519
+ temperature: float=1.0,
520
+ stop_repetition: int=3,
521
+ kvcache: int=1,
522
+ silence_tokens: list[int]=[],
523
+ multi_trial: list[int]=[],
524
+ *kargs
525
+ ) -> torch.Tensor:
526
+ """
527
+ different from inference_tts, this implementation uses kvcache, which should have significant speed up
528
+ Args:
529
+ x:
530
+ A 2-D tensor of shape (1, L).
531
+ x_lens:
532
+ A 1-D tensor of shape (1,). It contains the number of tokens in `x`
533
+ before padding.
534
+ y:
535
+ A 3-D tensor of shape (1, T, K).
536
+ tgt_y_lens:
537
+ *new arg* this specify the target length of y
538
+ top_k: (`optional`) int
539
+ The number of highest probability tokens to keep for top-k-filtering. Default to -100.
540
+ top_p: (`optional`) float
541
+ For Neucleus sampling
542
+ min_p: (`optional`) float
543
+ For min_p filtered sampling
544
+ temperature: (`optional`) float
545
+ The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
546
+ multi_trial: (`optional`) list[int]
547
+ If not empty, it will be [n_trials, beam_size, trial_interval]
548
+ from the start and begining trial_interval, we duplicate the current sample by beam_size,
549
+ at the end of every trial_interval, we choose the sample with the highest log likelihood to keep and throw away the rest
550
+ """
551
+ eog_inference = self.args.eos if self.args.eos>0 else self.args.eog
552
+ assert x.ndim == 2, x.shape
553
+ assert x_lens.ndim == 1, x_lens.shape
554
+ assert y.ndim == 3, y.shape
555
+ if self.args.special_first:
556
+ y = y + int(self.args.n_special)
557
+ y = y.transpose(2,1) # [1,T,K] -> [1,K,T]
558
+ assert y.shape[0] == 1 and y.shape[1] == self.args.n_codebooks, y.shape # there is no padding
559
+
560
+ # make x attention mask and x_input
561
+ x_attention_mask = torch.triu(torch.ones(x.shape[1], x.shape[1]), diagonal=1).bool().to(x.device)
562
+ # x_attention_mask = torch.zeros(x.shape[1], x.shape[1]).bool().to(x.device)
563
+ x_input = self.text_embedding(x)
564
+ x_input = self.text_positional_embedding(x_input, x_lens)
565
+
566
+ y_len = y.shape[2]
567
+ y_lens = torch.LongTensor([y_len]).to(y.device)
568
+
569
+ # rearrange y, we don't add eog to the end, this doesn't actually do anything in the tts scenario
570
+ rearranged_y = [[y[0]]]
571
+ assert rearranged_y[0][0].shape[0] == self.args.n_codebooks, rearranged_y[0][0].shape
572
+
573
+ # # shift y to create the delayed pattern
574
+ if getattr(self, "empty_tokens", None) == None:
575
+ self.empty_tokens = torch.full((self.args.n_codebooks, self.args.n_codebooks), self.args.empty_token, dtype=torch.long).to(y.device) # [K, K]
576
+ temp = rearranged_y[0][0]
577
+ assert temp.ndim == 2 and temp.shape[0] == self.args.n_codebooks, temp.shape
578
+ temp = torch.cat([temp, self.empty_tokens], dim=-1) # [K, T+n_codebooks]
579
+ for ii in range(self.args.n_codebooks):
580
+ temp[ii] = torch.roll(temp[ii], shifts=ii+1, dims=0)
581
+ shifted_y = [[temp]]
582
+
583
+ # below is different from forward or inference
584
+ # where we cut this shifted part
585
+ shifted_y[0][0] = shifted_y[0][0][:, :-(self.args.n_codebooks-1)]
586
+ assert not (shifted_y[0][0][self.args.n_codebooks:] == self.args.empty_token).any() and not (shifted_y[0][0][self.args.n_codebooks:] == self.args.eog).any(), shifted_y[0][0]
587
+ # next section in inference is insert mask at the intersection of each tensor in a sample, but we don't need to do that
588
+ # next section is concate tensors of each sample to one tensor, which we also don't need
589
+ cated_y = shifted_y[0][0].unsqueeze(-1) #[K,S]->[K,S,B]
590
+ new_y_lens = torch.LongTensor([cated_y.shape[1]]).to(cated_y.device)
591
+ assert cated_y.shape == torch.Size((self.args.n_codebooks, cated_y.shape[1], 1))
592
+ assert not (cated_y == self.args.audio_pad_token).any(), cated_y
593
+
594
+ # replace tokens in y with the embeddings, add sum codebooks up
595
+ embedded_y = torch.stack([self.audio_embedding[k](cated_y[k]) for k in range(self.args.n_codebooks)], dim=0) # [K, S, B, D]
596
+ assert embedded_y.shape[0] == self.args.n_codebooks, embedded_y.shape
597
+ assert embedded_y.shape[-1] == self.args.d_model, embedded_y.shape
598
+ embedded_y = embedded_y.sum(dim=0) # [K,S,B,D]->[S,B,D]
599
+ embedded_y = embedded_y.transpose(1,0) # [S,B,D]->[B,S,D]
600
+
601
+ # positional embedding
602
+ y_input = self.audio_positional_embedding(embedded_y, tgt_y_lens)
603
+
604
+ # make attention mask and padding mask
605
+ y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device)
606
+
607
+ x_padding_mask = torch.full((1,x_lens[0]), False).to(x.device)
608
+ y_padding_mask = torch.full((1,new_y_lens[0]), False).to(y.device)
609
+
610
+ # entering the generation stage
611
+ # starting from line 708
612
+ codebook_eog = [False] * self.args.n_codebooks
613
+ generated = [] # doesn't contain any empty token, contain eog
614
+ cur_generated = []
615
+ # say 0 is empty, 4 is eog
616
+ # tensor([[ 1, 2, 3, 4, 0, 0],
617
+ # [ 0, 1, 2, 3, 4, 0],
618
+ # [ 0, 0, 1, 2, 3, 4]])
619
+ num_gen = []
620
+ cur_num_gen = 0
621
+ ##################### silence repetition handling #####################
622
+ ##################### silence repetition handling #####################
623
+ # silence_tokens = [1388,1898,131] # [1388, 2045, 2041, 1996]
624
+ # silence_tokens = []
625
+ consec_silence_count = 0
626
+ prev_token = None
627
+ ##################### silence repetition handling #####################
628
+ ##################### silence repetition handling #####################
629
+
630
+ def sample_helper(n_eog, logits, codebook_eog, top_k, top_p, min_p, temperature, prev_token, consec_silence_count, stop_repetition, silence_tokens, cur_num_gen):
631
+ if n_eog == 0:
632
+ logits_adjust = logits
633
+ for jj in range(1,self.args.n_codebooks):
634
+ logits_adjust[jj][eog_inference] = -10000
635
+ logits_adjust[jj][self.args.empty_token] = -10000
636
+ if cur_num_gen <= self.args.encodec_sr // 5: # this shouldn't happen, but just in case the model stopped too early
637
+ logits_adjust[0][eog_inference] = -10000
638
+ ##################### silence repetition handling #####################
639
+ if stop_repetition > 0 and prev_token in silence_tokens and consec_silence_count > stop_repetition:
640
+ if logits_adjust[0, prev_token] < 0:
641
+ logits_adjust[0, prev_token] = logits_adjust[0, prev_token] * (consec_silence_count - (stop_repetition-1))
642
+ else:
643
+ logits_adjust[0, prev_token] = logits_adjust[0, prev_token] / (consec_silence_count - (stop_repetition-1))
644
+ ##################### silence repetition handling #####################
645
+ samples = topk_sampling(
646
+ logits_adjust, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature
647
+ ) # [K, 1]
648
+ assert samples.shape == torch.Size((self.args.n_codebooks, 1)), f"samples.shape: {samples.shape}"
649
+ if cur_num_gen < self.args.n_codebooks-1:
650
+ for jj in range(1, self.args.n_codebooks - cur_num_gen):
651
+ samples[-jj, 0] = self.args.empty_token
652
+
653
+ if (
654
+ samples[0,0] == eog_inference or torch.argmax(logits[0], dim=-1) == eog_inference or y_input.shape[1] > x_lens[0] * (self.args.encodec_sr//4)
655
+ ) or self.args.rope_base is not None and not self.args.decoder_regular_rope and self.args.progress_no_multiple and cur_num_gen > (tgt_y_lens[0] + self.args.encodec_sr * getattr(self.args, "extra_cutoff", 5)):
656
+ # last one condition in the first bracket means y is already too long, shouldn't happen, but put it here
657
+ # the second bracket means we are using progress-monitoring RoPE, but the model is generating excessively long sequence (5 seconds more than specified), in which case we terminate the generation
658
+ samples[0,0] = eog_inference
659
+ codebook_eog[0] = True
660
+ ##################### silence repetition handling #####################
661
+ if samples[0,0] in silence_tokens and samples[0,0] == prev_token:
662
+ consec_silence_count += 1
663
+ else:
664
+ consec_silence_count = 0
665
+ prev_token = samples[0,0]
666
+ ##################### silence repetition handling #####################
667
+ return samples, codebook_eog, prev_token, consec_silence_count
668
+ else:
669
+ assert sum(codebook_eog[i] for i in range(n_eog)) == n_eog, f"codebook_eog: {codebook_eog}, but n_eog: {n_eog}"
670
+ logits_adjust = logits
671
+ for jj in range(n_eog+1,self.args.n_codebooks):
672
+ logits_adjust[jj][eog_inference] = -10000
673
+ logits_adjust[jj][self.args.empty_token] = -10000
674
+ samples = topk_sampling(
675
+ logits_adjust, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature
676
+ ) # [K, 1]
677
+ for jj in range(n_eog):
678
+ samples[jj, 0] = self.args.empty_token
679
+ samples[n_eog, 0] = eog_inference
680
+ codebook_eog[n_eog] = True
681
+ return samples, codebook_eog, prev_token, consec_silence_count
682
+
683
+ # prepare the cache placeholder
684
+ # n_layers, 2, bsz, num_heads, src_len, head_dim, 2 means [key, value]
685
+ past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float32) if kvcache else None
686
+ if self.args.enc_dec:
687
+ xa = self.encoder(src=x_input, src_key_padding_mask=x_padding_mask)
688
+ while True:
689
+ if self.args.dec:
690
+ y_out, present = self.dec_forward(
691
+ x_input,
692
+ x_lens,
693
+ x_attention_mask,
694
+ x_padding_mask,
695
+ y_input,
696
+ new_y_lens,
697
+ y_attention_mask,
698
+ y_padding_mask,
699
+ past=past
700
+ )
701
+ else:
702
+ y_out, present = self.enc_dec_forward(
703
+ xa,
704
+ x_attention_mask,
705
+ x_padding_mask,
706
+ y_input,
707
+ new_y_lens,
708
+ y_attention_mask,
709
+ y_padding_mask,
710
+ tgt_y_lens=tgt_y_lens,
711
+ past=past
712
+ )
713
+ if past != None:
714
+ past = torch.cat([past, present.to(past.dtype)], dim=-2) if past.ndim > 3 else present.to(past.dtype)
715
+
716
+
717
+ y_out = y_out[:, -1:] # only take the last token
718
+
719
+ logits = torch.stack([self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)], dim=1) # [B K S card], B==S==1, so [1 K 1 card]
720
+ logits = logits.squeeze(0).squeeze(1) # [K card]
721
+ assert logits.shape == torch.Size((self.args.n_codebooks, self.n_audio_tokens[0])), f"{logits.shape}"
722
+
723
+ n_eog = sum(codebook_eog)
724
+ assert n_eog < self.args.n_codebooks
725
+ if self.args.eos > 0: # if we are using end-of-sentence token (which is used by default), eog shouldn't be used here, as there is no masked spans
726
+ for jj in range(self.args.n_codebooks):
727
+ logits[jj][self.args.eog] = -10000.
728
+
729
+ samples, codebook_eog, prev_token, consec_silence_count = sample_helper(n_eog, logits, codebook_eog, top_k, top_p, min_p, temperature, prev_token, consec_silence_count, stop_repetition, silence_tokens, cur_num_gen)
730
+ # samples.shape is [K,1]
731
+ # ge samples_emb
732
+ samples_emb = torch.stack([self.audio_embedding[k](samples[k]) for k in range(self.args.n_codebooks)], dim=0) # [K,1,D]
733
+ samples_emb = samples_emb.sum(dim=0,keepdim=True) # [1,1,D]
734
+
735
+ cur_num_gen += 1
736
+ cur_generated.append(samples.squeeze(-1)) # [K,1] -> [K]
737
+
738
+ if sum(codebook_eog) == self.args.n_codebooks: # generation for the current span is done
739
+ codebook_eog = [False] * self.args.n_codebooks
740
+ num_gen.append(cur_num_gen)
741
+ cur_num_gen = 0
742
+ generated.append(cur_generated)
743
+ cur_generated = []
744
+ break
745
+ else:
746
+ assert samples_emb.shape == torch.Size((1,1,self.args.d_model)), f"samples_emb.shape: {samples_emb.shape}"
747
+
748
+ embedded_y = torch.cat([embedded_y, samples_emb], dim=1)
749
+ new_y_lens = torch.LongTensor([embedded_y.shape[1]]).to(y.device)
750
+ y_input = self.audio_positional_embedding(embedded_y, tgt_y_lens) # [B T D]
751
+ # make attention mask and padding mask
752
+ y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device)
753
+ y_padding_mask = torch.full((1,new_y_lens[0]), False).to(y.device)
754
+
755
+ assert len(generated) == 1, f"len(generated): {len(generated)}"
756
+
757
+ # revert the pattern
758
+ flatten_gen = []
759
+ for l, orig_span in enumerate(generated):
760
+ span = torch.stack(orig_span, dim=0) # [T, K]
761
+ span = span.transpose(1,0) # [K, T]
762
+ assert span.shape[0] == self.args.n_codebooks, span.shape
763
+ unshifted_span = []
764
+ for j, s in enumerate(span):
765
+ start_from = j
766
+ end_at = - (self.args.n_codebooks - start_from)
767
+ unshifted_span.append(s[start_from:end_at])
768
+ unshifted_span = torch.stack(unshifted_span, dim=0)
769
+
770
+ assert unshifted_span.shape[1] == num_gen[l] - self.args.n_codebooks, f"len(unshifted_spans[0]): {len(unshifted_span[0])}, num_gen[l]: {num_gen[l]}"
771
+
772
+ flatten_gen.append(unshifted_span)
773
+ assert len(flatten_gen) == 1, len(flatten_gen)
774
+
775
+ # combine
776
+ res = [y[0], flatten_gen[0]]
777
+ res = torch.cat(res, dim=1).unsqueeze(0) # [K, new_t] -> [1, K, new_T]
778
+ expected_y_len = y_len + sum([item - self.args.n_codebooks for item in num_gen])
779
+ assert res.shape == torch.Size((1, self.args.n_codebooks, expected_y_len)), f"res.shape: {res.shape}, expected_y_len: {expected_y_len}. y_len + sum([item - self.args.n_codebooks for item in num_gen]): {y_len} + {sum([item - self.args.n_codebooks for item in num_gen])}"
780
+
781
+ if self.args.special_first:
782
+ res = res - int(self.args.n_special)
783
+ flatten_gen = flatten_gen - int(self.args.n_special)
784
+ return res, flatten_gen[0].unsqueeze(0)
pretrained/.gitkeep ADDED
File without changes
steps/__init__.py ADDED
File without changes
steps/optim.py ADDED
@@ -0,0 +1,1123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
2
+ #
3
+ # See ../LICENSE for clarification regarding multiple authors
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import contextlib
18
+ import logging
19
+ import random
20
+ from collections import defaultdict
21
+ from typing import List, Optional, Tuple, Union
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ from torch import Tensor
26
+ from torch.optim import Optimizer
27
+
28
+
29
+ class BatchedOptimizer(Optimizer):
30
+ """
31
+ This class adds to class Optimizer the capability to optimize parameters in batches:
32
+ it will stack the parameters and their grads for you so the optimizer can work
33
+ on tensors with an extra leading dimension. This is intended for speed with GPUs,
34
+ as it reduces the number of kernels launched in the optimizer.
35
+
36
+ Args:
37
+ params:
38
+ """
39
+
40
+ def __init__(self, params, defaults):
41
+ super(BatchedOptimizer, self).__init__(params, defaults)
42
+
43
+ @contextlib.contextmanager
44
+ def batched_params(self, param_group, group_params_names):
45
+ """
46
+ This function returns (technically, yields) a list of
47
+ of tuples (p, state), where
48
+ p is a `fake` parameter that is stacked (over axis 0) from real parameters
49
+ that share the same shape, and its gradient is also stacked;
50
+ `state` is the state corresponding to this batch of parameters
51
+ (it will be physically located in the "state" for one of the real
52
+ parameters, the last one that has any particular shape and dtype).
53
+
54
+ This function is decorated as a context manager so that it can
55
+ write parameters back to their "real" locations.
56
+
57
+ The idea is, instead of doing:
58
+ <code>
59
+ for p in group["params"]:
60
+ state = self.state[p]
61
+ ...
62
+ </code>
63
+ you can do:
64
+ <code>
65
+ with self.batched_params(group["params"]) as batches:
66
+ for p, state, p_names in batches:
67
+ ...
68
+ </code>
69
+
70
+ Args:
71
+ group: a parameter group, which is a list of parameters; should be
72
+ one of self.param_groups.
73
+ group_params_names: name for each parameter in group,
74
+ which is List[str].
75
+ """
76
+ batches = defaultdict(
77
+ list
78
+ ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
79
+ batches_names = defaultdict(
80
+ list
81
+ ) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
82
+
83
+ assert len(param_group) == len(group_params_names), f"len(param_group): {len(param_group)}, len(group_params_names): {len(group_params_names)}"
84
+ for p, named_p in zip(param_group, group_params_names):
85
+ key = (str(p.dtype), *p.shape)
86
+ batches[key].append(p)
87
+ batches_names[key].append(named_p)
88
+
89
+ batches_names_keys = list(batches_names.keys())
90
+ sorted_idx = sorted(
91
+ range(len(batches_names)), key=lambda i: batches_names_keys[i]
92
+ )
93
+ batches_names = [
94
+ batches_names[batches_names_keys[idx]] for idx in sorted_idx
95
+ ]
96
+ batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
97
+
98
+ stacked_params_dict = dict()
99
+
100
+ # turn batches into a list, in deterministic order.
101
+ # tuples will contain tuples of (stacked_param, state, stacked_params_names),
102
+ # one for each batch in `batches`.
103
+ tuples = []
104
+
105
+ for batch, batch_names in zip(batches, batches_names):
106
+ p = batch[0]
107
+ # we arbitrarily store the state in the
108
+ # state corresponding to the 1st parameter in the
109
+ # group. class Optimizer will take care of saving/loading state.
110
+ state = self.state[p]
111
+ p_stacked = torch.stack(batch)
112
+ grad = torch.stack(
113
+ [
114
+ torch.zeros_like(p) if p.grad is None else p.grad
115
+ for p in batch
116
+ ]
117
+ )
118
+ p_stacked.grad = grad
119
+ stacked_params_dict[key] = p_stacked
120
+ tuples.append((p_stacked, state, batch_names))
121
+
122
+ yield tuples # <-- calling code will do the actual optimization here!
123
+
124
+ for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
125
+ for i, p in enumerate(batch): # batch is list of Parameter
126
+ p.copy_(stacked_params[i])
127
+
128
+
129
+ class ScaledAdam(BatchedOptimizer):
130
+ """
131
+ Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
132
+ proportional to the norm of that parameter; and also learn the scale of the parameter,
133
+ in log space, subject to upper and lower limits (as if we had factored each parameter as
134
+ param = underlying_param * log_scale.exp())
135
+
136
+
137
+ Args:
138
+ params: The parameters or param_groups to optimize (like other Optimizer subclasses)
139
+ lr: The learning rate. We will typically use a learning rate schedule that starts
140
+ at 0.03 and decreases over time, i.e. much higher than other common
141
+ optimizers.
142
+ clipping_scale: (e.g. 2.0)
143
+ A scale for gradient-clipping: if specified, the normalized gradients
144
+ over the whole model will be clipped to have 2-norm equal to
145
+ `clipping_scale` times the median 2-norm over the most recent period
146
+ of `clipping_update_period` minibatches. By "normalized gradients",
147
+ we mean after multiplying by the rms parameter value for this tensor
148
+ [for non-scalars]; this is appropriate because our update is scaled
149
+ by this quantity.
150
+ betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad.
151
+ Must satisfy 0 < beta <= beta2 < 1.
152
+ scalar_lr_scale: A scaling factor on the learning rate, that we use to update the
153
+ scale of each parameter tensor and scalar parameters of the mode..
154
+ If each parameter were decomposed
155
+ as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale
156
+ would be a the scaling factor on the learning rate of p_scale.
157
+ eps: A general-purpose epsilon to prevent division by zero
158
+ param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
159
+ learning the scale on the parameters (we'll constrain the rms of each non-scalar
160
+ parameter tensor to be >= this value)
161
+ param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
162
+ learning the scale on the parameters (we'll constrain the rms of each non-scalar
163
+ parameter tensor to be <= this value)
164
+ scalar_max: Maximum absolute value for scalar parameters (applicable if your
165
+ model has any parameters with numel() == 1).
166
+ size_update_period: The periodicity, in steps, with which we update the size (scale)
167
+ of the parameter tensor. This is provided to save a little time
168
+ in the update.
169
+ clipping_update_period: if clipping_scale is specified, this is the period
170
+ """
171
+
172
+ def __init__(
173
+ self,
174
+ params,
175
+ lr=3e-02,
176
+ clipping_scale=None,
177
+ betas=(0.9, 0.98),
178
+ scalar_lr_scale=0.1,
179
+ eps=1.0e-08,
180
+ param_min_rms=1.0e-05,
181
+ param_max_rms=3.0,
182
+ scalar_max=10.0,
183
+ size_update_period=4,
184
+ clipping_update_period=100,
185
+ parameters_names=None,
186
+ show_dominant_parameters=True,
187
+ ):
188
+
189
+ assert parameters_names is not None, (
190
+ "Please prepare parameters_names,"
191
+ "which is a List[List[str]]. Each List[str] is for a group"
192
+ "and each str is for a parameter"
193
+ )
194
+ defaults = dict(
195
+ lr=lr,
196
+ clipping_scale=clipping_scale,
197
+ betas=betas,
198
+ scalar_lr_scale=scalar_lr_scale,
199
+ eps=eps,
200
+ param_min_rms=param_min_rms,
201
+ param_max_rms=param_max_rms,
202
+ scalar_max=scalar_max,
203
+ size_update_period=size_update_period,
204
+ clipping_update_period=clipping_update_period,
205
+ )
206
+
207
+ super(ScaledAdam, self).__init__(params, defaults)
208
+ assert len(self.param_groups) == len(parameters_names)
209
+ self.parameters_names = parameters_names
210
+ self.show_dominant_parameters = show_dominant_parameters
211
+
212
+ def __setstate__(self, state):
213
+ super(ScaledAdam, self).__setstate__(state)
214
+
215
+ @torch.no_grad()
216
+ def step(self, closure=None):
217
+ """Performs a single optimization step.
218
+
219
+ Arguments:
220
+ closure (callable, optional): A closure that reevaluates the model
221
+ and returns the loss.
222
+ """
223
+ loss = None
224
+ if closure is not None:
225
+ with torch.enable_grad():
226
+ loss = closure()
227
+
228
+ batch = True
229
+
230
+ for group, group_params_names in zip(
231
+ self.param_groups, self.parameters_names
232
+ ):
233
+
234
+ with self.batched_params(
235
+ group["params"], group_params_names
236
+ ) as batches:
237
+
238
+ # batches is list of pairs (stacked_param, state). stacked_param is like
239
+ # a regular parameter, and will have a .grad, but the 1st dim corresponds to
240
+ # a stacking dim, it is not a real dim.
241
+
242
+ if (
243
+ len(batches[0][1]) == 0
244
+ ): # if len(first state) == 0: not yet initialized
245
+ clipping_scale = 1
246
+ else:
247
+ clipping_scale = self._get_clipping_scale(group, batches)
248
+
249
+ for p, state, _ in batches:
250
+ # Perform optimization step.
251
+ # grad is not going to be None, we handled that when creating the batches.
252
+ grad = p.grad
253
+ if grad.is_sparse:
254
+ raise RuntimeError(
255
+ "ScaledAdam optimizer does not support sparse gradients"
256
+ )
257
+ # State initialization
258
+ if len(state) == 0:
259
+ self._init_state(group, p, state)
260
+
261
+ self._step_one_batch(group, p, state, clipping_scale)
262
+
263
+ return loss
264
+
265
+ def _init_state(self, group: dict, p: Tensor, state: dict):
266
+ """
267
+ Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p
268
+ is actually the batch dimension, corresponding to batched-together
269
+ parameters of a given shape.
270
+
271
+
272
+ Args:
273
+ group: Dict to look up configuration values.
274
+ p: The parameter that we are initializing the state for
275
+ state: Dict from string to whatever state we are initializing
276
+ """
277
+ size_update_period = group["size_update_period"]
278
+
279
+ state["step"] = 0
280
+
281
+ kwargs = {"device": p.device, "dtype": p.dtype}
282
+
283
+ # 'delta' implements conventional momentum. There are
284
+ # several different kinds of update going on, so rather than
285
+ # compute "exp_avg" like in Adam, we store and decay a
286
+ # parameter-change "delta", which combines all forms of
287
+ # update. this is equivalent to how it's done in Adam,
288
+ # except for the first few steps.
289
+ state["delta"] = torch.zeros_like(
290
+ p, memory_format=torch.preserve_format
291
+ )
292
+
293
+ batch_size = p.shape[0]
294
+ numel = p.numel() // batch_size
295
+ numel = p.numel()
296
+
297
+ if numel > 1:
298
+ # "param_rms" just periodically records the scalar root-mean-square value of
299
+ # the parameter tensor.
300
+ # it has a shape like (batch_size, 1, 1, 1, 1)
301
+ param_rms = (
302
+ (p ** 2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
303
+ )
304
+ state["param_rms"] = param_rms
305
+
306
+ state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
307
+ state["scale_grads"] = torch.zeros(
308
+ size_update_period, *param_rms.shape, **kwargs
309
+ )
310
+
311
+ # exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
312
+ state["exp_avg_sq"] = torch.zeros_like(
313
+ p, memory_format=torch.preserve_format
314
+ )
315
+
316
+ def _get_clipping_scale(
317
+ self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]
318
+ ) -> float:
319
+ """
320
+ Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
321
+ by this amount before applying the rest of the update.
322
+
323
+ Args:
324
+ group: the parameter group, an item in self.param_groups
325
+ tuples: a list of tuples of (param, state, param_names)
326
+ where param is a batched set of parameters,
327
+ with a .grad (1st dim is batch dim)
328
+ and state is the state-dict where optimization parameters are kept.
329
+ param_names is a List[str] while each str is name for a parameter
330
+ in batched set of parameters "param".
331
+ """
332
+ assert len(tuples) >= 1
333
+ clipping_scale = group["clipping_scale"]
334
+ (first_p, first_state, _) = tuples[0]
335
+ step = first_state["step"]
336
+ if clipping_scale is None or step == 0:
337
+ # no clipping. return early on step == 0 because the other
338
+ # parameters' state won't have been initialized yet.
339
+ return 1.0
340
+ clipping_update_period = group["clipping_update_period"]
341
+
342
+ tot_sumsq = torch.tensor(0.0, device=first_p.device)
343
+ for (p, state, param_names) in tuples:
344
+ grad = p.grad
345
+ if grad.is_sparse:
346
+ raise RuntimeError(
347
+ "ScaledAdam optimizer does not support sparse gradients"
348
+ )
349
+ if p.numel() == p.shape[0]: # a batch of scalars
350
+ tot_sumsq += (
351
+ grad ** 2
352
+ ).sum() # sum() to change shape [1] to []
353
+ else:
354
+ tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
355
+
356
+ tot_norm = tot_sumsq.sqrt()
357
+ if "model_norms" not in first_state:
358
+ first_state["model_norms"] = torch.zeros(
359
+ clipping_update_period, device=p.device
360
+ )
361
+ first_state["model_norms"][step % clipping_update_period] = tot_norm
362
+
363
+ if step % clipping_update_period == 0:
364
+ # Print some stats.
365
+ # We don't reach here if step == 0 because we would have returned
366
+ # above.
367
+ sorted_norms = first_state["model_norms"].sort()[0].to("cpu")
368
+ quartiles = []
369
+ for n in range(0, 5):
370
+ index = min(
371
+ clipping_update_period - 1,
372
+ (clipping_update_period // 4) * n,
373
+ )
374
+ quartiles.append(sorted_norms[index].item())
375
+
376
+ median = quartiles[2]
377
+ threshold = clipping_scale * median
378
+ first_state["model_norm_threshold"] = threshold
379
+ percent_clipped = (
380
+ first_state["num_clipped"] * 100.0 / clipping_update_period
381
+ if "num_clipped" in first_state
382
+ else 0.0
383
+ )
384
+ first_state["num_clipped"] = 0
385
+ quartiles = " ".join(["%.3e" % x for x in quartiles])
386
+ logging.info(
387
+ f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
388
+ f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
389
+ )
390
+
391
+ if step < clipping_update_period:
392
+ return 1.0 # We have not yet estimated a norm to clip to.
393
+ else:
394
+ try:
395
+ model_norm_threshold = first_state["model_norm_threshold"]
396
+ except KeyError:
397
+ logging.info(
398
+ "Warning: model_norm_threshold not in state: possibly "
399
+ "you changed config when restarting, adding clipping_scale option?"
400
+ )
401
+ return 1.0
402
+ ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
403
+ if ans < 1.0:
404
+ first_state["num_clipped"] += 1
405
+ if ans < 0.1:
406
+ logging.warn(
407
+ f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
408
+ )
409
+ if self.show_dominant_parameters:
410
+ assert p.shape[0] == len(param_names)
411
+ self._show_gradient_dominating_parameter(tuples, tot_sumsq)
412
+ return ans
413
+
414
+ def _show_gradient_dominating_parameter(
415
+ self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor
416
+ ):
417
+ """
418
+ Show information of parameter wihch dominanting tot_sumsq.
419
+
420
+ Args:
421
+ tuples: a list of tuples of (param, state, param_names)
422
+ where param is a batched set of parameters,
423
+ with a .grad (1st dim is batch dim)
424
+ and state is the state-dict where optimization parameters are kept.
425
+ param_names is a List[str] while each str is name for a parameter
426
+ in batched set of parameters "param".
427
+ tot_sumsq: sumsq of all parameters. Though it's could be calculated
428
+ from tuples, we still pass it to save some time.
429
+ """
430
+ all_sumsq_orig = {}
431
+ for (p, state, batch_param_names) in tuples:
432
+ # p is a stacked batch parameters.
433
+ batch_grad = p.grad
434
+ if p.numel() == p.shape[0]: # a batch of scalars
435
+ batch_sumsq_orig = batch_grad ** 2
436
+ # Dummpy values used by following `zip` statement.
437
+ batch_rms_orig = torch.ones(p.shape[0])
438
+ else:
439
+ batch_rms_orig = state["param_rms"]
440
+ batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum(
441
+ dim=list(range(1, batch_grad.ndim))
442
+ )
443
+
444
+ for name, sumsq_orig, rms, grad in zip(
445
+ batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
446
+ ):
447
+
448
+ proportion_orig = sumsq_orig / tot_sumsq
449
+ all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
450
+
451
+ assert torch.isclose(
452
+ sum([value[0] for value in all_sumsq_orig.values()]).cpu(),
453
+ torch.tensor(1.0),
454
+ )
455
+ sorted_by_proportion = {
456
+ k: v
457
+ for k, v in sorted(
458
+ all_sumsq_orig.items(),
459
+ key=lambda item: item[1][0],
460
+ reverse=True,
461
+ )
462
+ }
463
+ dominant_param_name = next(iter(sorted_by_proportion))
464
+ (
465
+ dominant_proportion,
466
+ dominant_sumsq,
467
+ dominant_rms,
468
+ dominant_grad,
469
+ ) = sorted_by_proportion[dominant_param_name]
470
+ logging.info(
471
+ f"Parameter Dominanting tot_sumsq {dominant_param_name}"
472
+ f" with proportion {dominant_proportion:.2f},"
473
+ f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
474
+ f"={dominant_sumsq:.3e},"
475
+ f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
476
+ f" orig_rms_sq={(dominant_rms**2).item():.3e}"
477
+ )
478
+
479
+ def _step_one_batch(
480
+ self, group: dict, p: Tensor, state: dict, clipping_scale: float
481
+ ):
482
+ """
483
+ Do the step for one parameter, which is actually going to be a batch of
484
+ `real` parameters, with dim 0 as the batch dim.
485
+ Args:
486
+ group: dict to look up configuration values
487
+ p: parameter to update (actually multiple parameters stacked together
488
+ as a batch)
489
+ state: state-dict for p, to look up the optimizer state
490
+ """
491
+ lr = group["lr"]
492
+ size_update_period = group["size_update_period"]
493
+ beta1 = group["betas"][0]
494
+
495
+ grad = p.grad
496
+ if clipping_scale != 1.0:
497
+ grad = grad * clipping_scale
498
+ step = state["step"]
499
+ delta = state["delta"]
500
+
501
+ delta.mul_(beta1)
502
+ batch_size = p.shape[0]
503
+ numel = p.numel() // batch_size
504
+ if numel > 1:
505
+ # Update the size/scale of p, and set param_rms
506
+ scale_grads = state["scale_grads"]
507
+ scale_grads[step % size_update_period] = (p * grad).sum(
508
+ dim=list(range(1, p.ndim)), keepdim=True
509
+ )
510
+ if step % size_update_period == size_update_period - 1:
511
+ param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
512
+ param_rms.copy_(
513
+ (p ** 2)
514
+ .mean(dim=list(range(1, p.ndim)), keepdim=True)
515
+ .sqrt()
516
+ )
517
+ if step > 0:
518
+ # self._size_update() learns the overall scale on the
519
+ # parameter, by shrinking or expanding it.
520
+ self._size_update(group, scale_grads, p, state)
521
+
522
+ if numel == 1:
523
+ # For parameters with 1 element we just use regular Adam.
524
+ # Updates delta.
525
+ self._step_scalar(group, p, state)
526
+ else:
527
+ self._step(group, p, state)
528
+
529
+ state["step"] = step + 1
530
+
531
+ def _size_update(
532
+ self, group: dict, scale_grads: Tensor, p: Tensor, state: dict
533
+ ) -> None:
534
+ """
535
+ Called only where p.numel() > 1, this updates the scale of the parameter.
536
+ If we imagine: p = underlying_param * scale.exp(), and we are doing
537
+ gradient descent on underlying param and on scale, this function does the update
538
+ on `scale`.
539
+
540
+ Args:
541
+ group: dict to look up configuration values
542
+ scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing
543
+ grads w.r.t. the scales.
544
+ p: The parameter to update
545
+ state: The state-dict of p
546
+ """
547
+
548
+ param_rms = state["param_rms"]
549
+ beta1, beta2 = group["betas"]
550
+ size_lr = group["lr"] * group["scalar_lr_scale"]
551
+ param_min_rms = group["param_min_rms"]
552
+ param_max_rms = group["param_max_rms"]
553
+ eps = group["eps"]
554
+ step = state["step"]
555
+ batch_size = p.shape[0]
556
+
557
+ size_update_period = scale_grads.shape[0]
558
+ # correct beta2 for the size update period: we will have
559
+ # faster decay at this level.
560
+ beta2_corr = beta2 ** size_update_period
561
+
562
+ scale_exp_avg_sq = state[
563
+ "scale_exp_avg_sq"
564
+ ] # shape: (batch_size, 1, 1, ..)
565
+ scale_exp_avg_sq.mul_(beta2_corr).add_(
566
+ (scale_grads ** 2).mean(
567
+ dim=0
568
+ ), # mean over dim `size_update_period`
569
+ alpha=1 - beta2_corr,
570
+ ) # shape is (batch_size, 1, 1, ...)
571
+
572
+ # The 1st time we reach here is when size_step == 1.
573
+ size_step = (step + 1) // size_update_period
574
+ bias_correction2 = 1 - beta2_corr ** size_step
575
+ # we don't bother with bias_correction1; this will help prevent divergence
576
+ # at the start of training.
577
+
578
+ denom = scale_exp_avg_sq.sqrt() + eps
579
+
580
+ scale_step = (
581
+ -size_lr
582
+ * (bias_correction2 ** 0.5)
583
+ * scale_grads.sum(dim=0)
584
+ / denom
585
+ )
586
+
587
+ is_too_small = param_rms < param_min_rms
588
+ is_too_large = param_rms > param_max_rms
589
+
590
+ # when the param gets too small, just don't shrink it any further.
591
+ scale_step.masked_fill_(is_too_small, 0.0)
592
+ # when it gets too large, stop it from getting any larger.
593
+ scale_step.masked_fill_(is_too_large, -size_lr * size_update_period)
594
+ delta = state["delta"]
595
+ # the factor of (1-beta1) relates to momentum.
596
+ delta.add_(p * scale_step, alpha=(1 - beta1))
597
+
598
+ def _step(self, group: dict, p: Tensor, state: dict):
599
+ """
600
+ This function does the core update of self.step(), in the case where the members of
601
+ the batch have more than 1 element.
602
+
603
+ Args:
604
+ group: A dict which will be used to look up configuration values
605
+ p: The parameter to be updated
606
+ grad: The grad of p
607
+ state: The state-dict corresponding to parameter p
608
+
609
+ This function modifies p.
610
+ """
611
+ grad = p.grad
612
+ lr = group["lr"]
613
+ beta1, beta2 = group["betas"]
614
+ eps = group["eps"]
615
+ param_min_rms = group["param_min_rms"]
616
+ step = state["step"]
617
+
618
+ exp_avg_sq = state["exp_avg_sq"]
619
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
620
+
621
+ this_step = state["step"] - (
622
+ state["zero_step"] if "zero_step" in state else 0
623
+ )
624
+ bias_correction2 = 1 - beta2 ** (this_step + 1)
625
+ if bias_correction2 < 0.99:
626
+ # note: not in-place.
627
+ exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
628
+
629
+ denom = exp_avg_sq.sqrt()
630
+ denom += eps
631
+ grad = grad / denom
632
+
633
+ alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms)
634
+
635
+ delta = state["delta"]
636
+ delta.add_(grad * alpha)
637
+ p.add_(delta)
638
+
639
+ def _step_scalar(self, group: dict, p: Tensor, state: dict):
640
+ """
641
+ A simplified form of the core update for scalar tensors, where we cannot get a good
642
+ estimate of the parameter rms.
643
+ """
644
+ beta1, beta2 = group["betas"]
645
+ scalar_max = group["scalar_max"]
646
+ eps = group["eps"]
647
+ lr = group["lr"] * group["scalar_lr_scale"]
648
+ grad = p.grad
649
+
650
+ exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,)
651
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
652
+
653
+ # bias_correction2 is like in Adam. Don't bother with bias_correction1;
654
+ # slower update at the start will help stability anyway.
655
+ bias_correction2 = 1 - beta2 ** (state["step"] + 1)
656
+ denom = (exp_avg_sq / bias_correction2).sqrt() + eps
657
+
658
+ delta = state["delta"]
659
+ delta.add_(grad / denom, alpha=-lr * (1 - beta1))
660
+ p.clamp_(min=-scalar_max, max=scalar_max)
661
+ p.add_(delta)
662
+
663
+
664
+ class LRScheduler(object):
665
+ """
666
+ Base-class for learning rate schedulers where the learning-rate depends on both the
667
+ batch and the epoch.
668
+ """
669
+
670
+ def __init__(self, optimizer: Optimizer, verbose: bool = False):
671
+ # Attach optimizer
672
+ if not isinstance(optimizer, Optimizer):
673
+ raise TypeError(
674
+ "{} is not an Optimizer".format(type(optimizer).__name__)
675
+ )
676
+ self.optimizer = optimizer
677
+ self.verbose = verbose
678
+
679
+ for group in optimizer.param_groups:
680
+ group.setdefault("base_lr", group["lr"])
681
+
682
+ self.base_lrs = [group["base_lr"] for group in optimizer.param_groups]
683
+
684
+ self.epoch = 0
685
+ self.batch = 0
686
+
687
+ def state_dict(self):
688
+ """Returns the state of the scheduler as a :class:`dict`.
689
+
690
+ It contains an entry for every variable in self.__dict__ which
691
+ is not the optimizer.
692
+ """
693
+ return {
694
+ "base_lrs": self.base_lrs,
695
+ "epoch": self.epoch,
696
+ "batch": self.batch,
697
+ }
698
+
699
+ def load_state_dict(self, state_dict):
700
+ """Loads the schedulers state.
701
+
702
+ Args:
703
+ state_dict (dict): scheduler state. Should be an object returned
704
+ from a call to :meth:`state_dict`.
705
+ """
706
+ self.__dict__.update(state_dict)
707
+
708
+ def get_last_lr(self) -> List[float]:
709
+ """Return last computed learning rate by current scheduler. Will be a list of float."""
710
+ return self._last_lr
711
+
712
+ def get_lr(self):
713
+ # Compute list of learning rates from self.epoch and self.batch and
714
+ # self.base_lrs; this must be overloaded by the user.
715
+ # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ]
716
+ raise NotImplementedError
717
+
718
+ def step_batch(self, batch: Optional[int] = None) -> None:
719
+ # Step the batch index, or just set it. If `batch` is specified, it
720
+ # must be the batch index from the start of training, i.e. summed over
721
+ # all epochs.
722
+ # You can call this in any order; if you don't provide 'batch', it should
723
+ # of course be called once per batch.
724
+ if batch is not None:
725
+ self.batch = batch
726
+ else:
727
+ self.batch = self.batch + 1
728
+ self._set_lrs()
729
+
730
+ def step_epoch(self, epoch: Optional[int] = None):
731
+ # Step the epoch index, or just set it. If you provide the 'epoch' arg,
732
+ # you should call this at the start of the epoch; if you don't provide the 'epoch'
733
+ # arg, you should call it at the end of the epoch.
734
+ if epoch is not None:
735
+ self.epoch = epoch
736
+ else:
737
+ self.epoch = self.epoch + 1
738
+ self._set_lrs()
739
+
740
+ def _set_lrs(self):
741
+ values = self.get_lr()
742
+ assert len(values) == len(self.optimizer.param_groups)
743
+
744
+ for i, data in enumerate(zip(self.optimizer.param_groups, values)):
745
+ param_group, lr = data
746
+ param_group["lr"] = lr
747
+ self.print_lr(self.verbose, i, lr)
748
+ self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
749
+
750
+ def print_lr(self, is_verbose, group, lr):
751
+ """Display the current learning rate."""
752
+ if is_verbose:
753
+ logging.info(
754
+ f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
755
+ f" of group {group} to {lr:.4e}."
756
+ )
757
+
758
+
759
+ class Eden(LRScheduler):
760
+ """
761
+ Eden scheduler.
762
+ The basic formula (before warmup) is:
763
+ lr = base_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 *
764
+ (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) * warmup
765
+ where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches
766
+ and then stays constant at 1.
767
+
768
+
769
+ E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam
770
+
771
+ Args:
772
+ optimizer: the optimizer to change the learning rates on
773
+ lr_batches: the number of batches after which we start significantly
774
+ decreasing the learning rate, suggest 5000.
775
+ lr_epochs: the number of epochs after which we start significantly
776
+ decreasing the learning rate, suggest 6 if you plan to do e.g.
777
+ 20 to 40 epochs, but may need smaller number if dataset is huge
778
+ and you will do few epochs.
779
+ """
780
+
781
+ def __init__(
782
+ self,
783
+ optimizer: Optimizer,
784
+ lr_batches: Union[int, float],
785
+ lr_epochs: Union[int, float],
786
+ warmup_batches: Union[int, float] = 500.0,
787
+ verbose: bool = False,
788
+ ):
789
+ super(Eden, self).__init__(optimizer, verbose)
790
+ self.lr_batches = lr_batches
791
+ self.lr_epochs = lr_epochs
792
+ self.warmup_batches = warmup_batches
793
+
794
+ def get_lr(self):
795
+ factor = (
796
+ (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2
797
+ ) ** -0.25 * (
798
+ ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2)
799
+ ** -0.25
800
+ )
801
+ warmup_factor = (
802
+ 1.0
803
+ if self.batch >= self.warmup_batches
804
+ else 0.5 + 0.5 * (self.batch / self.warmup_batches)
805
+ )
806
+
807
+ return [x * factor * warmup_factor for x in self.base_lrs]
808
+
809
+
810
+ def _test_eden():
811
+ m = torch.nn.Linear(100, 100)
812
+ optim = ScaledAdam(m.parameters(), lr=0.03)
813
+
814
+ scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True)
815
+
816
+ for epoch in range(10):
817
+ scheduler.step_epoch(epoch) # sets epoch to `epoch`
818
+
819
+ for step in range(20):
820
+ x = torch.randn(200, 100).detach()
821
+ x.requires_grad = True
822
+ y = m(x)
823
+ dy = torch.randn(200, 100).detach()
824
+ f = (y * dy).sum()
825
+ f.backward()
826
+
827
+ optim.step()
828
+ scheduler.step_batch()
829
+ optim.zero_grad()
830
+
831
+ logging.info(f"last lr = {scheduler.get_last_lr()}")
832
+ logging.info(f"state dict = {scheduler.state_dict()}")
833
+
834
+
835
+ # This is included mostly as a baseline for ScaledAdam.
836
+ class Eve(Optimizer):
837
+ """
838
+ Implements Eve algorithm. This is a modified version of AdamW with a special
839
+ way of setting the weight-decay / shrinkage-factor, which is designed to make the
840
+ rms of the parameters approach a particular target_rms (default: 0.1). This is
841
+ for use with networks with 'scaled' versions of modules (see scaling.py), which
842
+ will be close to invariant to the absolute scale on the parameter matrix.
843
+
844
+ The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
845
+ The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
846
+ Eve is unpublished so far.
847
+
848
+ Arguments:
849
+ params (iterable): iterable of parameters to optimize or dicts defining
850
+ parameter groups
851
+ lr (float, optional): learning rate (default: 1e-3)
852
+ betas (Tuple[float, float], optional): coefficients used for computing
853
+ running averages of gradient and its square (default: (0.9, 0.999))
854
+ eps (float, optional): term added to the denominator to improve
855
+ numerical stability (default: 1e-8)
856
+ weight_decay (float, optional): weight decay coefficient (default: 3e-4;
857
+ this value means that the weight would decay significantly after
858
+ about 3k minibatches. Is not multiplied by learning rate, but
859
+ is conditional on RMS-value of parameter being > target_rms.
860
+ target_rms (float, optional): target root-mean-square value of
861
+ parameters, if they fall below this we will stop applying weight decay.
862
+
863
+
864
+ .. _Adam: A Method for Stochastic Optimization:
865
+ https://arxiv.org/abs/1412.6980
866
+ .. _Decoupled Weight Decay Regularization:
867
+ https://arxiv.org/abs/1711.05101
868
+ .. _On the Convergence of Adam and Beyond:
869
+ https://openreview.net/forum?id=ryQu7f-RZ
870
+ """
871
+
872
+ def __init__(
873
+ self,
874
+ params,
875
+ lr=1e-3,
876
+ betas=(0.9, 0.98),
877
+ eps=1e-8,
878
+ weight_decay=1e-3,
879
+ target_rms=0.1,
880
+ ):
881
+ if not 0.0 <= lr:
882
+ raise ValueError("Invalid learning rate: {}".format(lr))
883
+ if not 0.0 <= eps:
884
+ raise ValueError("Invalid epsilon value: {}".format(eps))
885
+ if not 0.0 <= betas[0] < 1.0:
886
+ raise ValueError(
887
+ "Invalid beta parameter at index 0: {}".format(betas[0])
888
+ )
889
+ if not 0.0 <= betas[1] < 1.0:
890
+ raise ValueError(
891
+ "Invalid beta parameter at index 1: {}".format(betas[1])
892
+ )
893
+ if not 0 <= weight_decay <= 0.1:
894
+ raise ValueError(
895
+ "Invalid weight_decay value: {}".format(weight_decay)
896
+ )
897
+ if not 0 < target_rms <= 10.0:
898
+ raise ValueError("Invalid target_rms value: {}".format(target_rms))
899
+ defaults = dict(
900
+ lr=lr,
901
+ betas=betas,
902
+ eps=eps,
903
+ weight_decay=weight_decay,
904
+ target_rms=target_rms,
905
+ )
906
+ super(Eve, self).__init__(params, defaults)
907
+
908
+ def __setstate__(self, state):
909
+ super(Eve, self).__setstate__(state)
910
+
911
+ @torch.no_grad()
912
+ def step(self, closure=None):
913
+ """Performs a single optimization step.
914
+
915
+ Arguments:
916
+ closure (callable, optional): A closure that reevaluates the model
917
+ and returns the loss.
918
+ """
919
+ loss = None
920
+ if closure is not None:
921
+ with torch.enable_grad():
922
+ loss = closure()
923
+
924
+ for group in self.param_groups:
925
+ for p in group["params"]:
926
+ if p.grad is None:
927
+ continue
928
+
929
+ # Perform optimization step
930
+ grad = p.grad
931
+ if grad.is_sparse:
932
+ raise RuntimeError(
933
+ "AdamW does not support sparse gradients"
934
+ )
935
+
936
+ state = self.state[p]
937
+
938
+ # State initialization
939
+ if len(state) == 0:
940
+ state["step"] = 0
941
+ # Exponential moving average of gradient values
942
+ state["exp_avg"] = torch.zeros_like(
943
+ p, memory_format=torch.preserve_format
944
+ )
945
+ # Exponential moving average of squared gradient values
946
+ state["exp_avg_sq"] = torch.zeros_like(
947
+ p, memory_format=torch.preserve_format
948
+ )
949
+
950
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
951
+
952
+ beta1, beta2 = group["betas"]
953
+
954
+ state["step"] += 1
955
+ bias_correction1 = 1 - beta1 ** state["step"]
956
+ bias_correction2 = 1 - beta2 ** state["step"]
957
+
958
+ # Decay the first and second moment running average coefficient
959
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
960
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
961
+ denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_(
962
+ group["eps"]
963
+ )
964
+
965
+ step_size = group["lr"] / bias_correction1
966
+ target_rms = group["target_rms"]
967
+ weight_decay = group["weight_decay"]
968
+
969
+ if p.numel() > 1:
970
+ # avoid applying this weight-decay on "scaling factors"
971
+ # (which are scalar).
972
+ is_above_target_rms = p.norm() > (
973
+ target_rms * (p.numel() ** 0.5)
974
+ )
975
+ p.mul_(1 - (weight_decay * is_above_target_rms))
976
+
977
+ p.addcdiv_(exp_avg, denom, value=-step_size)
978
+
979
+ # if random.random() < 0.0005:
980
+ # step = (exp_avg / denom) * step_size
981
+ # logging.info(
982
+ # f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}"
983
+ # )
984
+
985
+ return loss
986
+
987
+ def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
988
+ """
989
+ Behaves like a constructor of a modified version of nn.Linear
990
+ that gives an easy way to set the default initial parameter scale.
991
+
992
+ Args:
993
+ Accepts the standard args and kwargs that nn.Linear accepts
994
+ e.g. in_features, out_features, bias=False.
995
+
996
+ initial_scale: you can override this if you want to increase
997
+ or decrease the initial magnitude of the module's output
998
+ (affects the initialization of weight_scale and bias_scale).
999
+ Another option, if you want to do something like this, is
1000
+ to re-initialize the parameters.
1001
+ """
1002
+ ans = nn.Linear(*args, **kwargs)
1003
+ with torch.no_grad():
1004
+ ans.weight[:] *= initial_scale
1005
+ if ans.bias is not None:
1006
+ torch.nn.init.uniform_(
1007
+ ans.bias, -0.1 * initial_scale, 0.1 * initial_scale
1008
+ )
1009
+ return ans
1010
+ def _test_scaled_adam(hidden_dim: int):
1011
+ import timeit
1012
+
1013
+ E = 100
1014
+ B = 4
1015
+ T = 2
1016
+ logging.info("in test_eve_cain")
1017
+ # device = torch.device('cuda')
1018
+ device = torch.device("cpu")
1019
+ dtype = torch.float32
1020
+
1021
+ # these input_magnitudes and output_magnitudes are to test that
1022
+ # Abel is working as we expect and is able to adjust scales of
1023
+ # different dims differently.
1024
+ input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
1025
+ output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
1026
+
1027
+ for iter in [1, 0]:
1028
+ Linear = torch.nn.Linear if iter == 0 else ScaledLinear
1029
+
1030
+ m = torch.nn.Sequential(
1031
+ Linear(E, hidden_dim),
1032
+ torch.nn.PReLU(),
1033
+ Linear(hidden_dim, hidden_dim),
1034
+ torch.nn.PReLU(),
1035
+ Linear(hidden_dim, E),
1036
+ ).to(device)
1037
+
1038
+ train_pairs = [
1039
+ (
1040
+ 100.0
1041
+ * torch.randn(B, T, E, device=device, dtype=dtype)
1042
+ * input_magnitudes,
1043
+ torch.randn(B, T, E, device=device, dtype=dtype)
1044
+ * output_magnitudes,
1045
+ )
1046
+ for _ in range(20)
1047
+ ]
1048
+
1049
+ if iter == 0:
1050
+ optim = Eve(m.parameters(), lr=0.003)
1051
+ elif iter == 1:
1052
+ optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0)
1053
+ scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
1054
+
1055
+ start = timeit.default_timer()
1056
+ avg_loss = 0.0
1057
+ for epoch in range(180):
1058
+ scheduler.step_epoch()
1059
+ # if epoch == 100 and iter in [2,3]:
1060
+ # optim.reset_speedup() # check it doesn't crash.
1061
+
1062
+ # if epoch == 130:
1063
+ # opts = diagnostics.TensorDiagnosticOptions(
1064
+ # 2 ** 22
1065
+ # ) # allow 4 megabytes per sub-module
1066
+ # diagnostic = diagnostics.attach_diagnostics(m, opts)
1067
+
1068
+ for n, (x, y) in enumerate(train_pairs):
1069
+ y_out = m(x)
1070
+ loss = ((y_out - y) ** 2).mean() * 100.0
1071
+ if epoch == 0 and n == 0:
1072
+ avg_loss = loss.item()
1073
+ else:
1074
+ avg_loss = 0.98 * avg_loss + 0.02 * loss.item()
1075
+ if n == 0 and epoch % 5 == 0:
1076
+ # norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
1077
+ # norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
1078
+ # norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()
1079
+ # norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item()
1080
+ # scale1 = '%.2e' % (m[0].weight_scale.exp().item())
1081
+ # scale1b = '%.2e' % (m[0].bias_scale.exp().item())
1082
+ # scale2 = '%.2e' % (m[2].weight_scale.exp().item())
1083
+ # scale2b = '%.2e' % (m[2].bias_scale.exp().item())
1084
+ lr = scheduler.get_last_lr()[0]
1085
+ logging.info(
1086
+ f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}"
1087
+ ) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
1088
+ loss.log().backward()
1089
+ optim.step()
1090
+ optim.zero_grad()
1091
+ scheduler.step_batch()
1092
+
1093
+ # diagnostic.print_diagnostics()
1094
+
1095
+ stop = timeit.default_timer()
1096
+ logging.info(f"Iter={iter}, Time taken: {stop - start}")
1097
+
1098
+ logging.info(f"last lr = {scheduler.get_last_lr()}")
1099
+ # logging.info("state dict = ", scheduler.state_dict())
1100
+ # logging.info("optim state_dict = ", optim.state_dict())
1101
+ logging.info(f"input_magnitudes = {input_magnitudes}")
1102
+ logging.info(f"output_magnitudes = {output_magnitudes}")
1103
+
1104
+
1105
+ if __name__ == "__main__":
1106
+ torch.set_num_threads(1)
1107
+ torch.set_num_interop_threads(1)
1108
+ logging.getLogger().setLevel(logging.INFO)
1109
+ import subprocess
1110
+
1111
+ s = subprocess.check_output(
1112
+ "git status -uno .; git log -1; git diff HEAD .", shell=True
1113
+ )
1114
+ logging.info(s)
1115
+ import sys
1116
+
1117
+ if len(sys.argv) > 1:
1118
+ hidden_dim = int(sys.argv[1])
1119
+ else:
1120
+ hidden_dim = 200
1121
+
1122
+ _test_scaled_adam(hidden_dim)
1123
+ _test_eden()
steps/trainer.py ADDED
@@ -0,0 +1,717 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time, sys, subprocess, json, re
2
+ from pathlib import Path
3
+ import os, random
4
+ import torch
5
+ import math, pickle
6
+ from tqdm import tqdm
7
+ from torch.optim import AdamW
8
+ from torch.optim.lr_scheduler import LambdaLR
9
+ import torch.nn as nn
10
+ import torch.distributed as dist
11
+ from torch.utils.data.sampler import Sampler
12
+ import copy
13
+ from torch.utils.tensorboard import SummaryWriter
14
+ import numpy as np
15
+ from torch.utils.data.distributed import DistributedSampler
16
+ import logging
17
+ # from data import librilight, gigaspeech, gigaspeech_waveform
18
+ from data import combined_dataset
19
+ from models import voice_star
20
+
21
+ from .trainer_utils import DistributedDynamicBatchSampler, StatefulDistributedSampler, StatefulSampler, AverageMeter, print_model_info
22
+ from .optim import ScaledAdam, Eden
23
+ import run_gen
24
+ import wandb, socket
25
+
26
+ class Trainer:
27
+
28
+ def __init__(self, args, world_size, rank, local_rank):
29
+ self.start_time = time.time()
30
+ self.args = args
31
+ if self.args.val_max_num_tokens == None:
32
+ self.args.val_max_num_tokens = self.args.max_num_tokens
33
+ self.world_size, self.rank, self.local_rank = world_size, rank, local_rank
34
+ self.device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
35
+ if self.rank == 0:
36
+ self.writer = SummaryWriter(args.exp_dir)
37
+ self.wandb = wandb.init(project="voice_editor", name=args.exp_dir.split("/")[-1], config=args, dir=args.exp_dir, entity=self.args.wandb_entity)
38
+ self.seed_everything(seed=self.args.seed)
39
+ self.meters = self._setup_meters()
40
+
41
+ self.progress, self.total_progress = self._setup_progress()
42
+
43
+ self.model, self.trainables, self.optim_states, self.scheduler_states, self.phn2num = self._setup_models()
44
+
45
+ self.train_dataset_length, self.train_sampler, self.train_loader, self.valid_loader = self._setup_dataloader() # both are use DistributedSampler, train sampler is stateful
46
+ if self.args.num_steps != None:
47
+ self.total_step = self.args.num_steps
48
+ self.args.num_epochs = math.ceil(self.total_step / math.floor(self.train_dataset_length / self.args.batch_size)) if not self.args.dynamic_batching else None
49
+ else:
50
+ self.total_step = int(math.floor(self.train_dataset_length / self.args.batch_size))*self.args.num_epochs
51
+
52
+ self.optimizer, self.scheduler = self._setup_optimizer()
53
+ self.scaler = torch.cuda.amp.GradScaler()
54
+ self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[self.local_rank], find_unused_parameters=False)
55
+ self.early_stop_accu_steps = 0
56
+ if self.rank == 0:
57
+ if self.args.dynamic_batching:
58
+ logging.info(f"max number of tokens per GPU in a training batch: {self.args.max_num_tokens}, max number of tokens per GPU in a inference batch: {self.args.val_max_num_tokens}")
59
+ else:
60
+ logging.info(f"batch size (per gpu): {self.args.batch_size}")
61
+
62
+ self.args.inference_every_n_steps = getattr(self.args, "inference_every_n_steps", self.args.val_every_n_steps*5)
63
+ assert self.args.inference_every_n_steps > self.args.val_every_n_steps and self.args.inference_every_n_steps % self.args.val_every_n_steps == 0, "inference_every_n_steps should be divisible by val_every_n_steps, otherwise the code will not get a chance to run inference"
64
+
65
+ def train(self):
66
+ flag = True
67
+ skip_flag = False
68
+ data_start_time = time.time()
69
+ if self.progress['step'] >= self.total_step:
70
+ if self.rank == 0:
71
+ self.writer.close()
72
+ self.wandb.finish()
73
+ return
74
+ while flag:
75
+ self.train_sampler.set_epoch(self.progress['epoch'])
76
+ for i, batch in enumerate(self.train_loader):
77
+ if len(batch['y_lens']) < self.args.gradient_accumulation_steps:
78
+ continue
79
+ data_end_time = time.time()
80
+ self.model.train()
81
+ if self.progress['step'] >= getattr(self.args, "uniform_weight_start_step", 1e50):
82
+ if self.progress['step'] == getattr(self.args, "uniform_weight_start_step", 1e50) and self.rank == 0:
83
+ logging.info("NOTE: start using uniform weight from step: {}".format(self.progress['step']))
84
+ self.args.codebook_weight = [2.5,2,1.5,0.6]
85
+ self.model.module.args.codebook_weight = [2.5,2,1.5,0.6]
86
+
87
+ if self.progress['step'] >= self.total_step:
88
+ dist.barrier()
89
+ flag = False
90
+ self.validate_and_save()
91
+ if self.rank == 0:
92
+ self.writer.close()
93
+ self.wandb.finish()
94
+ break
95
+ if isinstance(self.scheduler, Eden):
96
+ self.scheduler.step_epoch(self.progress['step']//self.args.pseudo_epoch_size + 1)
97
+ if self.args.optimizer_name == "ScaledAdam":
98
+ cur_lr = self.scheduler.get_last_lr()[0]
99
+ else:
100
+ lrs = [param_group['lr'] for param_group in self.optimizer.param_groups]
101
+ assert lrs[0] == lrs[1]
102
+ cur_lr = lrs[0]
103
+
104
+ if self.rank == 0 and self.progress['step'] % self.args.tb_write_every_n_steps == 0:
105
+ self.writer.add_scalar("train/lr", cur_lr, self.progress['step'])
106
+ self.wandb.log({"train/lr": cur_lr}, step=self.progress['step'])
107
+
108
+ all_inds = list(range(len(batch['y'])))
109
+ sum_losses = 0
110
+ sum_top10acc = 0
111
+ sum_ntoken = 0
112
+ sum_top10acc_cbi = [0 for _ in range(self.args.n_codebooks)]
113
+ # extra losses
114
+ sum_extra_losses = {}
115
+ # when using prompt-based training, it's likely that due to prompt, the total length gets much longer, which make effective batch size in each accumulation step much bigger and then lead to OOM.
116
+ # therefore we re-calculate graduent_accumulation_steps based on the effective batch size
117
+
118
+ if self.args.neighbor_prompt_prob > 0:
119
+ effective_batch_size = self.args.max_num_tokens // self.args.gradient_accumulation_steps
120
+ total_batch_size = sum(batch['y_lens']).item()
121
+ cur_gradient_accumulation_steps = max(self.args.gradient_accumulation_steps, total_batch_size // effective_batch_size)
122
+ gas = torch.tensor(cur_gradient_accumulation_steps, dtype=torch.int, device=self.local_rank)
123
+ dist.all_reduce(gas, op=dist.ReduceOp.MAX)
124
+ cur_gradient_accumulation_steps = gas.item()
125
+ len_batch = torch.tensor(len(batch['y']), dtype=torch.int, device=self.local_rank)
126
+ dist.all_reduce(len_batch, op=dist.ReduceOp.MIN)
127
+ len_batch = len_batch.item()
128
+ cur_gradient_accumulation_steps = min(cur_gradient_accumulation_steps, len_batch)
129
+ # for those that cur_gradient_accumulation_steps * effective_batch_size < total_batch_size, we only use the first cur_gradient_accumulation_steps * effective_batch_size samples
130
+ cur_len = 0
131
+ final_all_inds = []
132
+ pointer = 0
133
+ while cur_len < self.args.max_num_tokens and pointer < len(all_inds):
134
+ cur_len += batch['y_lens'][pointer]
135
+ final_all_inds.append(all_inds[pointer])
136
+ pointer += 1
137
+ all_inds = final_all_inds
138
+ else:
139
+ cur_gradient_accumulation_steps = self.args.gradient_accumulation_steps
140
+
141
+
142
+ sum_losses_local = 0.0
143
+ sum_top10acc_local = 0.0
144
+ sum_entropy_loss_local = 0.0
145
+ sum_ctc_loss_local = 0.0
146
+ sum_ntoken_local = 0.0
147
+ sum_top10acc_cbi_local = [0.0 for _ in range(self.args.n_codebooks)]
148
+
149
+ global_nan_flag = 0
150
+ for j in range(cur_gradient_accumulation_steps):
151
+ cur_ind = all_inds[j::cur_gradient_accumulation_steps]
152
+ cur_batch = {key: batch[key][cur_ind] for key in batch}
153
+
154
+ # Automatic casting
155
+ if self.args.precision == "float16":
156
+ precision_used = torch.float16
157
+ elif self.args.precision in ["bf16", "bfloat16"]:
158
+ precision_used = torch.bfloat16
159
+ else:
160
+ precision_used = torch.float32
161
+
162
+ with torch.amp.autocast('cuda', dtype=precision_used):
163
+ out = self.model(cur_batch, calc_loss=True)
164
+ if out is None:
165
+ continue
166
+
167
+ if torch.isnan(out['loss']).any():
168
+ local_nan_flag = torch.tensor(1, device=self.local_rank)
169
+ else:
170
+ local_nan_flag = torch.tensor(0, device=self.local_rank)
171
+
172
+ # All ranks check if *any* rank got a NaN
173
+ dist.all_reduce(local_nan_flag, op=dist.ReduceOp.SUM)
174
+ global_nan_flag = local_nan_flag.item()
175
+ if global_nan_flag > 0:
176
+ # Now *all* ranks break at the same j
177
+ logging.info(f"rank: {self.rank}. Loss at micro-batch {j} in step {self.progress['step']} was NaN on at least one rank; skipping.")
178
+ break
179
+
180
+ # Accumulate local values
181
+ record_loss = out['loss'].detach()
182
+ top10acc = out['top10acc'].detach()
183
+ effective_ntoken = out['effective_ntoken'].detach()
184
+
185
+ sum_losses_local += record_loss.item()
186
+ sum_top10acc_local += top10acc.item()
187
+ sum_ntoken_local += effective_ntoken.item()
188
+
189
+ # Optional losses
190
+ if 'entropy_loss' in out:
191
+ sum_entropy_loss_local += out['entropy_loss'].detach().item()
192
+ if 'ctc_loss' in out:
193
+ sum_ctc_loss_local += out['ctc_loss'].detach().item()
194
+
195
+ # Codebook accuracy
196
+ if 'top10acc_by_codebook' in out:
197
+ for cb in range(self.args.n_codebooks):
198
+ sum_top10acc_cbi_local[cb] += out['top10acc_by_codebook'][cb].detach().item()
199
+
200
+ # Backprop on this micro-batch
201
+ if self.args.optimizer_name == "ScaledAdam":
202
+ self.scaler.scale(out['loss']).backward()
203
+ else:
204
+ self.scaler.scale(out['loss'] / out['effective_ntoken']).backward()
205
+
206
+ if global_nan_flag > 0:
207
+ # If *any* rank had NaN, skip this step
208
+ logging.info(f"rank: {self.rank}. Loss at one micro-batch in step {self.progress['step']} was NaN on at least one rank; skipping.")
209
+ self.progress['step'] += 1
210
+ self.progress['cur_step'] += 1
211
+ self.optimizer.zero_grad()
212
+ continue
213
+
214
+ # Otherwise, do one big reduce for the summed metrics
215
+ metrics_tensor = torch.tensor([
216
+ sum_losses_local,
217
+ sum_top10acc_local,
218
+ sum_entropy_loss_local,
219
+ sum_ctc_loss_local,
220
+ sum_ntoken_local
221
+ ], device=self.local_rank, dtype=torch.float32)
222
+
223
+ dist.all_reduce(metrics_tensor, op=dist.ReduceOp.SUM)
224
+
225
+ # Also reduce the codebook array in one shot if needed
226
+ codebook_tensor = torch.tensor(sum_top10acc_cbi_local, device=self.local_rank, dtype=torch.float32)
227
+ dist.all_reduce(codebook_tensor, op=dist.ReduceOp.SUM)
228
+
229
+ # Convert them back to Python scalars
230
+ sum_losses = metrics_tensor[0].item()
231
+ sum_top10acc = metrics_tensor[1].item()
232
+ sum_entropy_loss = metrics_tensor[2].item()
233
+ sum_ctc_loss = metrics_tensor[3].item()
234
+ sum_ntoken = metrics_tensor[4].item()
235
+
236
+ sum_top10acc_cbi = codebook_tensor.tolist()
237
+
238
+ if self.args.optimizer_name != "ScaledAdam":
239
+ self.scaler.unscale_(self.optimizer)
240
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.gradient_clip_val)
241
+
242
+ self.scaler.step(self.optimizer)
243
+ self.scaler.update()
244
+ self.optimizer.zero_grad()
245
+
246
+ if self.args.optimizer_name == "ScaledAdam":
247
+ self.scheduler.step_batch(self.progress['step'])
248
+ else:
249
+ self.scheduler.step()
250
+
251
+ # logging
252
+ if self.rank == 0:
253
+ average_loss = sum_losses / sum_ntoken
254
+ average_top10acc = sum_top10acc / sum_ntoken
255
+ average_top10acc_cbi = [sum_top10acc_cbi[cb] / sum_ntoken * self.args.n_codebooks for cb in range(self.args.n_codebooks)]
256
+ self.meters['train_loss'].update(average_loss, batch['x'].shape[0]*self.world_size)
257
+ self.meters['train_top10acc'].update(average_top10acc, batch['x'].shape[0]*self.world_size)
258
+ self.meters['train_top10acc'].update(average_top10acc, batch['x'].shape[0]*self.world_size)
259
+ for cb in range(self.args.n_codebooks):
260
+ self.meters[f'train_top10acc_cb{cb+1}'].update(average_top10acc_cbi[cb], batch['x'].shape[0]*self.world_size)
261
+ self.meters['data_time'].update(data_end_time - data_start_time)
262
+ self.meters['train_time'].update(time.time() - data_end_time)
263
+
264
+ # log extra losses
265
+ for key in sum_extra_losses:
266
+ if "train_"+key not in self.meters:
267
+ self.meters["train_"+key] = AverageMeter()
268
+ self.meters["train_"+key].update(sum(sum_extra_losses[key])/len(sum_extra_losses[key]), batch['x'].shape[0]*self.world_size)
269
+
270
+ if self.progress['step'] % self.args.tb_write_every_n_steps == 0:
271
+ self.writer.add_scalar('train/loss', average_loss, self.progress['step'])
272
+ self.writer.add_scalar('train/top10acc', average_top10acc, self.progress['step'])
273
+ self.writer.add_scalar("train/ntokens", sum_ntoken, self.progress['step'])
274
+ self.wandb.log({"train/loss": average_loss, "train/top10acc": average_top10acc, "train/ntokens": sum_ntoken, "train/data_time": data_end_time - data_start_time, "train/train_time": time.time() - data_end_time}, step=self.progress['step'])
275
+
276
+ for cb in range(self.args.n_codebooks):
277
+ self.writer.add_scalar(f'train/top10acc_cb{cb+1}', average_top10acc_cbi[cb], self.progress['step'])
278
+ self.wandb.log({f'train/top10acc_cb{cb+1}': average_top10acc_cbi[cb]}, step=self.progress['step'])
279
+ self.writer.add_scalar("train/data_time", data_end_time - data_start_time, self.progress['step'])
280
+ self.writer.add_scalar("train/train_time", time.time() - data_end_time, self.progress['step'])
281
+ # write extra losses
282
+ for key in sum_extra_losses:
283
+ self.writer.add_scalar(f"train/{key}", sum(sum_extra_losses[key])/len(sum_extra_losses[key]), self.progress['step'])
284
+ self.wandb.log({f"train/{key}": sum(sum_extra_losses[key])/len(sum_extra_losses[key])}, step=self.progress['step'])
285
+ # logging.info(f"ntoken: {sum_ntoken}")
286
+
287
+ # logging
288
+ if self.progress['step'] % self.args.print_every_n_steps == 0:
289
+ log_out = {}
290
+ log_out['cur_epoch'] = f"{self.progress['epoch']}/{self.args.num_epochs}" if self.args.num_epochs is not None else f"{self.progress['epoch']}"
291
+ log_out['cur_step'] = f"{int(self.progress['cur_step']+1)}"
292
+ log_out['total_step'] = f"{self.progress['step']}/{self.args.num_steps}"
293
+ log_out['lr'] = f"{cur_lr:.7f}"
294
+ log_out['ntokens'] = f"{sum_ntoken}"
295
+ for key in self.meters:
296
+ if self.meters[key].val != 0 or self.meters[key].avg != 0:
297
+ log_out[key] = f"{self.meters[key].val:.4f} ({self.meters[key].avg:.4f})" if isinstance(self.meters[key].val, float) else f"{self.meters[key].val}"
298
+ logging.info(log_out)
299
+ if np.isnan(self.meters['train_loss'].avg):
300
+ logging.warning("training diverged...")
301
+ raise RuntimeError("training diverged...")
302
+
303
+ # save the model only
304
+ if self.progress['step'] % self.args.save_every_n_steps == 0:
305
+ dist.barrier()
306
+ if self.rank == 0:
307
+ save_path = os.path.join(self.args.exp_dir,f"bundle_step{self.progress['step']}.pth")
308
+ self.save_progress(name=f"step{self.progress['step']}")
309
+ torch.save(
310
+ {
311
+ "model": self.model.module.state_dict(),
312
+ "args": self.args,
313
+ "phn2num": self.train_loader.dataset.phn2num,
314
+ "optimizer": self.optimizer.state_dict(),
315
+ "scheduler": self.scheduler.state_dict(),
316
+ },save_path
317
+ )
318
+ logging.info(f"save model, optimizer, scheduler and progress at {save_path} at global step {self.progress['step']}")
319
+ dist.barrier()
320
+
321
+ # validation and save models
322
+ if self.progress['step'] % self.args.val_every_n_steps == 0:
323
+ dist.barrier()
324
+ continue_training = self.validate_and_save()
325
+ # broadcast continue_training to all processes, so that all processes gets into generation stage
326
+ continue_training = torch.tensor(int(continue_training), dtype=torch.int, device=self.local_rank)
327
+ dist.broadcast(continue_training, src=0)
328
+ continue_training = bool(continue_training.item())
329
+ dist.barrier() # need this to ensure all processes get to the next line?
330
+ logging.info(f"rank: {self.rank}, continue_training: {continue_training}")
331
+ if not continue_training:
332
+ if self.rank == 0:
333
+ self.writer.close()
334
+ self.wandb.finish()
335
+ flag = False
336
+ break
337
+
338
+ self.progress['step'] += 1
339
+ self.progress['cur_step'] += 1
340
+
341
+ data_start_time = time.time()
342
+ self.progress['epoch'] += 1
343
+ self.progress['cur_step'] = 0 # reset cur_step to be 0
344
+ dist.destroy_process_group()
345
+
346
+ def validate_and_save(self):
347
+ self.model.eval()
348
+
349
+ score = self.validate(self.valid_loader)
350
+
351
+ if self.args.early_stop_threshold > 0:
352
+ if self.progress['best_score'] - score < self.args.early_stop_threshold:
353
+ self.early_stop_accu_steps += self.args.val_every_n_steps
354
+ if self.early_stop_accu_steps >= self.args.early_stop_step-1:
355
+ logging.info(f"early stop based on self.args.early_stop_threshold: {self.args.early_stop_threshold}, and self.args.early_stop_step: {self.args.early_stop_step}")
356
+ logging.info(f"best validation score at step: {self.progress['best_step']}, and the score is {self.progress['best_score']:.4f}")
357
+ return False
358
+ else:
359
+ self.early_stop_accu_steps = 0
360
+
361
+ if self.rank == 0:
362
+ save_path = os.path.join(self.args.exp_dir,"bundle.pth")
363
+ if os.path.isfile(save_path):
364
+ os.system(f"mv {save_path} {save_path.replace('.pth', '_prev.pth')}")
365
+ torch.save(
366
+ {
367
+ "model": self.model.module.state_dict(),
368
+ "optimizer": self.optimizer.state_dict(),
369
+ "scheduler": self.scheduler.state_dict(),
370
+ "args": self.args,
371
+ "phn2num": self.train_loader.dataset.phn2num
372
+ },save_path
373
+ )
374
+ self.save_progress()
375
+ logging.info(f"save models, indices, acc and other statistics at {save_path} and {self.args.exp_dir}/progress.pkl at global step {self.progress['step']}")
376
+ if (score < self.progress['best_score']):
377
+ self.progress['best_step'] = self.progress['step']
378
+ self.progress['best_score'] = score
379
+ save_path = os.path.join(self.args.exp_dir,"best_bundle.pth")
380
+ if os.path.isfile(save_path):
381
+ os.system(f"mv {save_path} {save_path.replace('.pth', '_prev.pth')}")
382
+ torch.save(
383
+ {
384
+ "model": self.model.module.state_dict(),
385
+ "optimizer": self.optimizer.state_dict(),
386
+ "scheduler": self.scheduler.state_dict(),
387
+ "args": self.args,
388
+ "phn2num": self.train_loader.dataset.phn2num
389
+ },save_path
390
+ )
391
+ logging.info(f"save *best* models at {save_path} at global step {self.progress['step']}")
392
+
393
+ # sync best score and best step, so that all processes early stop at the same time
394
+ best_score_tensor = torch.tensor(self.progress['best_score'], device=self.local_rank)
395
+ dist.broadcast(best_score_tensor, src=0)
396
+ self.progress['best_score'] = float(best_score_tensor.item())
397
+ best_step_tensor = torch.tensor(self.progress['best_step'], device=self.local_rank)
398
+ dist.broadcast(best_step_tensor, src=0)
399
+ self.progress['best_step'] = int(best_step_tensor.item())
400
+ dist.barrier()
401
+ return True
402
+
403
+ def validate(self, valid_loader=None, hide_progress=True):
404
+ if valid_loader == None:
405
+ valid_loader = self.valid_loader
406
+ self.model.eval()
407
+
408
+ start_val_time = time.time()
409
+ sum_losses = 0
410
+ sum_top10acc = 0
411
+ sum_ntoken = 0
412
+ sum_dur_loss = 0
413
+ sum_dur_acc = 0
414
+ sum_entropy_loss = 0
415
+ sum_ctc_loss = 0
416
+
417
+ sum_top10acc_cbi = [0 for _ in range(self.args.n_codebooks)]
418
+ mean_perplexity_cbi = [0 for _ in range(self.args.n_codebooks)]
419
+
420
+ with torch.no_grad():
421
+ for i, batch in enumerate(tqdm(valid_loader, disable=hide_progress)):
422
+ out = self.model(batch, calc_loss=True) # no reduction is applied to loss
423
+ sum_losses += out['loss']
424
+ sum_top10acc += out['top10acc']
425
+ sum_ntoken += out['effective_ntoken']
426
+ if "dur_loss" in out:
427
+ sum_dur_loss += out['dur_loss']
428
+ sum_dur_acc += out['dur_acc']
429
+ if "entropy_loss" in out:
430
+ sum_entropy_loss += out['entropy_loss']
431
+ if "ctc_loss" in out:
432
+ sum_ctc_loss += out['ctc_loss']
433
+ # logging.info(f"iter {i}::: {sum_losses}, {sum_top10acc}, {sum_ntoken}")
434
+ if 'top10acc_by_codebook' in out:
435
+ for cb in range(self.args.n_codebooks):
436
+ sum_top10acc_cbi[cb] += out['top10acc_by_codebook'][cb]
437
+
438
+ if 'perplexity_by_codebook' in out:
439
+ for cb in range(self.args.n_codebooks):
440
+ mean_perplexity_cbi[cb] += out['perplexity_by_codebook'][cb]
441
+ # if i > 10:
442
+ # break
443
+
444
+
445
+ dist.all_reduce(sum_losses, op=dist.ReduceOp.SUM)
446
+ dist.all_reduce(sum_top10acc, op=dist.ReduceOp.SUM)
447
+ dist.all_reduce(sum_ntoken, op=dist.ReduceOp.SUM)
448
+ if "dur_loss" in out:
449
+ dist.all_reduce(sum_dur_loss, op=dist.ReduceOp.SUM)
450
+ dist.all_reduce(sum_dur_acc, op=dist.ReduceOp.SUM)
451
+ if "entropy_loss" in out:
452
+ dist.all_reduce(sum_entropy_loss, op=dist.ReduceOp.SUM)
453
+ if "ctc_loss" in out:
454
+ dist.all_reduce(sum_ctc_loss, op=dist.ReduceOp.SUM)
455
+
456
+ if 'top10acc_by_codebook' in out:
457
+ for cb in range(self.args.n_codebooks):
458
+ dist.all_reduce(sum_top10acc_cbi[cb], op=dist.ReduceOp.SUM)
459
+
460
+ if 'perplexity_by_codebook' in out:
461
+ for cb in range(self.args.n_codebooks):
462
+ dist.all_reduce(mean_perplexity_cbi[cb], op=dist.ReduceOp.SUM)
463
+
464
+ val_loss = sum_losses / sum_ntoken
465
+ val_top10acc = sum_top10acc / sum_ntoken
466
+
467
+ if self.rank == 0:
468
+ if "dur_loss" in out:
469
+ val_dur_loss = sum_dur_loss / sum_ntoken
470
+ val_dur_acc = sum_dur_acc / sum_ntoken
471
+ self.meters['val_dur_loss'].update(val_dur_loss)
472
+ logging.info(f"val dur_loss: {val_dur_loss:.5f}")
473
+ self.meters['val_dur_acc'].update(val_dur_acc)
474
+ logging.info(f"val dur_acc: {val_dur_acc:.5f}")
475
+ self.writer.add_scalar("val/dur_loss", val_dur_loss, self.progress['step'])
476
+ self.writer.add_scalar("val/dur_acc", val_dur_acc, self.progress['step'])
477
+ self.wandb.log({"val/dur_loss": val_dur_loss, "val/dur_acc": val_dur_acc}, step=self.progress['step'])
478
+ # logging
479
+ self.meters['val_loss'].update(val_loss)
480
+ logging.info(f"val loss: {val_loss:.5f}")
481
+ self.writer.add_scalar("val/loss", val_loss, self.progress['step'])
482
+ self.wandb.log({"val/loss": val_loss}, step=self.progress['step'])
483
+
484
+ self.meters['val_top10acc'].update(val_top10acc)
485
+ logging.info(f"val top10acc: {val_top10acc:.5f}")
486
+ self.writer.add_scalar("val/top10acc", val_top10acc, self.progress['step'])
487
+ self.wandb.log({"val/top10acc": val_top10acc}, step=self.progress['step'])
488
+ for cb in range(self.args.n_codebooks):
489
+ average_top10acc_cbi = sum_top10acc_cbi[cb] / sum_ntoken * self.args.n_codebooks
490
+ self.meters[f'val_top10acc_cb{cb+1}'].update(average_top10acc_cbi)
491
+ self.writer.add_scalar(f'val/top10acc_cb{cb+1}', average_top10acc_cbi, self.progress['step'])
492
+ self.wandb.log({f'val/top10acc_cb{cb+1}': average_top10acc_cbi}, step=self.progress['step'])
493
+
494
+ temp = mean_perplexity_cbi[cb]/len(valid_loader)
495
+ self.writer.add_scalar(f'val/perplexity_cb{cb+1}', temp, self.progress['step'])
496
+ self.wandb.log({f'val/perplexity_cb{cb+1}': temp}, step=self.progress['step'])
497
+
498
+ average_perplexity = sum(mean_perplexity_cbi)/(self.args.n_codebooks*len(valid_loader))
499
+ self.wandb.log({"val/average_perplexity": average_perplexity}, step=self.progress['step'])
500
+ self.writer.add_scalar('val/average_perplexity', average_perplexity, self.progress['step'])
501
+
502
+ # log entropy and ctc loss
503
+ if "entropy_loss" in out:
504
+ val_entropy_loss = sum_entropy_loss / ((i+1) * self.world_size)
505
+ self.meters['val_entropy_loss'].update(val_entropy_loss)
506
+ logging.info(f"val entropy_loss: {val_entropy_loss:.5f}")
507
+ self.writer.add_scalar("val/entropy_loss", val_entropy_loss, self.progress['step'])
508
+ self.wandb.log({"val/entropy_loss": val_entropy_loss}, step=self.progress['step'])
509
+ if "ctc_loss" in out:
510
+ val_ctc_loss = sum_ctc_loss / ((i+1) * self.world_size)
511
+ self.meters['val_ctc_loss'].update(val_ctc_loss)
512
+ logging.info(f"val ctc_loss: {val_ctc_loss:.5f}")
513
+ self.writer.add_scalar("val/ctc_loss", val_ctc_loss, self.progress['step'])
514
+ self.wandb.log({"val/ctc_loss": val_ctc_loss}, step=self.progress['step'])
515
+
516
+ logging.info(f"validation takes: {time.time() - start_val_time:.2f}s")
517
+ logging.info(f"Step [{self.progress['step']}/{self.total_step}]\t Time elapsed {(time.time() - self.start_time)/3600.:.2f}h, Val Loss: {val_loss:.4f}, Val Top10Acc: {val_top10acc:.4f}")
518
+
519
+ return val_loss.item()
520
+
521
+ def _setup_meters(self):
522
+ meters = {}
523
+ meter_names = ['train_loss', 'val_loss', 'train_top10acc', 'val_top10acc', 'data_time', 'train_time']
524
+ meter_names += ['train_dur_loss', 'train_dur_acc', 'val_dur_loss', 'val_dur_acc']
525
+ meter_names += ['val_perplexity']
526
+ meter_names += [f'train_top10acc_cb{cb+1}' for cb in range(self.args.n_codebooks)]
527
+ meter_names += [f'val_top10acc_cb{cb+1}' for cb in range(self.args.n_codebooks)]
528
+ meter_names += [f'val_perplexity_cb{cb+1}' for cb in range(self.args.n_codebooks)]
529
+ for name in meter_names:
530
+ meters[name] = AverageMeter()
531
+ return meters
532
+ def _setup_progress(self):
533
+ """
534
+ Need to customize it
535
+ """
536
+ progress = {}
537
+ progress['best_step'] = 1
538
+ progress['best_score'] = np.inf # this records loss value
539
+ progress['step'] = 1
540
+ progress['epoch'] = 1
541
+ progress['cur_step'] = 0 # step in the current epoch, for resuming the sampler
542
+ total_progress = []
543
+ # if self.args.resume or self.args.validate:
544
+ if self.args.resume:
545
+ progress_pkl = "%s/progress.pkl" % self.args.exp_dir
546
+ with open(progress_pkl, "rb") as f:
547
+ total_progress = pickle.load(f)
548
+ progress['best_step'], progress['best_score'], progress['step'], progress['epoch'], progress['cur_step'], _ = total_progress[-1]
549
+ if self.rank == 0:
550
+ logging.info("\nResume training from:")
551
+ logging.info(" epoch = %s" % progress['epoch'])
552
+ logging.info(" cur_step = %s" % progress['cur_step'])
553
+ logging.info(" step = %s" % progress['step'])
554
+ logging.info(" best_step = %s" % progress['best_step'])
555
+ logging.info(" best_score = %s" % progress['best_score'])
556
+ return progress, total_progress
557
+
558
+ def save_progress(self, name=None):
559
+ self.total_progress.append([self.progress['best_step'], self.progress['best_score'], int(self.progress['step']+1), self.progress['epoch'], int(self.progress['cur_step']+1), time.time() - self.start_time])
560
+ if name is not None:
561
+ progress_fn = f"{self.args.exp_dir}/progress_{name}.pkl"
562
+ else:
563
+ progress_fn = f"{self.args.exp_dir}/progress.pkl"
564
+ with open(progress_fn, "wb") as f:
565
+ pickle.dump(self.total_progress, f)
566
+
567
+ def _setup_dataloader(self):
568
+ train_dataset, val_dataset = combined_dataset.dataset(self.args, 'train'), combined_dataset.dataset(self.args, 'valid') # need to change 'train' to 'valid' in actual training
569
+
570
+ if self.args.dynamic_batching:
571
+ train_sampler = DistributedDynamicBatchSampler(train_dataset, self.args, num_replicas=self.world_size, rank=self.rank, shuffle=True, seed=self.args.seed, drop_last=True, lengths_list=train_dataset.lengths_list, verbose=True, epoch=0)
572
+ valid_sampler = DistributedDynamicBatchSampler(val_dataset, self.args, num_replicas=self.world_size, rank=self.rank, shuffle=True, seed=self.args.seed, drop_last=True, lengths_list=val_dataset.lengths_list, verbose=True, epoch=0)
573
+ else:
574
+ train_sampler = StatefulDistributedSampler(train_dataset, self.args.batch_size//self.world_size, num_replicas=self.world_size, rank=self.rank, shuffle=True, seed=self.args.seed, drop_last=True)
575
+ valid_sampler = DistributedSampler(val_dataset, num_replicas=self.world_size, rank=self.rank, shuffle=False, seed=self.args.seed, drop_last=False)
576
+
577
+ if self.progress['step'] > 1:
578
+ train_sampler.set_epoch_resume(self.progress['epoch'], self.progress['cur_step'])
579
+ assert self.phn2num != None
580
+
581
+ if self.phn2num != None:
582
+ train_dataset.phn2num = self.phn2num
583
+ val_dataset.phn2num = self.phn2num
584
+
585
+ if self.args.dynamic_batching:
586
+ train_loader = torch.utils.data.DataLoader(train_dataset,
587
+ batch_sampler=train_sampler,
588
+ num_workers=self.args.num_workers,
589
+ collate_fn=train_dataset.collate, persistent_workers=True
590
+ )
591
+ valid_loader = torch.utils.data.DataLoader(val_dataset,
592
+ batch_sampler=valid_sampler,
593
+ num_workers=self.args.num_workers,
594
+ collate_fn=val_dataset.collate, persistent_workers=True
595
+ )
596
+ else:
597
+ train_loader = torch.utils.data.DataLoader(train_dataset,
598
+ batch_size=self.args.batch_size, sampler=train_sampler, num_workers=self.args.num_workers,
599
+ collate_fn=train_dataset.collate, persistent_workers=True
600
+ )
601
+ valid_loader = torch.utils.data.DataLoader(val_dataset,
602
+ batch_size=self.args.batch_size, sampler=valid_sampler,
603
+ num_workers=self.args.num_workers,
604
+ collate_fn=val_dataset.collate, persistent_workers=True
605
+ )
606
+ return len(train_dataset), train_sampler, train_loader, valid_loader
607
+
608
+
609
+
610
+ def _setup_models(self):
611
+ model = voice_star.VoiceStar(self.args)
612
+
613
+ if self.rank == 0:
614
+ logging.info(model)
615
+ logging.info("model parameters")
616
+ print_model_info(model)
617
+
618
+ phn2num = None
619
+ optim_states = None
620
+ scheduler_states = None
621
+ if self.progress['step'] > 1:
622
+ bundle = torch.load(os.path.join(self.args.exp_dir, "bundle.pth"), map_location="cpu")
623
+ model.load_state_dict(bundle['model'])
624
+ optim_states = bundle['optimizer']
625
+ scheduler_states = bundle['scheduler']
626
+ phn2num = bundle['phn2num']
627
+ if self.rank == 0:
628
+ logging.info("loaded parameters and data indices from epoch %d, global step %d" % (self.progress['epoch'], self.progress['step']))
629
+ del bundle['model']
630
+
631
+ if self.args.load_model_from != None and self.progress['step'] <= 1:
632
+ logging.info(f"load weights from {self.args.load_model_from}")
633
+ sd = torch.load(self.args.load_model_from, map_location="cpu")
634
+ if hasattr(model, "carefully_load_state_dict"):
635
+ model.carefully_load_state_dict(sd['model'])
636
+ else:
637
+ model.load_state_dict(sd['model'])
638
+ phn2num = sd['phn2num']
639
+ del sd
640
+
641
+
642
+ #### below operations is for getting params for optimizer, which is at wrapper level ###
643
+ if self.args.optimizer_name == "ScaledAdam":
644
+ trainables = [p for p in model.parameters() if p.requires_grad]
645
+ else:
646
+ no_decay = [".bias", ".audio_embeddings.weight", ".text_embeddings.weight", ".norm.weight", ".norm1.weight", ".norm2.weight"]
647
+ optimizer_grouped_parameters = [
648
+ {
649
+ "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad],
650
+ "weight_decay": self.args.weight_decay,
651
+ },
652
+ {
653
+ "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad],
654
+ "weight_decay": 0.0,
655
+ },
656
+ ]
657
+ if len(optimizer_grouped_parameters[1]['params']) == 0:
658
+ logging.info("there is no embedding weights, bias, and layernorm parameters in the model, which should be True, check model parameter names")
659
+ trainables = optimizer_grouped_parameters[0]
660
+ else:
661
+ trainables = optimizer_grouped_parameters
662
+ #### below operations is for getting params for optimizer, which is at wrapper level ###
663
+ model.to(self.device)
664
+
665
+ return model, trainables, optim_states, scheduler_states, phn2num
666
+
667
+
668
+ def _setup_optimizer(self):
669
+ if self.args.optimizer_name == "ScaledAdam":
670
+ parameters_names = []
671
+ _model = self.model.module if isinstance(self.model, torch.nn.parallel.DistributedDataParallel) else self.model
672
+ parameters_names.append([n for n,p in self.model.named_parameters() if p.requires_grad])
673
+ optimizer = ScaledAdam(
674
+ self.trainables,
675
+ lr=self.args.lr,
676
+ betas=(0.9, 0.95),
677
+ clipping_scale=2.0,
678
+ parameters_names=parameters_names,
679
+ show_dominant_parameters=False,
680
+ clipping_update_period=self.args.clipping_update_period,
681
+ )
682
+ scheduler = Eden(optimizer, self.args.reduce_lr_start_step, self.args.reduce_lr_start_epoch, warmup_batches=self.total_step * self.args.warmup_fraction) # NOTE: if using ScaledAdam, we will use the Eden scheduler!
683
+
684
+ else:
685
+ optimizer = AdamW(self.trainables, lr=self.args.lr)
686
+ warmup_steps = self.total_step * self.args.warmup_fraction
687
+ def lr_lambda(current_step: int):
688
+ if current_step < warmup_steps:
689
+ return float(current_step) / float(max(1, warmup_steps))
690
+ return max(
691
+ 0.0, float(self.total_step - current_step) / float(max(1, self.total_step - warmup_steps))
692
+ )
693
+
694
+ scheduler = LambdaLR(optimizer, lr_lambda, last_epoch=-1)
695
+
696
+ # if resume
697
+ if self.progress['step'] > 1:
698
+ optimizer.load_state_dict(self.optim_states)
699
+ for state in optimizer.state.values():
700
+ for k, v in state.items():
701
+ if isinstance(v, torch.Tensor):
702
+ state[k] = v.cuda()
703
+ del self.optim_states
704
+
705
+ scheduler.load_state_dict(self.scheduler_states)
706
+
707
+ optimizer.zero_grad()
708
+ return optimizer, scheduler
709
+
710
+ def seed_everything(self, seed=1):
711
+ os.environ['PYTHONHASHSEED'] = str(seed)
712
+ random.seed(seed)
713
+ np.random.seed(seed)
714
+ torch.manual_seed(seed)
715
+ torch.cuda.manual_seed(seed)
716
+ torch.backends.cudnn.benchmark = False
717
+ torch.backends.cudnn.deterministic = True