Spaces:
Running
on
Zero
Running
on
Zero
Upload 51 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- LICENSE +21 -0
- MODEL-LICENSE +395 -0
- config.py +254 -0
- copy_codebase.py +56 -0
- data/__init__.py +0 -0
- data/combined_dataset.py +466 -0
- data/emilia_preprocessing/delete_tar_files.sh +42 -0
- data/emilia_preprocessing/encodec.py +1554 -0
- data/emilia_preprocessing/sha256hash.py +14 -0
- data/emilia_preprocessing/step1_download.py +9 -0
- data/emilia_preprocessing/step2_log_tar_files.sh +27 -0
- data/emilia_preprocessing/step3_untar.sh +101 -0
- data/emilia_preprocessing/step4_construct_manifest.py +251 -0
- data/emilia_preprocessing/step5_phonemize.py +158 -0
- data/emilia_preprocessing/step6_encodec_encode.py +177 -0
- data/emilia_preprocessing/step6_encodec_encode_script.sh +19 -0
- data/encodec.py +1554 -0
- data/ll60k_preprocessing/config.yaml +75 -0
- data/ll60k_preprocessing/encodec.py +1554 -0
- data/ll60k_preprocessing/step1_download.sh +42 -0
- data/ll60k_preprocessing/step2_resplit_long.py +51 -0
- data/ll60k_preprocessing/step3_seg_phn_manifest.py +194 -0
- data/ll60k_preprocessing/step4_encodec_encode.py +184 -0
- data/ll60k_preprocessing/step4_encodec_encode_script.sh +19 -0
- data/ll60k_preprocessing/step5_find_nearest_neighbor.py +157 -0
- data/ll60k_preprocessing/step6_forced_alignment.py +86 -0
- data/ll60k_preprocessing/step6_forced_alignment.sh +13 -0
- data/ll60k_preprocessing/step7_ipa_alignment.py +114 -0
- data/ll60k_preprocessing/tokenizer.py +460 -0
- data/tokenizer.py +295 -0
- demo/5895_34622_000026_000002.wav +3 -0
- generated_tts/generated.wav +3 -0
- inference_commandline.py +192 -0
- inference_gradio.py +334 -0
- inference_tts_utils.py +155 -0
- main.py +82 -0
- models/modules/__init__.py +0 -0
- models/modules/activation.py +781 -0
- models/modules/embedding.py +158 -0
- models/modules/sampling.py +63 -0
- models/modules/scaling.py +1406 -0
- models/modules/transformer.py +1089 -0
- models/modules/utils.py +36 -0
- models/modules/visualizer.py +107 -0
- models/voice_star.py +784 -0
- pretrained/.gitkeep +0 -0
- steps/__init__.py +0 -0
- steps/optim.py +1123 -0
- steps/trainer.py +717 -0
.gitattributes
CHANGED
@@ -36,3 +36,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
36 |
examples/web_6f93090a-81f6-489e-bb35-1a2838b18c01.png filter=lfs diff=lfs merge=lfs -text
|
37 |
examples/web_dfacd48d-d2c2-492f-b94c-41e6a34ea99f.png filter=lfs diff=lfs merge=lfs -text
|
38 |
illustration.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
36 |
examples/web_6f93090a-81f6-489e-bb35-1a2838b18c01.png filter=lfs diff=lfs merge=lfs -text
|
37 |
examples/web_dfacd48d-d2c2-492f-b94c-41e6a34ea99f.png filter=lfs diff=lfs merge=lfs -text
|
38 |
illustration.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
demo/5895_34622_000026_000002.wav filter=lfs diff=lfs merge=lfs -text
|
40 |
+
generated_tts/generated.wav filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2025 Puyuan Peng
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
MODEL-LICENSE
ADDED
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Attribution 4.0 International
|
2 |
+
|
3 |
+
=======================================================================
|
4 |
+
|
5 |
+
Creative Commons Corporation ("Creative Commons") is not a law firm and
|
6 |
+
does not provide legal services or legal advice. Distribution of
|
7 |
+
Creative Commons public licenses does not create a lawyer-client or
|
8 |
+
other relationship. Creative Commons makes its licenses and related
|
9 |
+
information available on an "as-is" basis. Creative Commons gives no
|
10 |
+
warranties regarding its licenses, any material licensed under their
|
11 |
+
terms and conditions, or any related information. Creative Commons
|
12 |
+
disclaims all liability for damages resulting from their use to the
|
13 |
+
fullest extent possible.
|
14 |
+
|
15 |
+
Using Creative Commons Public Licenses
|
16 |
+
|
17 |
+
Creative Commons public licenses provide a standard set of terms and
|
18 |
+
conditions that creators and other rights holders may use to share
|
19 |
+
original works of authorship and other material subject to copyright
|
20 |
+
and certain other rights specified in the public license below. The
|
21 |
+
following considerations are for informational purposes only, are not
|
22 |
+
exhaustive, and do not form part of our licenses.
|
23 |
+
|
24 |
+
Considerations for licensors: Our public licenses are
|
25 |
+
intended for use by those authorized to give the public
|
26 |
+
permission to use material in ways otherwise restricted by
|
27 |
+
copyright and certain other rights. Our licenses are
|
28 |
+
irrevocable. Licensors should read and understand the terms
|
29 |
+
and conditions of the license they choose before applying it.
|
30 |
+
Licensors should also secure all rights necessary before
|
31 |
+
applying our licenses so that the public can reuse the
|
32 |
+
material as expected. Licensors should clearly mark any
|
33 |
+
material not subject to the license. This includes other CC-
|
34 |
+
licensed material, or material used under an exception or
|
35 |
+
limitation to copyright. More considerations for licensors:
|
36 |
+
wiki.creativecommons.org/Considerations_for_licensors
|
37 |
+
|
38 |
+
Considerations for the public: By using one of our public
|
39 |
+
licenses, a licensor grants the public permission to use the
|
40 |
+
licensed material under specified terms and conditions. If
|
41 |
+
the licensor's permission is not necessary for any reason--for
|
42 |
+
example, because of any applicable exception or limitation to
|
43 |
+
copyright--then that use is not regulated by the license. Our
|
44 |
+
licenses grant only permissions under copyright and certain
|
45 |
+
other rights that a licensor has authority to grant. Use of
|
46 |
+
the licensed material may still be restricted for other
|
47 |
+
reasons, including because others have copyright or other
|
48 |
+
rights in the material. A licensor may make special requests,
|
49 |
+
such as asking that all changes be marked or described.
|
50 |
+
Although not required by our licenses, you are encouraged to
|
51 |
+
respect those requests where reasonable. More considerations
|
52 |
+
for the public:
|
53 |
+
wiki.creativecommons.org/Considerations_for_licensees
|
54 |
+
|
55 |
+
=======================================================================
|
56 |
+
|
57 |
+
Creative Commons Attribution 4.0 International Public License
|
58 |
+
|
59 |
+
By exercising the Licensed Rights (defined below), You accept and agree
|
60 |
+
to be bound by the terms and conditions of this Creative Commons
|
61 |
+
Attribution 4.0 International Public License ("Public License"). To the
|
62 |
+
extent this Public License may be interpreted as a contract, You are
|
63 |
+
granted the Licensed Rights in consideration of Your acceptance of
|
64 |
+
these terms and conditions, and the Licensor grants You such rights in
|
65 |
+
consideration of benefits the Licensor receives from making the
|
66 |
+
Licensed Material available under these terms and conditions.
|
67 |
+
|
68 |
+
|
69 |
+
Section 1 -- Definitions.
|
70 |
+
|
71 |
+
a. Adapted Material means material subject to Copyright and Similar
|
72 |
+
Rights that is derived from or based upon the Licensed Material
|
73 |
+
and in which the Licensed Material is translated, altered,
|
74 |
+
arranged, transformed, or otherwise modified in a manner requiring
|
75 |
+
permission under the Copyright and Similar Rights held by the
|
76 |
+
Licensor. For purposes of this Public License, where the Licensed
|
77 |
+
Material is a musical work, performance, or sound recording,
|
78 |
+
Adapted Material is always produced where the Licensed Material is
|
79 |
+
synched in timed relation with a moving image.
|
80 |
+
|
81 |
+
b. Adapter's License means the license You apply to Your Copyright
|
82 |
+
and Similar Rights in Your contributions to Adapted Material in
|
83 |
+
accordance with the terms and conditions of this Public License.
|
84 |
+
|
85 |
+
c. Copyright and Similar Rights means copyright and/or similar rights
|
86 |
+
closely related to copyright including, without limitation,
|
87 |
+
performance, broadcast, sound recording, and Sui Generis Database
|
88 |
+
Rights, without regard to how the rights are labeled or
|
89 |
+
categorized. For purposes of this Public License, the rights
|
90 |
+
specified in Section 2(b)(1)-(2) are not Copyright and Similar
|
91 |
+
Rights.
|
92 |
+
|
93 |
+
d. Effective Technological Measures means those measures that, in the
|
94 |
+
absence of proper authority, may not be circumvented under laws
|
95 |
+
fulfilling obligations under Article 11 of the WIPO Copyright
|
96 |
+
Treaty adopted on December 20, 1996, and/or similar international
|
97 |
+
agreements.
|
98 |
+
|
99 |
+
e. Exceptions and Limitations means fair use, fair dealing, and/or
|
100 |
+
any other exception or limitation to Copyright and Similar Rights
|
101 |
+
that applies to Your use of the Licensed Material.
|
102 |
+
|
103 |
+
f. Licensed Material means the artistic or literary work, database,
|
104 |
+
or other material to which the Licensor applied this Public
|
105 |
+
License.
|
106 |
+
|
107 |
+
g. Licensed Rights means the rights granted to You subject to the
|
108 |
+
terms and conditions of this Public License, which are limited to
|
109 |
+
all Copyright and Similar Rights that apply to Your use of the
|
110 |
+
Licensed Material and that the Licensor has authority to license.
|
111 |
+
|
112 |
+
h. Licensor means the individual(s) or entity(ies) granting rights
|
113 |
+
under this Public License.
|
114 |
+
|
115 |
+
i. Share means to provide material to the public by any means or
|
116 |
+
process that requires permission under the Licensed Rights, such
|
117 |
+
as reproduction, public display, public performance, distribution,
|
118 |
+
dissemination, communication, or importation, and to make material
|
119 |
+
available to the public including in ways that members of the
|
120 |
+
public may access the material from a place and at a time
|
121 |
+
individually chosen by them.
|
122 |
+
|
123 |
+
j. Sui Generis Database Rights means rights other than copyright
|
124 |
+
resulting from Directive 96/9/EC of the European Parliament and of
|
125 |
+
the Council of 11 March 1996 on the legal protection of databases,
|
126 |
+
as amended and/or succeeded, as well as other essentially
|
127 |
+
equivalent rights anywhere in the world.
|
128 |
+
|
129 |
+
k. You means the individual or entity exercising the Licensed Rights
|
130 |
+
under this Public License. Your has a corresponding meaning.
|
131 |
+
|
132 |
+
|
133 |
+
Section 2 -- Scope.
|
134 |
+
|
135 |
+
a. License grant.
|
136 |
+
|
137 |
+
1. Subject to the terms and conditions of this Public License,
|
138 |
+
the Licensor hereby grants You a worldwide, royalty-free,
|
139 |
+
non-sublicensable, non-exclusive, irrevocable license to
|
140 |
+
exercise the Licensed Rights in the Licensed Material to:
|
141 |
+
|
142 |
+
a. reproduce and Share the Licensed Material, in whole or
|
143 |
+
in part; and
|
144 |
+
|
145 |
+
b. produce, reproduce, and Share Adapted Material.
|
146 |
+
|
147 |
+
2. Exceptions and Limitations. For the avoidance of doubt, where
|
148 |
+
Exceptions and Limitations apply to Your use, this Public
|
149 |
+
License does not apply, and You do not need to comply with
|
150 |
+
its terms and conditions.
|
151 |
+
|
152 |
+
3. Term. The term of this Public License is specified in Section
|
153 |
+
6(a).
|
154 |
+
|
155 |
+
4. Media and formats; technical modifications allowed. The
|
156 |
+
Licensor authorizes You to exercise the Licensed Rights in
|
157 |
+
all media and formats whether now known or hereafter created,
|
158 |
+
and to make technical modifications necessary to do so. The
|
159 |
+
Licensor waives and/or agrees not to assert any right or
|
160 |
+
authority to forbid You from making technical modifications
|
161 |
+
necessary to exercise the Licensed Rights, including
|
162 |
+
technical modifications necessary to circumvent Effective
|
163 |
+
Technological Measures. For purposes of this Public License,
|
164 |
+
simply making modifications authorized by this Section 2(a)
|
165 |
+
(4) never produces Adapted Material.
|
166 |
+
|
167 |
+
5. Downstream recipients.
|
168 |
+
|
169 |
+
a. Offer from the Licensor -- Licensed Material. Every
|
170 |
+
recipient of the Licensed Material automatically
|
171 |
+
receives an offer from the Licensor to exercise the
|
172 |
+
Licensed Rights under the terms and conditions of this
|
173 |
+
Public License.
|
174 |
+
|
175 |
+
b. No downstream restrictions. You may not offer or impose
|
176 |
+
any additional or different terms or conditions on, or
|
177 |
+
apply any Effective Technological Measures to, the
|
178 |
+
Licensed Material if doing so restricts exercise of the
|
179 |
+
Licensed Rights by any recipient of the Licensed
|
180 |
+
Material.
|
181 |
+
|
182 |
+
6. No endorsement. Nothing in this Public License constitutes or
|
183 |
+
may be construed as permission to assert or imply that You
|
184 |
+
are, or that Your use of the Licensed Material is, connected
|
185 |
+
with, or sponsored, endorsed, or granted official status by,
|
186 |
+
the Licensor or others designated to receive attribution as
|
187 |
+
provided in Section 3(a)(1)(A)(i).
|
188 |
+
|
189 |
+
b. Other rights.
|
190 |
+
|
191 |
+
1. Moral rights, such as the right of integrity, are not
|
192 |
+
licensed under this Public License, nor are publicity,
|
193 |
+
privacy, and/or other similar personality rights; however, to
|
194 |
+
the extent possible, the Licensor waives and/or agrees not to
|
195 |
+
assert any such rights held by the Licensor to the limited
|
196 |
+
extent necessary to allow You to exercise the Licensed
|
197 |
+
Rights, but not otherwise.
|
198 |
+
|
199 |
+
2. Patent and trademark rights are not licensed under this
|
200 |
+
Public License.
|
201 |
+
|
202 |
+
3. To the extent possible, the Licensor waives any right to
|
203 |
+
collect royalties from You for the exercise of the Licensed
|
204 |
+
Rights, whether directly or through a collecting society
|
205 |
+
under any voluntary or waivable statutory or compulsory
|
206 |
+
licensing scheme. In all other cases the Licensor expressly
|
207 |
+
reserves any right to collect such royalties.
|
208 |
+
|
209 |
+
|
210 |
+
Section 3 -- License Conditions.
|
211 |
+
|
212 |
+
Your exercise of the Licensed Rights is expressly made subject to the
|
213 |
+
following conditions.
|
214 |
+
|
215 |
+
a. Attribution.
|
216 |
+
|
217 |
+
1. If You Share the Licensed Material (including in modified
|
218 |
+
form), You must:
|
219 |
+
|
220 |
+
a. retain the following if it is supplied by the Licensor
|
221 |
+
with the Licensed Material:
|
222 |
+
|
223 |
+
i. identification of the creator(s) of the Licensed
|
224 |
+
Material and any others designated to receive
|
225 |
+
attribution, in any reasonable manner requested by
|
226 |
+
the Licensor (including by pseudonym if
|
227 |
+
designated);
|
228 |
+
|
229 |
+
ii. a copyright notice;
|
230 |
+
|
231 |
+
iii. a notice that refers to this Public License;
|
232 |
+
|
233 |
+
iv. a notice that refers to the disclaimer of
|
234 |
+
warranties;
|
235 |
+
|
236 |
+
v. a URI or hyperlink to the Licensed Material to the
|
237 |
+
extent reasonably practicable;
|
238 |
+
|
239 |
+
b. indicate if You modified the Licensed Material and
|
240 |
+
retain an indication of any previous modifications; and
|
241 |
+
|
242 |
+
c. indicate the Licensed Material is licensed under this
|
243 |
+
Public License, and include the text of, or the URI or
|
244 |
+
hyperlink to, this Public License.
|
245 |
+
|
246 |
+
2. You may satisfy the conditions in Section 3(a)(1) in any
|
247 |
+
reasonable manner based on the medium, means, and context in
|
248 |
+
which You Share the Licensed Material. For example, it may be
|
249 |
+
reasonable to satisfy the conditions by providing a URI or
|
250 |
+
hyperlink to a resource that includes the required
|
251 |
+
information.
|
252 |
+
|
253 |
+
3. If requested by the Licensor, You must remove any of the
|
254 |
+
information required by Section 3(a)(1)(A) to the extent
|
255 |
+
reasonably practicable.
|
256 |
+
|
257 |
+
4. If You Share Adapted Material You produce, the Adapter's
|
258 |
+
License You apply must not prevent recipients of the Adapted
|
259 |
+
Material from complying with this Public License.
|
260 |
+
|
261 |
+
|
262 |
+
Section 4 -- Sui Generis Database Rights.
|
263 |
+
|
264 |
+
Where the Licensed Rights include Sui Generis Database Rights that
|
265 |
+
apply to Your use of the Licensed Material:
|
266 |
+
|
267 |
+
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
268 |
+
to extract, reuse, reproduce, and Share all or a substantial
|
269 |
+
portion of the contents of the database;
|
270 |
+
|
271 |
+
b. if You include all or a substantial portion of the database
|
272 |
+
contents in a database in which You have Sui Generis Database
|
273 |
+
Rights, then the database in which You have Sui Generis Database
|
274 |
+
Rights (but not its individual contents) is Adapted Material; and
|
275 |
+
|
276 |
+
c. You must comply with the conditions in Section 3(a) if You Share
|
277 |
+
all or a substantial portion of the contents of the database.
|
278 |
+
|
279 |
+
For the avoidance of doubt, this Section 4 supplements and does not
|
280 |
+
replace Your obligations under this Public License where the Licensed
|
281 |
+
Rights include other Copyright and Similar Rights.
|
282 |
+
|
283 |
+
|
284 |
+
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
285 |
+
|
286 |
+
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
287 |
+
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
288 |
+
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
289 |
+
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
290 |
+
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
291 |
+
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
292 |
+
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
293 |
+
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
294 |
+
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
295 |
+
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
296 |
+
|
297 |
+
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
298 |
+
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
299 |
+
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
300 |
+
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
301 |
+
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
302 |
+
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
303 |
+
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
304 |
+
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
305 |
+
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
306 |
+
|
307 |
+
c. The disclaimer of warranties and limitation of liability provided
|
308 |
+
above shall be interpreted in a manner that, to the extent
|
309 |
+
possible, most closely approximates an absolute disclaimer and
|
310 |
+
waiver of all liability.
|
311 |
+
|
312 |
+
|
313 |
+
Section 6 -- Term and Termination.
|
314 |
+
|
315 |
+
a. This Public License applies for the term of the Copyright and
|
316 |
+
Similar Rights licensed here. However, if You fail to comply with
|
317 |
+
this Public License, then Your rights under this Public License
|
318 |
+
terminate automatically.
|
319 |
+
|
320 |
+
b. Where Your right to use the Licensed Material has terminated under
|
321 |
+
Section 6(a), it reinstates:
|
322 |
+
|
323 |
+
1. automatically as of the date the violation is cured, provided
|
324 |
+
it is cured within 30 days of Your discovery of the
|
325 |
+
violation; or
|
326 |
+
|
327 |
+
2. upon express reinstatement by the Licensor.
|
328 |
+
|
329 |
+
For the avoidance of doubt, this Section 6(b) does not affect any
|
330 |
+
right the Licensor may have to seek remedies for Your violations
|
331 |
+
of this Public License.
|
332 |
+
|
333 |
+
c. For the avoidance of doubt, the Licensor may also offer the
|
334 |
+
Licensed Material under separate terms or conditions or stop
|
335 |
+
distributing the Licensed Material at any time; however, doing so
|
336 |
+
will not terminate this Public License.
|
337 |
+
|
338 |
+
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
339 |
+
License.
|
340 |
+
|
341 |
+
|
342 |
+
Section 7 -- Other Terms and Conditions.
|
343 |
+
|
344 |
+
a. The Licensor shall not be bound by any additional or different
|
345 |
+
terms or conditions communicated by You unless expressly agreed.
|
346 |
+
|
347 |
+
b. Any arrangements, understandings, or agreements regarding the
|
348 |
+
Licensed Material not stated herein are separate from and
|
349 |
+
independent of the terms and conditions of this Public License.
|
350 |
+
|
351 |
+
|
352 |
+
Section 8 -- Interpretation.
|
353 |
+
|
354 |
+
a. For the avoidance of doubt, this Public License does not, and
|
355 |
+
shall not be interpreted to, reduce, limit, restrict, or impose
|
356 |
+
conditions on any use of the Licensed Material that could lawfully
|
357 |
+
be made without permission under this Public License.
|
358 |
+
|
359 |
+
b. To the extent possible, if any provision of this Public License is
|
360 |
+
deemed unenforceable, it shall be automatically reformed to the
|
361 |
+
minimum extent necessary to make it enforceable. If the provision
|
362 |
+
cannot be reformed, it shall be severed from this Public License
|
363 |
+
without affecting the enforceability of the remaining terms and
|
364 |
+
conditions.
|
365 |
+
|
366 |
+
c. No term or condition of this Public License will be waived and no
|
367 |
+
failure to comply consented to unless expressly agreed to by the
|
368 |
+
Licensor.
|
369 |
+
|
370 |
+
d. Nothing in this Public License constitutes or may be interpreted
|
371 |
+
as a limitation upon, or waiver of, any privileges and immunities
|
372 |
+
that apply to the Licensor or You, including from the legal
|
373 |
+
processes of any jurisdiction or authority.
|
374 |
+
|
375 |
+
|
376 |
+
=======================================================================
|
377 |
+
|
378 |
+
Creative Commons is not a party to its public
|
379 |
+
licenses. Notwithstanding, Creative Commons may elect to apply one of
|
380 |
+
its public licenses to material it publishes and in those instances
|
381 |
+
will be considered the “Licensor.” The text of the Creative Commons
|
382 |
+
public licenses is dedicated to the public domain under the CC0 Public
|
383 |
+
Domain Dedication. Except for the limited purpose of indicating that
|
384 |
+
material is shared under a Creative Commons public license or as
|
385 |
+
otherwise permitted by the Creative Commons policies published at
|
386 |
+
creativecommons.org/policies, Creative Commons does not authorize the
|
387 |
+
use of the trademark "Creative Commons" or any other trademark or logo
|
388 |
+
of Creative Commons without its prior written consent including,
|
389 |
+
without limitation, in connection with any unauthorized modifications
|
390 |
+
to any of its public licenses or any other arrangements,
|
391 |
+
understandings, or agreements concerning use of licensed material. For
|
392 |
+
the avoidance of doubt, this paragraph does not form part of the
|
393 |
+
public licenses.
|
394 |
+
|
395 |
+
Creative Commons may be contacted at creativecommons.org.
|
config.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
def int_or_str(value):
|
4 |
+
"""Custom function to allow both int and str types."""
|
5 |
+
try:
|
6 |
+
return int(value) # Try converting to integer
|
7 |
+
except ValueError:
|
8 |
+
return value # If conversion fails, return as string
|
9 |
+
|
10 |
+
def MyParser():
|
11 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
12 |
+
# general training
|
13 |
+
parser.add_argument("--seed", type=int, default=1)
|
14 |
+
parser.add_argument("--debug", type=int, default=0)
|
15 |
+
parser.add_argument("--multinodes", type=int, default=0)
|
16 |
+
parser.add_argument("--dist_url", default="env://", type=str)
|
17 |
+
parser.add_argument("--dist_backend", default="nccl", type=str)
|
18 |
+
parser.add_argument("--precision", type=str, default="float16", help="we might need float32 for NAR model")
|
19 |
+
parser.add_argument("--num_workers", type=int, default=8, help="per gpu")
|
20 |
+
parser.add_argument("--resume", action="store_true", default=False)
|
21 |
+
parser.add_argument("--tb_write_every_n_steps", type=int, default=100)
|
22 |
+
parser.add_argument("--print_every_n_steps", type=int, default=250)
|
23 |
+
parser.add_argument("--val_every_n_steps", type=int, default=500)
|
24 |
+
parser.add_argument("--inference_every_n_steps", type=int, default=3000, help="will only get to inference when model is saved, and therefore this needs to be multiple of val_every_n_steps")
|
25 |
+
parser.add_argument("--save_every_n_steps", type=int, default=10000000, help="save the model every n steps, will save the model as bundle_step$step.pth")
|
26 |
+
parser.add_argument("--lr", type=float, default=1e-4)
|
27 |
+
parser.add_argument("--batch_size", type=int, default=100, help="this is the effective batch size per gpu, no matter whether using gradient_accumulation_steps")
|
28 |
+
parser.add_argument("--weight_decay", type=float, default=1e-2)
|
29 |
+
parser.add_argument("--warmup_fraction", type=float, default=0.1, help="use linear warmup, the proportion of the training steps that are used for warming up")
|
30 |
+
parser.add_argument("--num_epochs", type=int, default=10)
|
31 |
+
parser.add_argument("--num_steps", type=int, default=None, help="if not None, will ignore n_epochs and use num_steps as the total number of amount of training, can try e.g. 400000 i.e. 400k steps")
|
32 |
+
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
|
33 |
+
parser.add_argument("--gradient_clip_val", type=float, default=1.0, help="the value for torch.nn.utils.clip_grad_norm_()")
|
34 |
+
parser.add_argument("--early_stop_step", type=int, default=3200, help="stop training after this many steps of non-improvement")
|
35 |
+
parser.add_argument("--early_stop_threshold", type=float, default=-1.0, help="early stop after the improvement is below this threshold for certain number of steps")
|
36 |
+
|
37 |
+
|
38 |
+
# path
|
39 |
+
parser.add_argument("--exp_dir", type=str, default='/saltpool0/scratch/pyp/VoiceEditor/', help="will be combined with dataset name")
|
40 |
+
parser.add_argument("--dataset", type=str, help="e.g. 'libritts', 'librilight', 'spotify', they are folder name in the data dir also")
|
41 |
+
parser.add_argument("--dataset_dir", type=str, help="need to be compatible with corresponding dataset py file")
|
42 |
+
parser.add_argument("--compact_folder_name", type=str, default=None, help="if not None, will use compact_combined_dataset.py, and this is the folder name of the compact dataset")
|
43 |
+
parser.add_argument("--inference_dataset_dir", type=str, default="/data/scratch/pyp/datasets/librilight/preprocessed", help="need to be compatible with corresponding dataset py file")
|
44 |
+
|
45 |
+
parser.add_argument("--training_stage", type=int, default=1, help="if 1, train VoiceEditor_one, if 2 train VoiceEditor_seven")
|
46 |
+
parser.add_argument("--local_wandb", type=int, default=0, help="if 1, will use local wandb, otherwise use the global one")
|
47 |
+
parser.add_argument("--wandb_entity", type=str, default="puyuanpeng", help="the entity (usually your username) for wandb")
|
48 |
+
# data
|
49 |
+
parser.add_argument("--librilight_ratio", type=float, default=1, help='the portion of lightlight compared to gigaspeech, 1 means equal, 2 means librilight data is twice as much as gigaspeech')
|
50 |
+
parser.add_argument("--plus_librilight_root", type=str, default=None, help="if not None, will combine gigaspeech and librilight, this is the root folder to librilight. Note that will need to merge the vocab.txt based on gigaspeech's, in order to be able to load a pretrained model")
|
51 |
+
parser.add_argument("--plus_librilight_phn_folder_name", type=str, default=None, help="if not None, will combine gigaspeech and librilight, this is the phoneme folder name of librilight")
|
52 |
+
parser.add_argument("--plus_librilight_encodec_folder_name", type=str, default=None, help="if not None, will combine gigaspeech and librilight, this is the encodec folder name of librilight")
|
53 |
+
parser.add_argument("--plus_librilight_manifest_name", type=str, default=None, help="if not None, will combine gigaspeech and librilight, this is the manifest folder name of librilight")
|
54 |
+
parser.add_argument("--skip_us", type=int, default=0, help="skip the giga utterances that contains 'j uː ɛ s' because of the tokenization issue")
|
55 |
+
parser.add_argument("--pseudo_epoch_size", type=int, default=37901, help="only use for Eden scheduler. 37901 is the epoch size in the default optim setting, this is probably too big")
|
56 |
+
parser.add_argument("--switch_order", type=int, default=0, help="this is only for hificodec, where we switch the order of 2 and 3nd codebook")
|
57 |
+
parser.add_argument("--phn_folder_name", type=str, default="phoneme", help="for libritts I also have arpa phns, in which case should be phonemes_arpa")
|
58 |
+
parser.add_argument("--encodec_folder_name", type=str, default="mimi_8cb", help="folder where encodec codes are stored")
|
59 |
+
parser.add_argument("--manifest_name", type=str, default="manifest_final", help="if using hificodec, it should be hificodec_menifest, if using encodec, it is the default")
|
60 |
+
parser.add_argument("--pad_x", type=int, default=1, help="whether or not always pad x to have text_max_length. select 1 to get the maximal memory consumption, but the actual case should be smaller, better to have it being 0")
|
61 |
+
parser.add_argument("--max_num_tokens", type=int, default=18750, help="max number of encodec tokens per gpu, this is only used when using dynamic batching, will ignore batch size. Note that batch size is the final effective batch size (sum of batch on each gpu), but max_num_tokens is per gpu")
|
62 |
+
parser.add_argument("--val_max_num_tokens", type=int, default=6000, help="FOR validation, this basically is for music-gen because of high mem consumption. max number of encodec tokens per gpu, this is only used when using dynamic batching, will ignore batch size. Note that batch size is the final effective batch size (sum of batch on each gpu), but max_num_tokens is per gpu")
|
63 |
+
parser.add_argument("--num_buckets", type=int, default=10)
|
64 |
+
parser.add_argument("--dynamic_batching", type=int, default=1)
|
65 |
+
parser.add_argument("--audio_max_length", type=float, default=120, help="in second, crop the audio is length is longer than this")
|
66 |
+
parser.add_argument("--audio_min_length", type=float, default=2, help="in second, drop the audio if length is shorter than this")
|
67 |
+
parser.add_argument("--text_max_length", type=int, default=1000, help='if too long, we crop')
|
68 |
+
parser.add_argument("--text_min_length", type=float, default=10, help="if too short, will drop")
|
69 |
+
parser.add_argument("--encodec_sr", type=float, default=50, help="for 24kHz mimi model, it produces 12.5 codes for 1 sec of audio")
|
70 |
+
parser.add_argument('--mask_len_min', type=int, default=20, help='Minimum mask length')
|
71 |
+
parser.add_argument('--mask_len_max', type=int, default=400, help='Maximum mask length')
|
72 |
+
parser.add_argument('--extra_mask_len_min', type=int, default=2, help='Minimum extra mask length')
|
73 |
+
parser.add_argument('--extra_mask_len_max', type=int, default=20, help='Maximum extra mask length')
|
74 |
+
parser.add_argument('--final_audio_token_len', type=int, default=772, help="this is only for stage 1 training, since we add eog, start_of_continue, and a random amount of extra mask, --audio_max_length won't be the final max length, the self.args.final_audio_token_len = self.args.audio_max_length*self.args.encodec_sr+self.args.extra_mask_len_max+2 ")
|
75 |
+
|
76 |
+
# model
|
77 |
+
parser.add_argument("--ttsonly", default=0, type=int, help="if 1, only train tts model, no CM3")
|
78 |
+
parser.add_argument("--load_existing_text_embedding", type=int, default=0, help="if 1, when load model and the text vocab doesn't match, will load the existing weights while the new weights will be initialized randomly")
|
79 |
+
parser.add_argument("--fly", type=int, default=0, help="if 1, encode chunked audio on the fly")
|
80 |
+
parser.add_argument("--encodec_ckpt", type=str, default="/data/scratch/pyp/exp_pyp/audiocraft/encodec/xps/6f79c6a8/checkpoint.th")
|
81 |
+
parser.add_argument("--downsample_rate", type=int, default=320, help="the downsample rate for the encodec model, 16000/320 = 50Hz")
|
82 |
+
parser.add_argument("--segtts_mask", type=int, default=0, help="if 1, use segtts_mask model, where we have a prefix and segment utterance into two and shifted separately for modeling, and use make use of mask:0, by insert two mask:0 in the middle of the two segments")
|
83 |
+
parser.add_argument("--segtts", type=int, default=0, help="if 1, use segtts model, where we have a prefix and segment utterance into two and shifted separately for modeling")
|
84 |
+
parser.add_argument("--edge", type=int, default=0, help="if 1, use edge prediction for the first codebook")
|
85 |
+
parser.add_argument("--duration_loss_weight", type=float, default=1.0, help="weight on the duration loss")
|
86 |
+
parser.add_argument("--drop_long", type=int, default=1, help="if this is true, will drop example whose encodec sequence or phone sequence is too long, rather than cropping as we did before, to avoid hellucination")
|
87 |
+
parser.add_argument("--eos", type=int, default=2051, help="this is to be used with reduced_eog, where we end the utterance with eos, and end the generated segment with eog, also when this is used, the n_special should be 4")
|
88 |
+
parser.add_argument("--reduced_eog", type=int, default=1, help="for the non-final segments, do not insert eog at the end, this could hopefully solve the early stopping issue when doing tts")
|
89 |
+
|
90 |
+
parser.add_argument("--valle_orig", type=int, default=0, help="the original valle model, trained for TTS")
|
91 |
+
parser.add_argument("--valle_max_prompt_len", type=float, default=6, help='in sec.')
|
92 |
+
# randomly choose a portion as tts examples during training
|
93 |
+
parser.add_argument("--tts_portion", type=float, default=0, help="randomly choose a portion of the training examples as tts examples, where no mask and rearrangement is used")
|
94 |
+
|
95 |
+
# put special tokens first to handle different vocab_size
|
96 |
+
parser.add_argument("--special_first", type=int, default=0, help="if 1, need to have special tokens to be the first few tokens, e.g. 0, 1, 2, which means we need to adjust the preprocessing and postprocessing of the encodec codes. note that we hard coded to have 3 special tokens")
|
97 |
+
parser.add_argument("--n_special", type=int, default=4, help="empty, eog, pad, eos")
|
98 |
+
|
99 |
+
# weight codebook differently
|
100 |
+
parser.add_argument("--codebook_weight", type=str, default=None, help="e.g. ['5','1','0.5','0.1']")
|
101 |
+
|
102 |
+
# args for MusicGen
|
103 |
+
parser.add_argument("--mask_span_weight", default=1.0, type=float, help="the weight on the tokens in masked span")
|
104 |
+
parser.add_argument("--unmask_span_weight", default=1.0, type=float, help="the weight on unmasked span")
|
105 |
+
parser.add_argument("--start_end_weight", default=None, type=str, help="weight the start x tokens and end x tokens differently, e.g. (10,2.0), means x == 10, weight==2.0")
|
106 |
+
# for now not consider the two weights above, only consider eog_weight, which is defined below somewhere, as the above two are not super principled
|
107 |
+
|
108 |
+
parser.add_argument("--musicgen", type=int, default=0, help="whether or not use this model, will also have an impact on the output shape of the dataset")
|
109 |
+
parser.add_argument("--enc_dec", default=0, type=int, help="use enc-dec architecture, text is from the enc, only for musicgen")
|
110 |
+
parser.add_argument("--dec", default=0, type=int, help="use dec only architecture, text is from the enc, only for musicgen. Exclusive with --enc_dec")
|
111 |
+
parser.add_argument("--empty_token", default=2048, type=int, help="indicating the no token at the position for the codebook")
|
112 |
+
# args for the optimizer and scheduler from Feiteng
|
113 |
+
# original setup for the 3 params are 5000 4 and 1000
|
114 |
+
# but that's because set_epoch is run on num_gradient_accumulation_step*step (with 4 being the accumulation step)
|
115 |
+
# so I scaled down them a little bit
|
116 |
+
# will try scaling them back if this doesn't work
|
117 |
+
parser.add_argument("--optimizer_name", type=str, default="AdamW", help="can also use ScaledAdam, in which case we'll also use the Eden scheduler")
|
118 |
+
parser.add_argument("--reduce_lr_start_step", type=int, default=3000, help='after which significantly reduce the lr. a param for the eden optimizer')
|
119 |
+
parser.add_argument("--reduce_lr_start_epoch", type=int, default=4)
|
120 |
+
parser.add_argument("--clipping_update_period", type=int, default=600)
|
121 |
+
|
122 |
+
|
123 |
+
# below are args for valle
|
124 |
+
# below are args for valle
|
125 |
+
parser.add_argument("--valle", type=int, default=0, help="if 1, use valle model (cm3)")
|
126 |
+
parser.add_argument("--decoder_dim", type=int, default=1024)
|
127 |
+
parser.add_argument("--norm_first", action="store_true", default=True)
|
128 |
+
parser.add_argument("--add_prenet", action="store_true", default=False)
|
129 |
+
parser.add_argument("--prefix_mode", type=int, default=5, help="this is for NAR, we only do 5, which is CM3")
|
130 |
+
parser.add_argument("--share_embedding", action="store_true", default=False)
|
131 |
+
parser.add_argument("--nar_scale_factor", type=float, default=1.0)
|
132 |
+
parser.add_argument("--prepend_bos", action="store_true", default=False)
|
133 |
+
parser.add_argument("--sync_nar", type=int, default=0, help="whether to choose the same NAR model to run for training_stage==2 across different process (this is only for DDP)")
|
134 |
+
# above are args for valle
|
135 |
+
# above are args for valle
|
136 |
+
|
137 |
+
|
138 |
+
|
139 |
+
# add parallel_pattern
|
140 |
+
parser.add_argument("--parallel_pattern", type=int, default=0, help="if 1, use parallel pattern, we also use LFSC codec")
|
141 |
+
parser.add_argument("--full_prediction", type=int, default=0, help='this is for ve1, if 1, use full autoregressive mask, and calculate loss over all tokens, except for mask_tokens')
|
142 |
+
parser.add_argument("--multicm3", type=int, default=0, help='cm3 model but allows multiple mask spans')
|
143 |
+
parser.add_argument("--max_mask_portion",type=float,default=0.7,help="should mask a utterance for more than this portion")
|
144 |
+
parser.add_argument("--max_n_spans", type=int, default=8, help='maximal number of spans, only use when using multicm3, this is used to decide number of mask_embedding, and max clamp value if use Poisson distribution, if use uniform distribution to sample number of spans if will be uniform(1,max_n_spans)')
|
145 |
+
parser.add_argument("--shuffle_mask_embedding", type=int, default=0, help="whether shuffle the mask embedding, so that mask:0 is not the most well trained, default is not shuffling. The default has it's benefit, as it make sure that mask:0 always appear the first")
|
146 |
+
parser.add_argument("--mask_sample_dist", type=str, default="uniform", help="uniform or poissonx, e.g. poisson1, meaning the parameter lambda is 1, it will most likely sample 1 masks")
|
147 |
+
parser.add_argument("--min_gap", type=int, default=10, help="after sampled starts, delete later one if it closer to the former start than the min_gap")
|
148 |
+
|
149 |
+
parser.add_argument('--cm3', type=int, default=0, help="use cm3 style for ve1, the input from dataloader is going to be just raw data, all masking and rearrangement will happen whin the model")
|
150 |
+
|
151 |
+
parser.add_argument('--sep_special_token', type=int, default=0, help="remove text/audio pad token, set audio_mask_token and start of continue to be separately learned embeddings. Therefore, for ve1 self.n_text_tokens == self.args.text_vocab_size, self.n_audio_tokens == self.args.audio_vocab_size + 2, for ve7, self.n_text_tokens == self.args.text_vocab_size, self.n_audio_tokens == self.args.audio_vocab_size")
|
152 |
+
parser.add_argument('--one_causal', type=int, default=0, help="whether model VE_one generation as autoregressive gen or non-autoregressive gen")
|
153 |
+
parser.add_argument('--n_codebooks', type=int, default=8)
|
154 |
+
parser.add_argument('--weight_sharing', type=int, default=0, help="sharing weights between VE_seven predict layer and embedding layer")
|
155 |
+
parser.add_argument('--text_vocab_size', type=int, default=86, help='Size of text vocabulary')
|
156 |
+
parser.add_argument('--text_pad_token', type=int, default=86, help='padding of the text tokens, not attended')
|
157 |
+
# parser.add_argument('--audio_vocab_size', type=int, default=1024, help='Size of audio vocabulary')
|
158 |
+
parser.add_argument('--audio_vocab_size', type=str, default='2048', help="Size of audio vocabulary, can be specified as '[128,512,1024,2048]'")
|
159 |
+
parser.add_argument('--audio_mask_token', type=int, default=1024, help='Audio mask token, this the the extra mask used in the masked region for AR, for NAR, the entire masked region will be filled with it')
|
160 |
+
parser.add_argument('--bog', type=int, default=1025, help='Begin of generation token')
|
161 |
+
parser.add_argument('--eog', type=int, default=2049, help='End of generation token')
|
162 |
+
parser.add_argument('--start_of_continue', type=int, default=1027, help='this token follows the masked region, proceeds the first unmasked token, to indicate that gt tokens starts')
|
163 |
+
parser.add_argument('--audio_pad_token', type=int, default=2050, help='padding of the encodec codes, not attended')
|
164 |
+
parser.add_argument('--d_model', type=int, default=1024, help='Model dimension')
|
165 |
+
parser.add_argument('--audio_embedding_dim', type=int, default=128, help='dimension for encodec continues embedding (before being quantized)')
|
166 |
+
parser.add_argument('--text_embedding_dropout', type=float, default=0.1, help='Dropout for text embedding')
|
167 |
+
parser.add_argument('--audio_embedding_dropout', type=float, default=0, help='Dropout for audio embedding')
|
168 |
+
parser.add_argument('--text_positional_embedding_dropout', type=float, default=0.1, help='Dropout for text positional embedding')
|
169 |
+
parser.add_argument('--audio_positional_embedding_dropout', type=float, default=0.1, help='Dropout for audio positional embedding')
|
170 |
+
parser.add_argument('--trm_dropout', type=float, default=0.1, help='Dropout for transformer')
|
171 |
+
parser.add_argument('--nhead', type=int, default=16, help='Number of attention heads')
|
172 |
+
parser.add_argument('--num_encoder_layers', type=int, default=12, help='Number of encoder layers')
|
173 |
+
parser.add_argument('--num_decoder_layers', type=int, default=12, help='Number of decoder layers')
|
174 |
+
parser.add_argument('--eog_weight', type=float, default=1.0, help='Weight for End of generation token')
|
175 |
+
parser.add_argument('--stage_one_load_encodec_embedding', type=str, default=None, help='Path to load encodec embedding for stage one. On our lab machine it is /saltpool0/scratch/pyp/VoiceEditor/encodec_embedding/24khz_8codebooks.pth, 8 is the n_codebooks')
|
176 |
+
parser.add_argument('--stage_two_load_encodec_embedding', type=str, default=None, help='Path to load encodec embedding for stage two。 On our lab machine it is /saltpool0/scratch/pyp/VoiceEditor/encodec_embedding/24khz_8codebooks.pth, 8 is the n_codebooks')
|
177 |
+
parser.add_argument('--stage_two_load_ve_one_embedding', type=str, default=None, help='Path to load VoiceEditor_one audio embedding for stage two')
|
178 |
+
parser.add_argument('--load_model_from', type=str, default=None, help='Path to load model from, this will be effective last, so will overwrite all previous load, including resume')
|
179 |
+
parser.add_argument('--load_model_from_ve1', type=str, default=None, help='Path to load ve1 model weights from, this will be effective last, designed for loading the encoder weights of the VE7 from a pretrained VE1')
|
180 |
+
|
181 |
+
|
182 |
+
## below are args for the new long model
|
183 |
+
parser.add_argument("--target_time_stretch_prob", type=float, default=0, help="the probability of time stretching the target audio")
|
184 |
+
parser.add_argument("--target_time_stretch_bound", type=float, default=0.1, help="the bound of the time stretching target audio, e.g. 0.1 means the audio will be stretched by 0.9 to 1.1")
|
185 |
+
parser.add_argument("--time_stretch_prob", type=float, default=0, help="the probability of time stretching the audio")
|
186 |
+
parser.add_argument("--time_stretch_bound", type=float, default=0.3, help="the bound of the time stretching, e.g. 0.3 means the audio will be stretched by 0.7 to 1.3")
|
187 |
+
parser.add_argument("--no_loss_on_prefix", type=int, default=0, help="if 1, will not calculate loss on the prefix acoustic tokens")
|
188 |
+
parser.add_argument("--x_sep_token", type=int, default=None, help="if not None, will use this token in between prompt text and target generation text")
|
189 |
+
parser.add_argument("--y_sep_token", type=int, default=None, help="if not None, will use this token in between prompt codec tokens and target codec tokens")
|
190 |
+
parser.add_argument("--neighbor_prompt_prob", type=float, default=0, help="the probability of using the prompt from the neighbor")
|
191 |
+
parser.add_argument("--neighbor_folder_name", type=str, default='neighbors',help="folder where the neighbors of the current audio files are stored, each row contains three tab separated entries: neighbor_fn, neighbor_temporal_distance, neighbor_duration")
|
192 |
+
parser.add_argument("--alignment_folder_name", type=str, default='alignment', help="folder where the forced alignment of the current audio files are stored, in csv format, each row contains five comma separated entries: begin, end, label, type, speaker, the first row is header")
|
193 |
+
parser.add_argument("--ipa_alignment_folder_name", type=str, default='ipa_alignment', help="folder where the forced alignment of the current audio files are stored, in txt format, each row contains three tab separated entries: begin, end, ipa phn sequence, generated using data/ll60k_preprocessing/step7_ipa_alignment.py")
|
194 |
+
parser.add_argument("--max_prompt_len", type=float, default=30, help="in sec., maximal prompt length selected from some neighboring file")
|
195 |
+
parser.add_argument("--min_prompt_len", type=float, default=0.5, help="in sec., minimal prompt length selected from some neighboring file")
|
196 |
+
parser.add_argument("--neighbor_selection_method", type=str, default="maxdist_60", help="maxdist_60 means uniformly select a neighbor that's within 60 sec of the current audio file")
|
197 |
+
parser.add_argument("--num_trial", type=int, default=5, help="number of tries to select a neighbor")
|
198 |
+
parser.add_argument("--prompt_start_from_begining_prob", type=float, default=0.5, help="the probability of starting the prompt from the beginning of the neighbor")
|
199 |
+
parser.add_argument("--min_alignment_len", type=int, default=5, help="in number of words")
|
200 |
+
parser.add_argument("--audio_folder_name", type=str, default='audio', help="folder where the audio files are stored")
|
201 |
+
|
202 |
+
# rope parameters
|
203 |
+
parser.add_argument("--decoder_regular_rope", type=int, default=0, help="if 1, will use regular rope for the decoder (note that we always use regular rope for encoder). ")
|
204 |
+
parser.add_argument("--progress_no_multiple", type=int, default=0, help="if 1, will not multiple the percentage progress by the length of the key, see apply_rotary_pos_emb in models/modules/activation.py, this applies to both rope and sinusoidal positional encoding. Note that progress scale is still applied, i.e. when we only apply progress scale, but not multiple, the scaling factor is constant for every sample, rather than sample dependent")
|
205 |
+
parser.add_argument("--add_eos_to_text", type=int, default=0, help="if not 0, use this number as eos and add to the end of text token, usually use the second to last token in the vocab size")
|
206 |
+
parser.add_argument("--add_bos_to_text", type=int, default=0, help="if not 0, use this number as bos and add to the begining of text token, usually use the third to last token in the vocab size")
|
207 |
+
parser.add_argument("--use_sinusoidal", type=int, default=0, help="if 1, will use sinusoidal positional encoding, otherwise use rope. BUT if rope_base is None, will use sinusoidal")
|
208 |
+
parser.add_argument("--sinusoidal_base", type=int, default=1e4, help="the base of the exponential function, default is 1e4")
|
209 |
+
parser.add_argument("--use_sinusoidal_progress", type=int, default=0, help="if 1, will use sinusoidal positional encoding for progress, otherwise use rope")
|
210 |
+
parser.add_argument("--rope_base", type=int, default=None, help="the base of the exponential function, default is 1e4, if None, will not use rope")
|
211 |
+
parser.add_argument("--multiple_key_length", type=int, default=0, help="if 1, during progress calculation, will multiple the precentage progress by the length of the key, otherwise multiple with length of query. see models/rope_playground.ipynb")
|
212 |
+
parser.add_argument("--progress_scale", type=float, default=1.0, help="scale the progress, the smaller the value, the bigger the diagonal in attention score, see models/rope_playground.ipynb")
|
213 |
+
|
214 |
+
# attention alignment loss
|
215 |
+
parser.add_argument("--attention_alignment_loss", type=float, default=0.0, help="the weight on the attention alignment loss, if 0, will not calculate the loss")
|
216 |
+
parser.add_argument("--alignment_loss_layer", type=str, default="['0-1', '2', '3']", help='the layers to calculate the alignment loss, e.g. ["0-1", "2", "3"]')
|
217 |
+
parser.add_argument("--alignment_loss_head", type=str, default="['0-1', '2', '3']", help='the attention heads to calculate the alignment loss, e.g. ["0-1", "2", "3"]')
|
218 |
+
parser.add_argument("--alignment_blank_logit", type=float, default=-1.0, help="the logit for the blank token added to the attention weights")
|
219 |
+
|
220 |
+
# inference parameters
|
221 |
+
parser.add_argument("--metrics", type=str, default="['spk_sim','wer','mcd','pitch','energy','pesq','utmos']")
|
222 |
+
parser.add_argument("--res_jsonl_root", type=str, default="/home/pyp/BoostedVoiceEditor/res")
|
223 |
+
parser.add_argument("--res_name", type=str, default="2jan25.jsonl")
|
224 |
+
parser.add_argument("--inference_seed", type=int, default=1)
|
225 |
+
parser.add_argument("--codec_audio_sr", type=int, default=16000)
|
226 |
+
parser.add_argument("--codec_sr", type=float, default=50)
|
227 |
+
parser.add_argument("--top_k", type=int, default=0)
|
228 |
+
parser.add_argument("--top_p", type=float, default=0.9)
|
229 |
+
parser.add_argument("--temperature", type=float, default=1)
|
230 |
+
parser.add_argument("--silence_tokens", type=list, default=[])
|
231 |
+
parser.add_argument("--kvcache", type=int, default=0)
|
232 |
+
parser.add_argument("--stop_repetition", type=int, default=3)
|
233 |
+
parser.add_argument("--sample_batch_size", type=int, default=1)
|
234 |
+
parser.add_argument("--inference_manifest_fns", type=str, default="['/home/pyp/BoostedVoiceEditor/manifests/debug.jsonl']")
|
235 |
+
parser.add_argument("--use_gt_duration", type=int, default=1)
|
236 |
+
parser.add_argument("--save_root", type=str, default="/data/scratch/pyp/exp_pyp/BoostedVoiceEditor/gens")
|
237 |
+
parser.add_argument("--encodec_signature", type=str, default="/data/scratch/pyp/exp_pyp/audiocraft/encodec/xps/6f79c6a8/checkpoint.th")
|
238 |
+
parser.add_argument("--extra_cutoff", type=float, default=5, help="in rare cases where the model doesn't follow specified target duration (only happened in extrapolation cases), we will terminate generation once the extra duration exceeds this value")
|
239 |
+
parser.add_argument("--duration_margin", type=float, default=0.04, help="used along with extra_cutoff, when extra_cutoff is used (i.e. model doesn't follow specified target_duration), we terminate the generate, and cut the results to target_duration + duration_margin")
|
240 |
+
# add repeat_prompt and asr_model_name
|
241 |
+
parser.add_argument("--repeat_prompt", type=int_or_str, default=0, help="if 1, will repeat the prompt for each segment")
|
242 |
+
parser.add_argument("--asr_model_name", type=str, default="w2v2", help="the name of the asr model, if not None, will use the asr model to generate the prompt")
|
243 |
+
|
244 |
+
# depth transformer parameters
|
245 |
+
parser.add_argument("--depth_dec_num_layers", type=int, default=0)
|
246 |
+
parser.add_argument("--depth_dec_d_model", type=int, default=768)
|
247 |
+
parser.add_argument("--depth_dec_nhead", type=int, default=12)
|
248 |
+
parser.add_argument("--moshi_depth", type=int, default=0, help="if 1, will use the same parameterization as moshi, i.e. temporal trm output will gets added to every transformed token embedding")
|
249 |
+
|
250 |
+
parser.add_argument("--validation_sample_cap", type=int, default=None, help="cap the validation data to this number")
|
251 |
+
parser.add_argument("--no_libri_in_training", type=int, default=None, help="if 1, will not use librilight in training, only use in validation")
|
252 |
+
parser.add_argument("--uniform_weight_start_step", type=int, default=1e50, help="set all codebook weight to be uniform starting from this step")
|
253 |
+
|
254 |
+
return parser
|
copy_codebase.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import shutil
|
4 |
+
import fnmatch
|
5 |
+
|
6 |
+
def parse_gitignore(gitignore_path):
|
7 |
+
"""Parse a .gitignore file and return a list of patterns."""
|
8 |
+
patterns = []
|
9 |
+
with open(gitignore_path, "r") as f:
|
10 |
+
for line in f:
|
11 |
+
# Ignore comments and blank lines
|
12 |
+
line = line.strip()
|
13 |
+
if not line or line.startswith("#"):
|
14 |
+
continue
|
15 |
+
# Handle wildcards and directory separators
|
16 |
+
patterns.append(line)
|
17 |
+
return patterns
|
18 |
+
|
19 |
+
def file_matches_patterns(file_path, patterns):
|
20 |
+
"""Check if a file matches any of the patterns in .gitignore."""
|
21 |
+
for pattern in patterns:
|
22 |
+
if fnmatch.fnmatch(file_path, pattern):
|
23 |
+
return True
|
24 |
+
return False
|
25 |
+
|
26 |
+
def copy_codebase(src, dst, max_size_mb=5, gitignore_path=None):
|
27 |
+
""" Copy files from src to dst, skipping files larger than max_size_mb and matching .gitignore patterns. """
|
28 |
+
if gitignore_path and os.path.exists(gitignore_path):
|
29 |
+
patterns = parse_gitignore(gitignore_path)
|
30 |
+
else:
|
31 |
+
patterns = []
|
32 |
+
print("patterns to ignore: ", patterns)
|
33 |
+
os.makedirs(dst, exist_ok=True)
|
34 |
+
for root, dirs, files in os.walk(src):
|
35 |
+
for file in files:
|
36 |
+
file_path = os.path.join(root, file)
|
37 |
+
relative_path = os.path.relpath(file_path, src)
|
38 |
+
dst_path = os.path.join(dst, relative_path)
|
39 |
+
# ignore .git because of permission issues
|
40 |
+
if "/.git/" in file_path:
|
41 |
+
continue
|
42 |
+
|
43 |
+
# Check .gitignore patterns
|
44 |
+
if file_matches_patterns(file_path, patterns):
|
45 |
+
# print(f"Skipping {file_path} because it matches a pattern in .gitignore")
|
46 |
+
continue
|
47 |
+
|
48 |
+
# Check file size
|
49 |
+
if os.path.getsize(file_path) > max_size_mb * 1024 * 1024:
|
50 |
+
print(f"Skipping {file_path} because it's larger than {max_size_mb}MB")
|
51 |
+
continue
|
52 |
+
|
53 |
+
|
54 |
+
# Make sure the destination directory exists
|
55 |
+
os.makedirs(os.path.dirname(dst_path), exist_ok=True)
|
56 |
+
shutil.copy(file_path, dst_path)
|
data/__init__.py
ADDED
File without changes
|
data/combined_dataset.py
ADDED
@@ -0,0 +1,466 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import ffmpeg
|
3 |
+
import torch
|
4 |
+
import random
|
5 |
+
import copy
|
6 |
+
import logging
|
7 |
+
import torch.distributed as dist
|
8 |
+
import shutil
|
9 |
+
import csv
|
10 |
+
import torchaudio
|
11 |
+
import glob
|
12 |
+
import numpy as np
|
13 |
+
from data.tokenizer import TextTokenizer, tokenize_text, AudioTokenizer
|
14 |
+
def find_files(root_dir, endswith=".wav"):
|
15 |
+
files = []
|
16 |
+
# os.walk generates the file names in a directory tree
|
17 |
+
for dirpath, dirnames, filenames in os.walk(root_dir):
|
18 |
+
for filename in filenames:
|
19 |
+
# os.path.splitext splits the file name into a base and extension
|
20 |
+
# base, ext = os.path.splitext(filename)
|
21 |
+
if filename.lower().endswith(endswith):
|
22 |
+
# os.path.join combines one or more path names into a single path
|
23 |
+
full_path = os.path.join(dirpath, filename)
|
24 |
+
files.append(full_path)
|
25 |
+
return files
|
26 |
+
|
27 |
+
class dataset(torch.utils.data.Dataset):
|
28 |
+
def __init__(self, args, split):
|
29 |
+
super().__init__()
|
30 |
+
self.args = args
|
31 |
+
self.args.target_time_stretch_prob = getattr(self.args, "target_time_stretch_prob", 0)
|
32 |
+
self.args.target_time_stretch_bound = getattr(self.args, "target_time_stretch_bound", 0.1)
|
33 |
+
self.split = split
|
34 |
+
|
35 |
+
assert self.split in ['train', 'valid', 'test'], f"split should be one of ['train', 'valid', 'test'], but it's {split}"
|
36 |
+
|
37 |
+
if "[" not in self.args.dataset_dir or "]" not in self.args.dataset_dir:
|
38 |
+
self.dataset_dir = f"['{self.args.dataset_dir}']"
|
39 |
+
else:
|
40 |
+
self.dataset_dir = copy.deepcopy(self.args.dataset_dir)
|
41 |
+
self.dataset_dir = eval(self.dataset_dir)
|
42 |
+
data = []
|
43 |
+
if "[" not in self.args.manifest_name or "]" not in self.args.manifest_name:
|
44 |
+
self.args.manifest_name = f"['{self.args.manifest_name}']"
|
45 |
+
else:
|
46 |
+
self.args.manifest_name = copy.deepcopy(self.args.manifest_name)
|
47 |
+
self.manifest_name = eval(self.args.manifest_name)
|
48 |
+
if len(self.manifest_name) != len(self.dataset_dir):
|
49 |
+
assert len(self.manifest_name) == 1, f"len(self.manifest_name) should be 1 or equal to len(self.dataset_dir), but it's {len(self.manifest_name)}"
|
50 |
+
self.manifest_name = self.manifest_name * len(self.dataset_dir)
|
51 |
+
for i_data, dataset_dir in enumerate(self.dataset_dir):
|
52 |
+
if getattr(self.args, "no_libri_in_training", None) != None and ("librilight" in dataset_dir) and self.split == "train":
|
53 |
+
if not dist.is_initialized() or dist.get_rank() == 0:
|
54 |
+
logging.info(f"skipping librilight in training split")
|
55 |
+
continue
|
56 |
+
n_datapoints = 0
|
57 |
+
manifest_fn = os.path.join(dataset_dir, self.manifest_name[i_data], self.split+".txt")
|
58 |
+
if not os.path.isfile(manifest_fn):
|
59 |
+
all_manifest_fn = glob.glob(manifest_fn.replace(".txt", "_*=*.txt"))
|
60 |
+
if len(all_manifest_fn) == 0:
|
61 |
+
logging.info(f"no manifest file found for {split} split in {dataset_dir}")
|
62 |
+
continue
|
63 |
+
if self.args.debug:
|
64 |
+
logging.info(f"debugging mode, only using the frist found manifest file: {all_manifest_fn[0]}")
|
65 |
+
all_manifest_fn = all_manifest_fn[:1]
|
66 |
+
else:
|
67 |
+
if dist.is_initialized() and dist.get_rank() == 0:
|
68 |
+
logging.info(f"Combining found manifest files for {split}: {all_manifest_fn}")
|
69 |
+
for cur_manifest_fn in all_manifest_fn:
|
70 |
+
with open(cur_manifest_fn, "r") as rf:
|
71 |
+
tmp = [l.strip().split("\t") + [i_data] for l in rf.readlines()] # i_data is the index of the dataset
|
72 |
+
n_datapoints += len(tmp)
|
73 |
+
data += tmp
|
74 |
+
else:
|
75 |
+
with open(manifest_fn, "r") as rf:
|
76 |
+
tmp = [l.strip().split("\t") + [i_data] for l in rf.readlines()]
|
77 |
+
data += tmp
|
78 |
+
n_datapoints += len(tmp)
|
79 |
+
if dist.is_initialized() and dist.get_rank() == 0:
|
80 |
+
logging.info(f"number of data points for {split} split in {dataset_dir}: {n_datapoints}")
|
81 |
+
assert len(data) > 0, f"no data found for {split} split"
|
82 |
+
lengths_list = [int(item[1]) for item in data] # use 1 because there might be more than 1 columns (for gigaspeech we have 3 columns: path, duration, selfsim)
|
83 |
+
self.data = []
|
84 |
+
self.lengths_list = []
|
85 |
+
total_duration = 0
|
86 |
+
for d, l in zip(data, lengths_list):
|
87 |
+
if l >= self.args.encodec_sr*self.args.audio_min_length:
|
88 |
+
if self.args.drop_long and l > self.args.encodec_sr*self.args.audio_max_length:
|
89 |
+
continue
|
90 |
+
self.data.append(d)
|
91 |
+
self.lengths_list.append(l)
|
92 |
+
total_duration += l / self.args.encodec_sr / 3600
|
93 |
+
# logging.info(f"for now cut the dataset to only have 500 examples for debugging")
|
94 |
+
# self.data = self.data[:1000]
|
95 |
+
# self.lengths_list = self.lengths_list[:1000]
|
96 |
+
if dist.is_initialized() and dist.get_rank() == 0:
|
97 |
+
logging.info(f"TOTAL number of data points for {self.split} split: {len(self.lengths_list)}")
|
98 |
+
logging.info(f"TOTAL duration for {self.split} split: {total_duration:.1f} hours")
|
99 |
+
# phoneme vocabulary
|
100 |
+
phn_set = set()
|
101 |
+
for dataset_dir in self.dataset_dir:
|
102 |
+
vocab_fn = os.path.join(dataset_dir, "vocab.txt")
|
103 |
+
with open(vocab_fn, "r") as f:
|
104 |
+
temp = [l.strip().split("\t") for l in f.readlines() if len(l) != 0]
|
105 |
+
phn_set.update([item[-1] for item in temp])
|
106 |
+
self.phn2num = {item:i for i, item in enumerate(phn_set)}
|
107 |
+
assert self.args.text_vocab_size > len(self.phn2num), f"need self.args.text_vocab_size to be bigger than number of phns in vocab to handle OOD phn, but the former is {self.args.text_vocab_size} while the latter is {len(self.phn2num)}"
|
108 |
+
|
109 |
+
if (self.args.neighbor_prompt_prob > 0 and self.args.time_stretch_prob > 0) or self.args.target_time_stretch_prob > 0:
|
110 |
+
userdir = os.path.expanduser("~")
|
111 |
+
encodec_signature = getattr(self.args, "encodec_signature", os.path.join(userdir, "VoiceStar", "pretrained", "encodec_6f79c6a8.th"))
|
112 |
+
self.audio_tokenizer = AudioTokenizer(signature=encodec_signature, device=torch.device("cpu"), encode_only=True)
|
113 |
+
assert self.audio_tokenizer.sample_rate == self.args.codec_audio_sr, f"audio_tokenizer.sample_rate: {self.audio_tokenizer.sample_rate}, self.args.encodec_sr: {self.args.encodec_sr}"
|
114 |
+
if dist.is_initialized() and dist.get_rank() == 0:
|
115 |
+
logging.info(f"rank: {dist.get_rank()}, audio_tokenizer device: {self.audio_tokenizer._device}")
|
116 |
+
|
117 |
+
def __len__(self):
|
118 |
+
return len(self.lengths_list)
|
119 |
+
|
120 |
+
def _load_phn_enc(self, index):
|
121 |
+
item = self.data[index]
|
122 |
+
dataset_dir = self.dataset_dir[item[-1]]
|
123 |
+
pf = os.path.join(dataset_dir, self.args.phn_folder_name, item[0]+".txt")
|
124 |
+
ef = os.path.join(dataset_dir, self.args.encodec_folder_name, item[0]+".txt")
|
125 |
+
# with certain probability, we load the audio, and time stretch it, note that we should not hit self.args.audio_max_length
|
126 |
+
if "/librilight" in dataset_dir:
|
127 |
+
audio_ext = ".flac"
|
128 |
+
elif "/emilia" in dataset_dir:
|
129 |
+
audio_ext = ".mp3"
|
130 |
+
else:
|
131 |
+
raise NotImplementedError(f"dataset_dir: {dataset_dir}")
|
132 |
+
|
133 |
+
audio_fn = os.path.join(dataset_dir, self.args.audio_folder_name, item[0].replace(".txt", "")+audio_ext)
|
134 |
+
speed_factor = random.uniform(-self.args.target_time_stretch_bound, self.args.target_time_stretch_bound) + 1
|
135 |
+
length_ok = (float(item[1]) / self.args.encodec_sr) / speed_factor < self.args.audio_max_length # NOTE to calculate the maximal duration after time stretching, we should be used as orig/(1-bound), rather than orig*(1+bound)
|
136 |
+
if self.args.target_time_stretch_prob > 0 and random.random() < self.args.target_time_stretch_prob and os.path.isfile(audio_fn) and length_ok:
|
137 |
+
try:
|
138 |
+
with open(pf, "r") as p:
|
139 |
+
phns = [l.strip() for l in p.readlines()]
|
140 |
+
assert len(phns) == 1, phns
|
141 |
+
all_phns = phns[0].split(" ")
|
142 |
+
x = [self.phn2num[item] for item in all_phns if item in self.phn2num]
|
143 |
+
except:
|
144 |
+
logging.info(f"loading failed for {pf}, maybe files don't exist or are corrupted")
|
145 |
+
return [], [[]], dataset_dir, audio_ext
|
146 |
+
# time stretch
|
147 |
+
try:
|
148 |
+
process = (
|
149 |
+
ffmpeg.input(audio_fn, ss=0, t=float(item[1]) / self.args.encodec_sr)
|
150 |
+
.output('pipe:1', format='f32le', ac=1, ar=self.audio_tokenizer.sample_rate, filter='atempo={}'.format(speed_factor))
|
151 |
+
.run_async(pipe_stdout=True, pipe_stderr=True)
|
152 |
+
)
|
153 |
+
# Read the processed audio from ffmpeg stdout
|
154 |
+
output, _ = process.communicate()
|
155 |
+
|
156 |
+
# Convert the output to a numpy array
|
157 |
+
output_np = np.frombuffer(output, dtype=np.float32).copy()
|
158 |
+
|
159 |
+
# Reshape the numpy array back to the expected shape (1, samples for mono)
|
160 |
+
waveform = torch.from_numpy(output_np)
|
161 |
+
waveform = waveform.unsqueeze(0).unsqueeze(0)
|
162 |
+
assert waveform.ndim == 3 and waveform.shape[0] == 1 and waveform.shape[1] == 1, waveform.shape
|
163 |
+
with torch.no_grad():
|
164 |
+
encos = self.audio_tokenizer.encode(waveform.to(self.audio_tokenizer._device))
|
165 |
+
assert encos.shape[1] == self.args.n_codebooks, f"encos.shape: {encos.shape}"
|
166 |
+
encos = encos.cpu().squeeze(0).numpy().tolist() # [K, T]
|
167 |
+
if self.args.special_first:
|
168 |
+
raise NotImplementedError
|
169 |
+
# y = [[int(n)+self.args.n_special for n in l] for l in encos]
|
170 |
+
else:
|
171 |
+
y = [[int(n) for n in l] for l in encos]
|
172 |
+
return x, y, dataset_dir, audio_ext
|
173 |
+
except Exception as e:
|
174 |
+
logging.info(f"failed with time stretch and codec encode for {audio_fn}")
|
175 |
+
logging.info(f"error: {e}")
|
176 |
+
pass
|
177 |
+
|
178 |
+
try:
|
179 |
+
with open(pf, "r") as p, open(ef, "r") as e:
|
180 |
+
phns = [l.strip() for l in p.readlines()]
|
181 |
+
assert len(phns) == 1, phns
|
182 |
+
all_phns = phns[0].split(" ")
|
183 |
+
x = [self.phn2num[item] for item in all_phns if item in self.phn2num] # we assume that OOD will not happen, because phn vocab is small
|
184 |
+
encos = [l.strip().split() for k, l in enumerate(e.readlines()) if k < self.args.n_codebooks]
|
185 |
+
|
186 |
+
assert len(encos) == self.args.n_codebooks, ef
|
187 |
+
|
188 |
+
if self.args.special_first:
|
189 |
+
raise NotImplementedError
|
190 |
+
# y = [[int(n)+self.args.n_special for n in l] for l in encos]
|
191 |
+
else:
|
192 |
+
y = [[int(n) for n in l] for l in encos]
|
193 |
+
except:
|
194 |
+
logging.info(f"loading failed for {pf} and {ef}, maybe files don't exist or are corrupted")
|
195 |
+
return [], [[]], dataset_dir, audio_ext
|
196 |
+
|
197 |
+
return x, y, dataset_dir, audio_ext
|
198 |
+
|
199 |
+
# this uses the output of step7_ipa_alignment.py
|
200 |
+
def find_neighbor(self, neighbors, y_len, dataset_dir, audio_ext):
|
201 |
+
neighbor = random.choice(neighbors)
|
202 |
+
neighbor_enc_fn = os.path.join(dataset_dir, self.args.encodec_folder_name, neighbor[0])
|
203 |
+
if not os.path.isfile(neighbor_enc_fn):
|
204 |
+
return None, None
|
205 |
+
neighbor_audio_path = os.path.join(dataset_dir, self.args.audio_folder_name, neighbor[0].replace(".txt", audio_ext))
|
206 |
+
if getattr(self.args, "time_stretch_prob", 0) > 0 and not os.path.isfile(neighbor_audio_path):
|
207 |
+
logging.info(f"audio file not found: {neighbor_audio_path}")
|
208 |
+
return None, None
|
209 |
+
if random.random() < getattr(self.args, "time_stretch_prob", 0):
|
210 |
+
time_stretch_flag = True
|
211 |
+
speed_factor = random.uniform(-self.args.time_stretch_bound, self.args.time_stretch_bound) + 1
|
212 |
+
duration_factor = 1 / speed_factor
|
213 |
+
else:
|
214 |
+
time_stretch_flag = False
|
215 |
+
duration_factor = 1
|
216 |
+
|
217 |
+
####################### TODO for now always use the entire neighbor for emilia
|
218 |
+
####################### TODO for now always use the entire neighbor for emilia
|
219 |
+
# if it's gigaspeech or emilia, we did not run MFA forced alignment, and therefore no ipa alignment, and will just use the entire neighbor as the prompt
|
220 |
+
|
221 |
+
if "/emilia" in dataset_dir:
|
222 |
+
# get neighbor duration
|
223 |
+
neighbor_dur = float(neighbor[2])
|
224 |
+
if neighbor_dur * duration_factor + y_len / self.args.encodec_sr > self.args.audio_max_length or neighbor_dur * duration_factor < self.args.min_prompt_len:
|
225 |
+
return None, None
|
226 |
+
try:
|
227 |
+
neighbor_pf = os.path.join(dataset_dir, self.args.phn_folder_name, neighbor[0])
|
228 |
+
with open(neighbor_pf, "r") as p:
|
229 |
+
phns = [l.strip() for l in p.readlines()]
|
230 |
+
assert len(phns) == 1, phns
|
231 |
+
all_phns = phns[0].split(" ")
|
232 |
+
phn_token = [self.phn2num[item] for item in all_phns if item in self.phn2num]
|
233 |
+
except:
|
234 |
+
logging.info(f"loading failed for {neighbor_pf}, maybe files don't exist")
|
235 |
+
return None, None
|
236 |
+
# if do not stretch the audio
|
237 |
+
if not time_stretch_flag:
|
238 |
+
with open(neighbor_enc_fn, "r") as f:
|
239 |
+
neighbor_enc = [l.strip().split() for l in f.readlines()]
|
240 |
+
if len(neighbor_enc) != self.args.n_codebooks:
|
241 |
+
return None, None
|
242 |
+
# if too long
|
243 |
+
else:
|
244 |
+
if self.args.special_first:
|
245 |
+
raise NotImplementedError
|
246 |
+
# neighbor_enc = [[int(n)+self.args.n_special for n in l] for l in neighbor_enc]
|
247 |
+
else:
|
248 |
+
neighbor_enc = [[int(n) for n in l] for l in neighbor_enc]
|
249 |
+
|
250 |
+
return phn_token, neighbor_enc
|
251 |
+
else: # stretch the audio with ffmpeg-python
|
252 |
+
process = (
|
253 |
+
ffmpeg.input(neighbor_audio_path, ss=0, t=neighbor_dur)
|
254 |
+
.output('pipe:1', format='f32le', ac=1, ar=self.audio_tokenizer.sample_rate, filter='atempo={}'.format(speed_factor))
|
255 |
+
.run_async(pipe_stdout=True, pipe_stderr=True)
|
256 |
+
)
|
257 |
+
# Read the processed audio from ffmpeg stdout
|
258 |
+
output, _ = process.communicate()
|
259 |
+
|
260 |
+
# Convert the output to a numpy array
|
261 |
+
output_np = np.frombuffer(output, dtype=np.float32).copy()
|
262 |
+
|
263 |
+
# Reshape the numpy array back to the expected shape (1, samples for mono)
|
264 |
+
waveform = torch.from_numpy(output_np)
|
265 |
+
waveform = waveform.unsqueeze(0).unsqueeze(0)
|
266 |
+
assert waveform.ndim == 3 and waveform.shape[0] == 1 and waveform.shape[1] == 1, waveform.shape
|
267 |
+
with torch.no_grad():
|
268 |
+
encos = self.audio_tokenizer.encode(waveform.to(self.audio_tokenizer._device))
|
269 |
+
assert encos.shape[1] == self.args.n_codebooks, f"encos.shape: {encos.shape}"
|
270 |
+
neighbor_enc = encos.cpu().squeeze(0).numpy().tolist() # [K, T]
|
271 |
+
return phn_token, neighbor_enc
|
272 |
+
####################### TODO for now always use the entire neighbor for emilia
|
273 |
+
####################### TODO for now always use the entire neighbor for emilia
|
274 |
+
ipa_alignment_fn = os.path.join(dataset_dir, self.args.ipa_alignment_folder_name, neighbor[0])
|
275 |
+
if not os.path.isfile(ipa_alignment_fn):
|
276 |
+
# print(f"file not found: {ipa_alignment_fn}", flush=True)
|
277 |
+
return None, None
|
278 |
+
with open(ipa_alignment_fn, "r") as f:
|
279 |
+
alignments = [l.strip().split("\t") for l in f.readlines()]
|
280 |
+
alignments = [[float(l[0]), float(l[1]), l[2]] for l in alignments if len(l) == 3]
|
281 |
+
alignments = [l for l in alignments if self.args.min_prompt_len < (l[1] - l[0]) * duration_factor < self.args.max_prompt_len]
|
282 |
+
if len(alignments) == 0:
|
283 |
+
# print(f"no valid alignment found for {ipa_alignment_fn}")
|
284 |
+
return None, None
|
285 |
+
idx = random.choice(range(len(alignments)))
|
286 |
+
while (alignments[idx][1] - alignments[idx][0]) * duration_factor + y_len / self.args.encodec_sr > self.args.audio_max_length:
|
287 |
+
idx -= 1
|
288 |
+
if idx < 0:
|
289 |
+
# print(f"too long combined with y_len {ipa_alignment_fn=}, and {y_len=}")
|
290 |
+
return None, None
|
291 |
+
if (alignments[idx][1] - alignments[idx][0]) * duration_factor < self.args.min_prompt_len:
|
292 |
+
return None, None
|
293 |
+
|
294 |
+
start_time, end_time = alignments[idx][:2]
|
295 |
+
phn = alignments[idx][2].split(" ")
|
296 |
+
phn_token = [self.phn2num[item] for item in phn if item in self.phn2num]
|
297 |
+
if len(phn_token) == 0:
|
298 |
+
return None, None
|
299 |
+
|
300 |
+
if time_stretch_flag:
|
301 |
+
duration = end_time - start_time
|
302 |
+
process = (
|
303 |
+
ffmpeg.input(neighbor_audio_path, ss=start_time, t=duration)
|
304 |
+
.output('pipe:1', format='f32le', ac=1, ar=self.audio_tokenizer.sample_rate, filter='atempo={}'.format(speed_factor))
|
305 |
+
.run_async(pipe_stdout=True, pipe_stderr=True)
|
306 |
+
)
|
307 |
+
# Read the processed audio from ffmpeg stdout
|
308 |
+
output, _ = process.communicate()
|
309 |
+
|
310 |
+
# Convert the output to a numpy array
|
311 |
+
output_np = np.frombuffer(output, dtype=np.float32).copy()
|
312 |
+
|
313 |
+
# Reshape the numpy array back to the expected shape (1, samples for mono)
|
314 |
+
waveform = torch.from_numpy(output_np)
|
315 |
+
waveform = waveform.unsqueeze(0).unsqueeze(0)
|
316 |
+
assert waveform.ndim == 3 and waveform.shape[0] == 1 and waveform.shape[1] == 1, waveform.shape
|
317 |
+
try:
|
318 |
+
with torch.no_grad():
|
319 |
+
encos = self.audio_tokenizer.encode(waveform.to(self.audio_tokenizer._device))
|
320 |
+
except:
|
321 |
+
logging.info(f"failed with time stretch for {neighbor_audio_path}, from {start_time} to {end_time} with duration factor {duration_factor}, which leads to {duration*duration_factor} seconds")
|
322 |
+
return None, None
|
323 |
+
assert encos.shape[1] == self.args.n_codebooks, f"encos.shape: {encos.shape}"
|
324 |
+
neighbor_enc = encos.cpu().squeeze(0).numpy().tolist() # [K, T]
|
325 |
+
return phn_token, neighbor_enc
|
326 |
+
else:
|
327 |
+
# get encodec codes from storage
|
328 |
+
with open(neighbor_enc_fn, "r") as f:
|
329 |
+
neighbor_enc = [l.strip().split() for l in f.readlines()]
|
330 |
+
if len(neighbor_enc) != self.args.n_codebooks:
|
331 |
+
# print(f"wrong number of codebooks for {neighbor_enc_fn}")
|
332 |
+
return None, None
|
333 |
+
else:
|
334 |
+
# trim the encodec codes to the segment
|
335 |
+
start_enc_frame = int(start_time * self.args.encodec_sr)
|
336 |
+
end_enc_frame = int(end_time * self.args.encodec_sr)
|
337 |
+
neighbor_enc = [l[start_enc_frame:end_enc_frame] for l in neighbor_enc]
|
338 |
+
if len(neighbor_enc[0]) == 0:
|
339 |
+
# print(f"no valid encodec codes found for {neighbor_enc_fn}")
|
340 |
+
return None, None
|
341 |
+
if self.args.special_first:
|
342 |
+
raise NotImplementedError
|
343 |
+
else:
|
344 |
+
neighbor_enc = [[int(n) for n in l] for l in neighbor_enc]
|
345 |
+
return phn_token, neighbor_enc
|
346 |
+
|
347 |
+
def __getitem__(self, index):
|
348 |
+
x, y, dataset_dir, audio_ext = self._load_phn_enc(index)
|
349 |
+
x_len, y_len = len(x), len(y[0])
|
350 |
+
extra_ret = {'x_sep_token_position': 0, 'y_sep_token_position': 0}
|
351 |
+
if x_len == 0 or y_len == 0:
|
352 |
+
ret = {
|
353 |
+
"x": None,
|
354 |
+
"x_len": None,
|
355 |
+
"y": None,
|
356 |
+
"y_len": None,
|
357 |
+
}
|
358 |
+
ret.update(extra_ret)
|
359 |
+
return ret
|
360 |
+
while y_len < self.args.encodec_sr*self.args.audio_min_length:
|
361 |
+
assert not self.args.dynamic_batching
|
362 |
+
index = random.choice(range(len(self))) # regenerate an index
|
363 |
+
x, y, dataset_dir, audio_ext = self._load_phn_enc(index)
|
364 |
+
x_len, y_len = len(x), len(y[0])
|
365 |
+
|
366 |
+
# if use neighbor prompt
|
367 |
+
x_neighbor, y_neighbor = None, None
|
368 |
+
use_neighbor_prob = random.random()
|
369 |
+
neighbor_fn = os.path.join(dataset_dir, self.args.neighbor_folder_name, self.data[index][0]+".txt")
|
370 |
+
if self.args.neighbor_prompt_prob > 0 and use_neighbor_prob < self.args.neighbor_prompt_prob and os.path.isfile(neighbor_fn): # it might not exist, just because we didn't find neighbor for this file (other than itself, which is common for emilia)
|
371 |
+
with open(neighbor_fn, "r") as f:
|
372 |
+
neighbors = [l.strip().split("\t") for l in f.readlines()]
|
373 |
+
# select neighbors
|
374 |
+
if "maxdist" in self.args.neighbor_selection_method:
|
375 |
+
maxdist = int(self.args.neighbor_selection_method.split("_")[-1])
|
376 |
+
# only keep neighbors with distance within maxdist
|
377 |
+
neighbors = [n for n in neighbors if float(n[1]) <= maxdist]
|
378 |
+
else:
|
379 |
+
raise NotImplementedError
|
380 |
+
x_neighbor, y_neighbor = None, None
|
381 |
+
if len(neighbors) > 0:
|
382 |
+
x_neighbor, y_neighbor = self.find_neighbor(neighbors, y_len, dataset_dir, audio_ext)
|
383 |
+
i_trial = 0
|
384 |
+
while x_neighbor is None and i_trial < self.args.num_trial and i_trial < len(neighbors):
|
385 |
+
x_neighbor, y_neighbor = self.find_neighbor(neighbors, y_len, dataset_dir, audio_ext)
|
386 |
+
i_trial += 1
|
387 |
+
|
388 |
+
if x_neighbor != None:
|
389 |
+
if self.args.x_sep_token != None:
|
390 |
+
x = x_neighbor + [self.args.x_sep_token] + x
|
391 |
+
else:
|
392 |
+
x = x_neighbor + x
|
393 |
+
if self.args.y_sep_token != None:
|
394 |
+
y = [y_neighbor[i] + [self.args.y_sep_token] + y[i] for i in range(len(y))]
|
395 |
+
else:
|
396 |
+
y = [y_neighbor[i] + y[i] for i in range(len(y))]
|
397 |
+
extra_ret['y_sep_token_position'] = len(y_neighbor[0]) + 1 # if using y_sep_token, this is actually the position of the token right before the y_sep_token, but since y_sep_token is ignored in loss computation, it's fine that we use the position of the token right before it
|
398 |
+
extra_ret['x_sep_token_position'] = len(x_neighbor) + 1
|
399 |
+
x_len, y_len = len(x), len(y[0])
|
400 |
+
|
401 |
+
|
402 |
+
# consider adding eos to the end of the text
|
403 |
+
if self.args.add_eos_to_text != 0:
|
404 |
+
x.append(self.args.add_eos_to_text)
|
405 |
+
x_len += 1
|
406 |
+
if getattr(self.args, "add_bos_to_text", 0) != 0:
|
407 |
+
x = [self.args.add_bos_to_text] + x
|
408 |
+
x_len += 1
|
409 |
+
### padding and cropping ###
|
410 |
+
### padding and cropping ###
|
411 |
+
# adjust the length of encodec codes, pad to max_len or randomly crop
|
412 |
+
orig_y_len = copy.copy(y_len)
|
413 |
+
max_len = int(self.args.audio_max_length * self.args.encodec_sr)
|
414 |
+
if y_len > max_len + 10: # give it some margin for rounding error
|
415 |
+
raise RuntimeError(f"audio is too long, {y_len=}, {max_len=}")
|
416 |
+
else:
|
417 |
+
audio_start = 0
|
418 |
+
if not self.args.dynamic_batching:
|
419 |
+
pad = [0] * (max_len - y_len) if self.args.sep_special_token else [self.args.audio_pad_token] * (max_len - y_len)
|
420 |
+
for i in range(len(y)):
|
421 |
+
y[i] = y[i] + pad
|
422 |
+
|
423 |
+
if self.args.pad_x and x_len <= self.args.text_max_length:
|
424 |
+
pad = [0] * (self.args.text_max_length - x_len) if self.args.sep_special_token else [self.args.text_pad_token] * (self.args.text_max_length - x_len)
|
425 |
+
x = x + pad
|
426 |
+
|
427 |
+
ret = {
|
428 |
+
"x": torch.LongTensor(x),
|
429 |
+
"x_len": x_len,
|
430 |
+
"y": torch.LongTensor(y),
|
431 |
+
"y_len": y_len,
|
432 |
+
}
|
433 |
+
ret.update(extra_ret)
|
434 |
+
|
435 |
+
return ret
|
436 |
+
|
437 |
+
|
438 |
+
def collate(self, batch):
|
439 |
+
# make sure keys in every batch is the same
|
440 |
+
for batch1, batch2 in zip(batch[:-1], batch[1:]):
|
441 |
+
assert set(batch1.keys()) == set(batch2.keys()), f"keys in batch1: {batch1.keys()} and keys in batch2: {batch2.keys()} are different"
|
442 |
+
out = {key:[] for key in batch[0]}
|
443 |
+
for item in batch:
|
444 |
+
if item['x'] == None: # deal with load failure
|
445 |
+
continue
|
446 |
+
for key, val in item.items():
|
447 |
+
out[key].append(val)
|
448 |
+
res = {}
|
449 |
+
if self.args.pad_x:
|
450 |
+
res["x"] = torch.stack(out["x"], dim=0)
|
451 |
+
else:
|
452 |
+
res["x"] = torch.nn.utils.rnn.pad_sequence(out["x"], batch_first=True, padding_value=self.args.text_pad_token)
|
453 |
+
res["x_lens"] = torch.LongTensor(out["x_len"])
|
454 |
+
if self.args.dynamic_batching:
|
455 |
+
res['y'] = torch.nn.utils.rnn.pad_sequence([item.transpose(1,0) for item in out['y']],padding_value=self.args.audio_pad_token)
|
456 |
+
res['y'] = res['y'].permute(1,2,0) # T B K -> B K T
|
457 |
+
else:
|
458 |
+
res['y'] = torch.stack(out['y'], dim=0)
|
459 |
+
res["y_lens"] = torch.LongTensor(out["y_len"])
|
460 |
+
res["text_padding_mask"] = torch.arange(res['x'][0].shape[-1]).unsqueeze(0) >= res['x_lens'].unsqueeze(1)
|
461 |
+
res["audio_padding_mask"] = torch.arange(res['y'][0].shape[-1]).unsqueeze(0) >= res['y_lens'].unsqueeze(1)
|
462 |
+
if "y_sep_token_position" in out:
|
463 |
+
res["y_sep_token_position"] = torch.LongTensor(out["y_sep_token_position"])
|
464 |
+
if "x_sep_token_position" in out:
|
465 |
+
res["x_sep_token_position"] = torch.LongTensor(out["x_sep_token_position"])
|
466 |
+
return res
|
data/emilia_preprocessing/delete_tar_files.sh
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Define the root directory where the tar files are located
|
4 |
+
root=${root:-/data/scratch/pyp/datasets/emilia/downloads} # Example: /data/scratch/pyp/datasets/emilia/downloads
|
5 |
+
exist_log_file="file_log_debug.txt" # Log of files to delete
|
6 |
+
delete_log="deleted_files.log" # Log of successfully deleted files
|
7 |
+
error_log="delete_errors.log" # Log of errors (e.g., missing files)
|
8 |
+
|
9 |
+
# Clear previous logs
|
10 |
+
> "$delete_log"
|
11 |
+
> "$error_log"
|
12 |
+
|
13 |
+
echo "Starting deletion of tar files listed in $exist_log_file..."
|
14 |
+
|
15 |
+
# Loop through each line in exist_log_file
|
16 |
+
while IFS=',' read -r filename size local_sha256 original_filename url; do
|
17 |
+
# Trim leading/trailing whitespace
|
18 |
+
original_filename=$(echo "$original_filename" | xargs)
|
19 |
+
|
20 |
+
# Construct the full path to the tar file
|
21 |
+
tar_file="${root}/${original_filename}"
|
22 |
+
|
23 |
+
# Check if the tar file exists
|
24 |
+
if [ -f "$tar_file" ]; then
|
25 |
+
# Attempt to delete the file
|
26 |
+
if rm -f "$tar_file"; then
|
27 |
+
echo "✅ Deleted: $tar_file"
|
28 |
+
echo "$tar_file" >> "$delete_log"
|
29 |
+
else
|
30 |
+
echo "❌ Failed to delete: $tar_file"
|
31 |
+
echo "$tar_file" >> "$error_log"
|
32 |
+
fi
|
33 |
+
else
|
34 |
+
# Log missing files
|
35 |
+
echo "❌ File not found: $tar_file"
|
36 |
+
echo "$tar_file" >> "$error_log"
|
37 |
+
fi
|
38 |
+
done < "$exist_log_file"
|
39 |
+
|
40 |
+
echo "Deletion process completed."
|
41 |
+
echo "Deleted files are logged in $delete_log."
|
42 |
+
echo "Errors (if any) are logged in $error_log."
|
data/emilia_preprocessing/encodec.py
ADDED
@@ -0,0 +1,1554 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Compression models or wrapper around existing models.
|
7 |
+
Also defines the main interface that a model must follow to be usable as an audio tokenizer.
|
8 |
+
"""
|
9 |
+
|
10 |
+
from abc import ABC, abstractmethod
|
11 |
+
from dataclasses import dataclass, field
|
12 |
+
import logging
|
13 |
+
import math
|
14 |
+
from pathlib import Path
|
15 |
+
import typing as tp
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
import torch
|
19 |
+
from torch import nn
|
20 |
+
from torch import einsum
|
21 |
+
import torch.nn.functional as F
|
22 |
+
from torch.nn.utils import spectral_norm, weight_norm
|
23 |
+
|
24 |
+
import logging
|
25 |
+
import warnings
|
26 |
+
from einops import rearrange, repeat
|
27 |
+
import omegaconf
|
28 |
+
# import flashy
|
29 |
+
|
30 |
+
CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
|
31 |
+
'time_group_norm'])
|
32 |
+
|
33 |
+
def dict_from_config(cfg: omegaconf.DictConfig) -> dict:
|
34 |
+
"""Convenience function to map an omegaconf configuration to a dictionary.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
cfg (omegaconf.DictConfig): Original configuration to map to dict.
|
38 |
+
Returns:
|
39 |
+
dict: Config as dictionary object.
|
40 |
+
"""
|
41 |
+
dct = omegaconf.OmegaConf.to_container(cfg, resolve=True)
|
42 |
+
assert isinstance(dct, dict)
|
43 |
+
return dct
|
44 |
+
|
45 |
+
@dataclass
|
46 |
+
class QuantizedResult:
|
47 |
+
x: torch.Tensor
|
48 |
+
codes: torch.Tensor
|
49 |
+
bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
|
50 |
+
penalty: tp.Optional[torch.Tensor] = None
|
51 |
+
metrics: dict = field(default_factory=dict)
|
52 |
+
|
53 |
+
class BaseQuantizer(nn.Module):
|
54 |
+
"""Base class for quantizers.
|
55 |
+
"""
|
56 |
+
|
57 |
+
def forward(self, x: torch.Tensor, frame_rate: int) -> QuantizedResult:
|
58 |
+
"""
|
59 |
+
Given input tensor x, returns first the quantized (or approximately quantized)
|
60 |
+
representation along with quantized codes, bandwidth, and any penalty term for the loss.
|
61 |
+
Finally, this returns a dict of metrics to update logging etc.
|
62 |
+
Frame rate must be passed so that the bandwidth is properly computed.
|
63 |
+
"""
|
64 |
+
raise NotImplementedError()
|
65 |
+
|
66 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
67 |
+
"""Encode a given input tensor with the specified sample rate at the given bandwidth."""
|
68 |
+
raise NotImplementedError()
|
69 |
+
|
70 |
+
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
71 |
+
"""Decode the given codes to the quantized representation."""
|
72 |
+
raise NotImplementedError()
|
73 |
+
|
74 |
+
@property
|
75 |
+
def total_codebooks(self):
|
76 |
+
"""Total number of codebooks."""
|
77 |
+
raise NotImplementedError()
|
78 |
+
|
79 |
+
@property
|
80 |
+
def num_codebooks(self):
|
81 |
+
"""Number of active codebooks."""
|
82 |
+
raise NotImplementedError()
|
83 |
+
|
84 |
+
def set_num_codebooks(self, n: int):
|
85 |
+
"""Set the number of active codebooks."""
|
86 |
+
raise NotImplementedError()
|
87 |
+
|
88 |
+
class CompressionModel(ABC, nn.Module):
|
89 |
+
"""Base API for all compression model that aim at being used as audio tokenizers
|
90 |
+
with a language model.
|
91 |
+
"""
|
92 |
+
|
93 |
+
@abstractmethod
|
94 |
+
def forward(self, x: torch.Tensor) -> QuantizedResult:
|
95 |
+
...
|
96 |
+
|
97 |
+
@abstractmethod
|
98 |
+
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
99 |
+
"""See `EncodecModel.encode`."""
|
100 |
+
...
|
101 |
+
|
102 |
+
@abstractmethod
|
103 |
+
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
104 |
+
"""See `EncodecModel.decode`."""
|
105 |
+
...
|
106 |
+
|
107 |
+
@abstractmethod
|
108 |
+
def decode_latent(self, codes: torch.Tensor):
|
109 |
+
"""Decode from the discrete codes to continuous latent space."""
|
110 |
+
...
|
111 |
+
|
112 |
+
@property
|
113 |
+
@abstractmethod
|
114 |
+
def channels(self) -> int:
|
115 |
+
...
|
116 |
+
|
117 |
+
@property
|
118 |
+
@abstractmethod
|
119 |
+
def frame_rate(self) -> float:
|
120 |
+
...
|
121 |
+
|
122 |
+
@property
|
123 |
+
@abstractmethod
|
124 |
+
def sample_rate(self) -> int:
|
125 |
+
...
|
126 |
+
|
127 |
+
@property
|
128 |
+
@abstractmethod
|
129 |
+
def cardinality(self) -> int:
|
130 |
+
...
|
131 |
+
|
132 |
+
@property
|
133 |
+
@abstractmethod
|
134 |
+
def num_codebooks(self) -> int:
|
135 |
+
...
|
136 |
+
|
137 |
+
@property
|
138 |
+
@abstractmethod
|
139 |
+
def total_codebooks(self) -> int:
|
140 |
+
...
|
141 |
+
|
142 |
+
@abstractmethod
|
143 |
+
def set_num_codebooks(self, n: int):
|
144 |
+
"""Set the active number of codebooks used by the quantizer."""
|
145 |
+
...
|
146 |
+
|
147 |
+
def apply_parametrization_norm(module: nn.Module, norm: str = 'none'):
|
148 |
+
assert norm in CONV_NORMALIZATIONS
|
149 |
+
if norm == 'weight_norm':
|
150 |
+
return weight_norm(module)
|
151 |
+
elif norm == 'spectral_norm':
|
152 |
+
return spectral_norm(module)
|
153 |
+
else:
|
154 |
+
# We already check was in CONV_NORMALIZATION, so any other choice
|
155 |
+
# doesn't need reparametrization.
|
156 |
+
return module
|
157 |
+
|
158 |
+
|
159 |
+
def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs):
|
160 |
+
"""Return the proper normalization module. If causal is True, this will ensure the returned
|
161 |
+
module is causal, or return an error if the normalization doesn't support causal evaluation.
|
162 |
+
"""
|
163 |
+
assert norm in CONV_NORMALIZATIONS
|
164 |
+
if norm == 'time_group_norm':
|
165 |
+
if causal:
|
166 |
+
raise ValueError("GroupNorm doesn't support causal evaluation.")
|
167 |
+
assert isinstance(module, nn.modules.conv._ConvNd)
|
168 |
+
return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
|
169 |
+
else:
|
170 |
+
return nn.Identity()
|
171 |
+
|
172 |
+
|
173 |
+
def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
|
174 |
+
padding_total: int = 0) -> int:
|
175 |
+
"""See `pad_for_conv1d`."""
|
176 |
+
length = x.shape[-1]
|
177 |
+
n_frames = (length - kernel_size + padding_total) / stride + 1
|
178 |
+
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
179 |
+
return ideal_length - length
|
180 |
+
|
181 |
+
|
182 |
+
def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
|
183 |
+
"""Pad for a convolution to make sure that the last window is full.
|
184 |
+
Extra padding is added at the end. This is required to ensure that we can rebuild
|
185 |
+
an output of the same length, as otherwise, even with padding, some time steps
|
186 |
+
might get removed.
|
187 |
+
For instance, with total padding = 4, kernel size = 4, stride = 2:
|
188 |
+
0 0 1 2 3 4 5 0 0 # (0s are padding)
|
189 |
+
1 2 3 # (output frames of a convolution, last 0 is never used)
|
190 |
+
0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
|
191 |
+
1 2 3 4 # once you removed padding, we are missing one time step !
|
192 |
+
"""
|
193 |
+
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
194 |
+
return F.pad(x, (0, extra_padding))
|
195 |
+
|
196 |
+
|
197 |
+
def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
|
198 |
+
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
|
199 |
+
If this is the case, we insert extra 0 padding to the right before the reflection happen.
|
200 |
+
"""
|
201 |
+
length = x.shape[-1]
|
202 |
+
padding_left, padding_right = paddings
|
203 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
204 |
+
if mode == 'reflect':
|
205 |
+
max_pad = max(padding_left, padding_right)
|
206 |
+
extra_pad = 0
|
207 |
+
if length <= max_pad:
|
208 |
+
extra_pad = max_pad - length + 1
|
209 |
+
x = F.pad(x, (0, extra_pad))
|
210 |
+
padded = F.pad(x, paddings, mode, value)
|
211 |
+
end = padded.shape[-1] - extra_pad
|
212 |
+
return padded[..., :end]
|
213 |
+
else:
|
214 |
+
return F.pad(x, paddings, mode, value)
|
215 |
+
|
216 |
+
|
217 |
+
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
|
218 |
+
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
|
219 |
+
padding_left, padding_right = paddings
|
220 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
221 |
+
assert (padding_left + padding_right) <= x.shape[-1]
|
222 |
+
end = x.shape[-1] - padding_right
|
223 |
+
return x[..., padding_left: end]
|
224 |
+
|
225 |
+
|
226 |
+
class NormConv1d(nn.Module):
|
227 |
+
"""Wrapper around Conv1d and normalization applied to this conv
|
228 |
+
to provide a uniform interface across normalization approaches.
|
229 |
+
"""
|
230 |
+
def __init__(self, *args, causal: bool = False, norm: str = 'none',
|
231 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
232 |
+
super().__init__()
|
233 |
+
self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
|
234 |
+
self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
|
235 |
+
self.norm_type = norm
|
236 |
+
|
237 |
+
def forward(self, x):
|
238 |
+
x = self.conv(x)
|
239 |
+
x = self.norm(x)
|
240 |
+
return x
|
241 |
+
|
242 |
+
|
243 |
+
class NormConv2d(nn.Module):
|
244 |
+
"""Wrapper around Conv2d and normalization applied to this conv
|
245 |
+
to provide a uniform interface across normalization approaches.
|
246 |
+
"""
|
247 |
+
def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
248 |
+
super().__init__()
|
249 |
+
self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
|
250 |
+
self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
|
251 |
+
self.norm_type = norm
|
252 |
+
|
253 |
+
def forward(self, x):
|
254 |
+
x = self.conv(x)
|
255 |
+
x = self.norm(x)
|
256 |
+
return x
|
257 |
+
|
258 |
+
|
259 |
+
class NormConvTranspose1d(nn.Module):
|
260 |
+
"""Wrapper around ConvTranspose1d and normalization applied to this conv
|
261 |
+
to provide a uniform interface across normalization approaches.
|
262 |
+
"""
|
263 |
+
def __init__(self, *args, causal: bool = False, norm: str = 'none',
|
264 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
265 |
+
super().__init__()
|
266 |
+
self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
|
267 |
+
self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
|
268 |
+
self.norm_type = norm
|
269 |
+
|
270 |
+
def forward(self, x):
|
271 |
+
x = self.convtr(x)
|
272 |
+
x = self.norm(x)
|
273 |
+
return x
|
274 |
+
|
275 |
+
|
276 |
+
class NormConvTranspose2d(nn.Module):
|
277 |
+
"""Wrapper around ConvTranspose2d and normalization applied to this conv
|
278 |
+
to provide a uniform interface across normalization approaches.
|
279 |
+
"""
|
280 |
+
def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
281 |
+
super().__init__()
|
282 |
+
self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
|
283 |
+
self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
|
284 |
+
|
285 |
+
def forward(self, x):
|
286 |
+
x = self.convtr(x)
|
287 |
+
x = self.norm(x)
|
288 |
+
return x
|
289 |
+
|
290 |
+
|
291 |
+
class StreamableConv1d(nn.Module):
|
292 |
+
"""Conv1d with some builtin handling of asymmetric or causal padding
|
293 |
+
and normalization.
|
294 |
+
"""
|
295 |
+
def __init__(self, in_channels: int, out_channels: int,
|
296 |
+
kernel_size: int, stride: int = 1, dilation: int = 1,
|
297 |
+
groups: int = 1, bias: bool = True, causal: bool = False,
|
298 |
+
norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
|
299 |
+
pad_mode: str = 'reflect'):
|
300 |
+
super().__init__()
|
301 |
+
# warn user on unusual setup between dilation and stride
|
302 |
+
if stride > 1 and dilation > 1:
|
303 |
+
warnings.warn("StreamableConv1d has been initialized with stride > 1 and dilation > 1"
|
304 |
+
f" (kernel_size={kernel_size} stride={stride}, dilation={dilation}).")
|
305 |
+
self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
|
306 |
+
dilation=dilation, groups=groups, bias=bias, causal=causal,
|
307 |
+
norm=norm, norm_kwargs=norm_kwargs)
|
308 |
+
self.causal = causal
|
309 |
+
self.pad_mode = pad_mode
|
310 |
+
|
311 |
+
def forward(self, x):
|
312 |
+
B, C, T = x.shape
|
313 |
+
kernel_size = self.conv.conv.kernel_size[0]
|
314 |
+
stride = self.conv.conv.stride[0]
|
315 |
+
dilation = self.conv.conv.dilation[0]
|
316 |
+
kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
|
317 |
+
padding_total = kernel_size - stride
|
318 |
+
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
319 |
+
if self.causal:
|
320 |
+
# Left padding for causal
|
321 |
+
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
|
322 |
+
else:
|
323 |
+
# Asymmetric padding required for odd strides
|
324 |
+
padding_right = padding_total // 2
|
325 |
+
padding_left = padding_total - padding_right
|
326 |
+
x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
|
327 |
+
return self.conv(x)
|
328 |
+
|
329 |
+
|
330 |
+
class StreamableConvTranspose1d(nn.Module):
|
331 |
+
"""ConvTranspose1d with some builtin handling of asymmetric or causal padding
|
332 |
+
and normalization.
|
333 |
+
"""
|
334 |
+
def __init__(self, in_channels: int, out_channels: int,
|
335 |
+
kernel_size: int, stride: int = 1, causal: bool = False,
|
336 |
+
norm: str = 'none', trim_right_ratio: float = 1.,
|
337 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}):
|
338 |
+
super().__init__()
|
339 |
+
self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
|
340 |
+
causal=causal, norm=norm, norm_kwargs=norm_kwargs)
|
341 |
+
self.causal = causal
|
342 |
+
self.trim_right_ratio = trim_right_ratio
|
343 |
+
assert self.causal or self.trim_right_ratio == 1., \
|
344 |
+
"`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
|
345 |
+
assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
|
346 |
+
|
347 |
+
def forward(self, x):
|
348 |
+
kernel_size = self.convtr.convtr.kernel_size[0]
|
349 |
+
stride = self.convtr.convtr.stride[0]
|
350 |
+
padding_total = kernel_size - stride
|
351 |
+
|
352 |
+
y = self.convtr(x)
|
353 |
+
|
354 |
+
# We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
|
355 |
+
# removed at the very end, when keeping only the right length for the output,
|
356 |
+
# as removing it here would require also passing the length at the matching layer
|
357 |
+
# in the encoder.
|
358 |
+
if self.causal:
|
359 |
+
# Trim the padding on the right according to the specified ratio
|
360 |
+
# if trim_right_ratio = 1.0, trim everything from right
|
361 |
+
padding_right = math.ceil(padding_total * self.trim_right_ratio)
|
362 |
+
padding_left = padding_total - padding_right
|
363 |
+
y = unpad1d(y, (padding_left, padding_right))
|
364 |
+
else:
|
365 |
+
# Asymmetric padding required for odd strides
|
366 |
+
padding_right = padding_total // 2
|
367 |
+
padding_left = padding_total - padding_right
|
368 |
+
y = unpad1d(y, (padding_left, padding_right))
|
369 |
+
return y
|
370 |
+
|
371 |
+
|
372 |
+
class StreamableLSTM(nn.Module):
|
373 |
+
"""LSTM without worrying about the hidden state, nor the layout of the data.
|
374 |
+
Expects input as convolutional layout.
|
375 |
+
"""
|
376 |
+
def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
|
377 |
+
super().__init__()
|
378 |
+
self.skip = skip
|
379 |
+
self.lstm = nn.LSTM(dimension, dimension, num_layers)
|
380 |
+
|
381 |
+
def forward(self, x):
|
382 |
+
x = x.permute(2, 0, 1)
|
383 |
+
y, _ = self.lstm(x)
|
384 |
+
if self.skip:
|
385 |
+
y = y + x
|
386 |
+
y = y.permute(1, 2, 0)
|
387 |
+
return y
|
388 |
+
|
389 |
+
|
390 |
+
class SEANetResnetBlock(nn.Module):
|
391 |
+
"""Residual block from SEANet model.
|
392 |
+
|
393 |
+
Args:
|
394 |
+
dim (int): Dimension of the input/output.
|
395 |
+
kernel_sizes (list): List of kernel sizes for the convolutions.
|
396 |
+
dilations (list): List of dilations for the convolutions.
|
397 |
+
activation (str): Activation function.
|
398 |
+
activation_params (dict): Parameters to provide to the activation function.
|
399 |
+
norm (str): Normalization method.
|
400 |
+
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
401 |
+
causal (bool): Whether to use fully causal convolution.
|
402 |
+
pad_mode (str): Padding mode for the convolutions.
|
403 |
+
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
404 |
+
true_skip (bool): Whether to use true skip connection or a simple
|
405 |
+
(streamable) convolution as the skip connection.
|
406 |
+
"""
|
407 |
+
def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1],
|
408 |
+
activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
|
409 |
+
norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False,
|
410 |
+
pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True):
|
411 |
+
super().__init__()
|
412 |
+
assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations'
|
413 |
+
act = getattr(nn, activation)
|
414 |
+
hidden = dim // compress
|
415 |
+
block = []
|
416 |
+
for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
|
417 |
+
in_chs = dim if i == 0 else hidden
|
418 |
+
out_chs = dim if i == len(kernel_sizes) - 1 else hidden
|
419 |
+
block += [
|
420 |
+
act(**activation_params),
|
421 |
+
StreamableConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation,
|
422 |
+
norm=norm, norm_kwargs=norm_params,
|
423 |
+
causal=causal, pad_mode=pad_mode),
|
424 |
+
]
|
425 |
+
self.block = nn.Sequential(*block)
|
426 |
+
self.shortcut: nn.Module
|
427 |
+
if true_skip:
|
428 |
+
self.shortcut = nn.Identity()
|
429 |
+
else:
|
430 |
+
self.shortcut = StreamableConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params,
|
431 |
+
causal=causal, pad_mode=pad_mode)
|
432 |
+
|
433 |
+
def forward(self, x):
|
434 |
+
return self.shortcut(x) + self.block(x)
|
435 |
+
|
436 |
+
|
437 |
+
class SEANetEncoder(nn.Module):
|
438 |
+
"""SEANet encoder.
|
439 |
+
|
440 |
+
Args:
|
441 |
+
channels (int): Audio channels.
|
442 |
+
dimension (int): Intermediate representation dimension.
|
443 |
+
n_filters (int): Base width for the model.
|
444 |
+
n_residual_layers (int): nb of residual layers.
|
445 |
+
ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
|
446 |
+
upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
|
447 |
+
that must match the decoder order. We use the decoder order as some models may only employ the decoder.
|
448 |
+
activation (str): Activation function.
|
449 |
+
activation_params (dict): Parameters to provide to the activation function.
|
450 |
+
norm (str): Normalization method.
|
451 |
+
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
452 |
+
kernel_size (int): Kernel size for the initial convolution.
|
453 |
+
last_kernel_size (int): Kernel size for the initial convolution.
|
454 |
+
residual_kernel_size (int): Kernel size for the residual layers.
|
455 |
+
dilation_base (int): How much to increase the dilation with each layer.
|
456 |
+
causal (bool): Whether to use fully causal convolution.
|
457 |
+
pad_mode (str): Padding mode for the convolutions.
|
458 |
+
true_skip (bool): Whether to use true skip connection or a simple
|
459 |
+
(streamable) convolution as the skip connection in the residual network blocks.
|
460 |
+
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
461 |
+
lstm (int): Number of LSTM layers at the end of the encoder.
|
462 |
+
disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
|
463 |
+
For the encoder, it corresponds to the N first blocks.
|
464 |
+
"""
|
465 |
+
def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
|
466 |
+
ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
|
467 |
+
norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
|
468 |
+
last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
|
469 |
+
pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0,
|
470 |
+
disable_norm_outer_blocks: int = 0):
|
471 |
+
super().__init__()
|
472 |
+
self.channels = channels
|
473 |
+
self.dimension = dimension
|
474 |
+
self.n_filters = n_filters
|
475 |
+
self.ratios = list(reversed(ratios))
|
476 |
+
del ratios
|
477 |
+
self.n_residual_layers = n_residual_layers
|
478 |
+
self.hop_length = np.prod(self.ratios)
|
479 |
+
self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
|
480 |
+
self.disable_norm_outer_blocks = disable_norm_outer_blocks
|
481 |
+
assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \
|
482 |
+
"Number of blocks for which to disable norm is invalid." \
|
483 |
+
"It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
|
484 |
+
|
485 |
+
act = getattr(nn, activation)
|
486 |
+
mult = 1
|
487 |
+
model: tp.List[nn.Module] = [
|
488 |
+
StreamableConv1d(channels, mult * n_filters, kernel_size,
|
489 |
+
norm='none' if self.disable_norm_outer_blocks >= 1 else norm,
|
490 |
+
norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
|
491 |
+
]
|
492 |
+
# Downsample to raw audio scale
|
493 |
+
for i, ratio in enumerate(self.ratios):
|
494 |
+
block_norm = 'none' if self.disable_norm_outer_blocks >= i + 2 else norm
|
495 |
+
# Add residual layers
|
496 |
+
for j in range(n_residual_layers):
|
497 |
+
model += [
|
498 |
+
SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1],
|
499 |
+
dilations=[dilation_base ** j, 1],
|
500 |
+
norm=block_norm, norm_params=norm_params,
|
501 |
+
activation=activation, activation_params=activation_params,
|
502 |
+
causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
|
503 |
+
|
504 |
+
# Add downsampling layers
|
505 |
+
model += [
|
506 |
+
act(**activation_params),
|
507 |
+
StreamableConv1d(mult * n_filters, mult * n_filters * 2,
|
508 |
+
kernel_size=ratio * 2, stride=ratio,
|
509 |
+
norm=block_norm, norm_kwargs=norm_params,
|
510 |
+
causal=causal, pad_mode=pad_mode),
|
511 |
+
]
|
512 |
+
mult *= 2
|
513 |
+
|
514 |
+
if lstm:
|
515 |
+
model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
|
516 |
+
|
517 |
+
model += [
|
518 |
+
act(**activation_params),
|
519 |
+
StreamableConv1d(mult * n_filters, dimension, last_kernel_size,
|
520 |
+
norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm,
|
521 |
+
norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
|
522 |
+
]
|
523 |
+
|
524 |
+
self.model = nn.Sequential(*model)
|
525 |
+
|
526 |
+
def forward(self, x):
|
527 |
+
return self.model(x)
|
528 |
+
|
529 |
+
|
530 |
+
class SEANetDecoder(nn.Module):
|
531 |
+
"""SEANet decoder.
|
532 |
+
|
533 |
+
Args:
|
534 |
+
channels (int): Audio channels.
|
535 |
+
dimension (int): Intermediate representation dimension.
|
536 |
+
n_filters (int): Base width for the model.
|
537 |
+
n_residual_layers (int): nb of residual layers.
|
538 |
+
ratios (Sequence[int]): kernel size and stride ratios.
|
539 |
+
activation (str): Activation function.
|
540 |
+
activation_params (dict): Parameters to provide to the activation function.
|
541 |
+
final_activation (str): Final activation function after all convolutions.
|
542 |
+
final_activation_params (dict): Parameters to provide to the activation function.
|
543 |
+
norm (str): Normalization method.
|
544 |
+
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
545 |
+
kernel_size (int): Kernel size for the initial convolution.
|
546 |
+
last_kernel_size (int): Kernel size for the initial convolution.
|
547 |
+
residual_kernel_size (int): Kernel size for the residual layers.
|
548 |
+
dilation_base (int): How much to increase the dilation with each layer.
|
549 |
+
causal (bool): Whether to use fully causal convolution.
|
550 |
+
pad_mode (str): Padding mode for the convolutions.
|
551 |
+
true_skip (bool): Whether to use true skip connection or a simple.
|
552 |
+
(streamable) convolution as the skip connection in the residual network blocks.
|
553 |
+
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
554 |
+
lstm (int): Number of LSTM layers at the end of the encoder.
|
555 |
+
disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
|
556 |
+
For the decoder, it corresponds to the N last blocks.
|
557 |
+
trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
|
558 |
+
If equal to 1.0, it means that all the trimming is done at the right.
|
559 |
+
"""
|
560 |
+
def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
|
561 |
+
ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
|
562 |
+
final_activation: tp.Optional[str] = None, final_activation_params: tp.Optional[dict] = None,
|
563 |
+
norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
|
564 |
+
last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
|
565 |
+
pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0,
|
566 |
+
disable_norm_outer_blocks: int = 0, trim_right_ratio: float = 1.0):
|
567 |
+
super().__init__()
|
568 |
+
self.dimension = dimension
|
569 |
+
self.channels = channels
|
570 |
+
self.n_filters = n_filters
|
571 |
+
self.ratios = ratios
|
572 |
+
del ratios
|
573 |
+
self.n_residual_layers = n_residual_layers
|
574 |
+
self.hop_length = np.prod(self.ratios)
|
575 |
+
self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
|
576 |
+
self.disable_norm_outer_blocks = disable_norm_outer_blocks
|
577 |
+
assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \
|
578 |
+
"Number of blocks for which to disable norm is invalid." \
|
579 |
+
"It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
|
580 |
+
|
581 |
+
act = getattr(nn, activation)
|
582 |
+
mult = int(2 ** len(self.ratios))
|
583 |
+
model: tp.List[nn.Module] = [
|
584 |
+
StreamableConv1d(dimension, mult * n_filters, kernel_size,
|
585 |
+
norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm,
|
586 |
+
norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
|
587 |
+
]
|
588 |
+
|
589 |
+
if lstm:
|
590 |
+
model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
|
591 |
+
|
592 |
+
# Upsample to raw audio scale
|
593 |
+
for i, ratio in enumerate(self.ratios):
|
594 |
+
block_norm = 'none' if self.disable_norm_outer_blocks >= self.n_blocks - (i + 1) else norm
|
595 |
+
# Add upsampling layers
|
596 |
+
model += [
|
597 |
+
act(**activation_params),
|
598 |
+
StreamableConvTranspose1d(mult * n_filters, mult * n_filters // 2,
|
599 |
+
kernel_size=ratio * 2, stride=ratio,
|
600 |
+
norm=block_norm, norm_kwargs=norm_params,
|
601 |
+
causal=causal, trim_right_ratio=trim_right_ratio),
|
602 |
+
]
|
603 |
+
# Add residual layers
|
604 |
+
for j in range(n_residual_layers):
|
605 |
+
model += [
|
606 |
+
SEANetResnetBlock(mult * n_filters // 2, kernel_sizes=[residual_kernel_size, 1],
|
607 |
+
dilations=[dilation_base ** j, 1],
|
608 |
+
activation=activation, activation_params=activation_params,
|
609 |
+
norm=block_norm, norm_params=norm_params, causal=causal,
|
610 |
+
pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
|
611 |
+
|
612 |
+
mult //= 2
|
613 |
+
|
614 |
+
# Add final layers
|
615 |
+
model += [
|
616 |
+
act(**activation_params),
|
617 |
+
StreamableConv1d(n_filters, channels, last_kernel_size,
|
618 |
+
norm='none' if self.disable_norm_outer_blocks >= 1 else norm,
|
619 |
+
norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
|
620 |
+
]
|
621 |
+
# Add optional final activation to decoder (eg. tanh)
|
622 |
+
if final_activation is not None:
|
623 |
+
final_act = getattr(nn, final_activation)
|
624 |
+
final_activation_params = final_activation_params or {}
|
625 |
+
model += [
|
626 |
+
final_act(**final_activation_params)
|
627 |
+
]
|
628 |
+
self.model = nn.Sequential(*model)
|
629 |
+
|
630 |
+
def forward(self, z):
|
631 |
+
y = self.model(z)
|
632 |
+
return y
|
633 |
+
|
634 |
+
|
635 |
+
def exists(val: tp.Optional[tp.Any]) -> bool:
|
636 |
+
return val is not None
|
637 |
+
|
638 |
+
|
639 |
+
def default(val: tp.Any, d: tp.Any) -> tp.Any:
|
640 |
+
return val if exists(val) else d
|
641 |
+
|
642 |
+
|
643 |
+
def l2norm(t):
|
644 |
+
return F.normalize(t, p=2, dim=-1)
|
645 |
+
|
646 |
+
|
647 |
+
def ema_inplace(moving_avg, new, decay: float):
|
648 |
+
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
649 |
+
|
650 |
+
|
651 |
+
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
|
652 |
+
return (x + epsilon) / (x.sum() + n_categories * epsilon)
|
653 |
+
|
654 |
+
|
655 |
+
def uniform_init(*shape: int):
|
656 |
+
t = torch.empty(shape)
|
657 |
+
nn.init.kaiming_uniform_(t)
|
658 |
+
return t
|
659 |
+
|
660 |
+
|
661 |
+
def sample_vectors(samples, num: int):
|
662 |
+
num_samples, device = samples.shape[0], samples.device
|
663 |
+
|
664 |
+
if num_samples >= num:
|
665 |
+
indices = torch.randperm(num_samples, device=device)[:num]
|
666 |
+
else:
|
667 |
+
indices = torch.randint(0, num_samples, (num,), device=device)
|
668 |
+
|
669 |
+
return samples[indices]
|
670 |
+
|
671 |
+
|
672 |
+
def kmeans(samples, num_clusters: int, num_iters: int = 10):
|
673 |
+
dim, dtype = samples.shape[-1], samples.dtype
|
674 |
+
|
675 |
+
means = sample_vectors(samples, num_clusters)
|
676 |
+
|
677 |
+
for _ in range(num_iters):
|
678 |
+
diffs = rearrange(samples, "n d -> n () d") - rearrange(
|
679 |
+
means, "c d -> () c d"
|
680 |
+
)
|
681 |
+
dists = -(diffs ** 2).sum(dim=-1)
|
682 |
+
|
683 |
+
buckets = dists.max(dim=-1).indices
|
684 |
+
bins = torch.bincount(buckets, minlength=num_clusters)
|
685 |
+
zero_mask = bins == 0
|
686 |
+
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
687 |
+
|
688 |
+
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
|
689 |
+
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
|
690 |
+
new_means = new_means / bins_min_clamped[..., None]
|
691 |
+
|
692 |
+
means = torch.where(zero_mask[..., None], means, new_means)
|
693 |
+
|
694 |
+
return means, bins
|
695 |
+
|
696 |
+
|
697 |
+
def orthogonal_loss_fn(t):
|
698 |
+
# eq (2) from https://arxiv.org/abs/2112.00384
|
699 |
+
n = t.shape[0]
|
700 |
+
normed_codes = l2norm(t)
|
701 |
+
identity = torch.eye(n, device=t.device)
|
702 |
+
cosine_sim = einsum("i d, j d -> i j", normed_codes, normed_codes)
|
703 |
+
return ((cosine_sim - identity) ** 2).sum() / (n ** 2)
|
704 |
+
|
705 |
+
|
706 |
+
class EuclideanCodebook(nn.Module):
|
707 |
+
"""Codebook with Euclidean distance.
|
708 |
+
|
709 |
+
Args:
|
710 |
+
dim (int): Dimension.
|
711 |
+
codebook_size (int): Codebook size.
|
712 |
+
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
|
713 |
+
If set to true, run the k-means algorithm on the first training batch and use
|
714 |
+
the learned centroids as initialization.
|
715 |
+
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
|
716 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
717 |
+
epsilon (float): Epsilon value for numerical stability.
|
718 |
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
719 |
+
that have an exponential moving average cluster size less than the specified threshold with
|
720 |
+
randomly selected vector from the current batch.
|
721 |
+
"""
|
722 |
+
def __init__(
|
723 |
+
self,
|
724 |
+
dim: int,
|
725 |
+
codebook_size: int,
|
726 |
+
kmeans_init: int = False,
|
727 |
+
kmeans_iters: int = 10,
|
728 |
+
decay: float = 0.8,
|
729 |
+
epsilon: float = 1e-5,
|
730 |
+
threshold_ema_dead_code: int = 2,
|
731 |
+
):
|
732 |
+
super().__init__()
|
733 |
+
self.decay = decay
|
734 |
+
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
|
735 |
+
embed = init_fn(codebook_size, dim)
|
736 |
+
|
737 |
+
self.codebook_size = codebook_size
|
738 |
+
|
739 |
+
self.kmeans_iters = kmeans_iters
|
740 |
+
self.epsilon = epsilon
|
741 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
742 |
+
|
743 |
+
self.register_buffer("inited", torch.Tensor([not kmeans_init]))
|
744 |
+
self.register_buffer("cluster_size", torch.zeros(codebook_size))
|
745 |
+
self.register_buffer("embed", embed)
|
746 |
+
self.register_buffer("embed_avg", embed.clone())
|
747 |
+
|
748 |
+
@torch.jit.ignore
|
749 |
+
def init_embed_(self, data):
|
750 |
+
if self.inited:
|
751 |
+
return
|
752 |
+
|
753 |
+
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
|
754 |
+
self.embed.data.copy_(embed)
|
755 |
+
self.embed_avg.data.copy_(embed.clone())
|
756 |
+
self.cluster_size.data.copy_(cluster_size)
|
757 |
+
self.inited.data.copy_(torch.Tensor([True]))
|
758 |
+
# Make sure all buffers across workers are in sync after initialization
|
759 |
+
flashy.distrib.broadcast_tensors(self.buffers())
|
760 |
+
|
761 |
+
def replace_(self, samples, mask):
|
762 |
+
modified_codebook = torch.where(
|
763 |
+
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
|
764 |
+
)
|
765 |
+
self.embed.data.copy_(modified_codebook)
|
766 |
+
|
767 |
+
def expire_codes_(self, batch_samples):
|
768 |
+
if self.threshold_ema_dead_code == 0:
|
769 |
+
return
|
770 |
+
|
771 |
+
expired_codes = self.cluster_size < self.threshold_ema_dead_code
|
772 |
+
if not torch.any(expired_codes):
|
773 |
+
return
|
774 |
+
|
775 |
+
batch_samples = rearrange(batch_samples, "... d -> (...) d")
|
776 |
+
self.replace_(batch_samples, mask=expired_codes)
|
777 |
+
flashy.distrib.broadcast_tensors(self.buffers())
|
778 |
+
|
779 |
+
def preprocess(self, x):
|
780 |
+
x = rearrange(x, "... d -> (...) d")
|
781 |
+
return x
|
782 |
+
|
783 |
+
def quantize(self, x):
|
784 |
+
embed = self.embed.t()
|
785 |
+
dist = -(
|
786 |
+
x.pow(2).sum(1, keepdim=True)
|
787 |
+
- 2 * x @ embed
|
788 |
+
+ embed.pow(2).sum(0, keepdim=True)
|
789 |
+
)
|
790 |
+
embed_ind = dist.max(dim=-1).indices
|
791 |
+
return embed_ind
|
792 |
+
|
793 |
+
def postprocess_emb(self, embed_ind, shape):
|
794 |
+
return embed_ind.view(*shape[:-1])
|
795 |
+
|
796 |
+
def dequantize(self, embed_ind):
|
797 |
+
quantize = F.embedding(embed_ind, self.embed)
|
798 |
+
return quantize
|
799 |
+
|
800 |
+
def encode(self, x):
|
801 |
+
shape = x.shape
|
802 |
+
# pre-process
|
803 |
+
x = self.preprocess(x)
|
804 |
+
# quantize
|
805 |
+
embed_ind = self.quantize(x)
|
806 |
+
# post-process
|
807 |
+
embed_ind = self.postprocess_emb(embed_ind, shape)
|
808 |
+
return embed_ind
|
809 |
+
|
810 |
+
def decode(self, embed_ind):
|
811 |
+
quantize = self.dequantize(embed_ind)
|
812 |
+
return quantize
|
813 |
+
|
814 |
+
def forward(self, x):
|
815 |
+
raise NotImplementedError()
|
816 |
+
shape, dtype = x.shape, x.dtype
|
817 |
+
x = self.preprocess(x)
|
818 |
+
self.init_embed_(x)
|
819 |
+
|
820 |
+
embed_ind = self.quantize(x)
|
821 |
+
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
|
822 |
+
embed_ind = self.postprocess_emb(embed_ind, shape)
|
823 |
+
quantize = self.dequantize(embed_ind)
|
824 |
+
|
825 |
+
if self.training:
|
826 |
+
# We do the expiry of code at that point as buffers are in sync
|
827 |
+
# and all the workers will take the same decision.
|
828 |
+
self.expire_codes_(x)
|
829 |
+
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
|
830 |
+
embed_sum = x.t() @ embed_onehot
|
831 |
+
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
|
832 |
+
cluster_size = (
|
833 |
+
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
|
834 |
+
* self.cluster_size.sum()
|
835 |
+
)
|
836 |
+
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
837 |
+
self.embed.data.copy_(embed_normalized)
|
838 |
+
|
839 |
+
return quantize, embed_ind
|
840 |
+
|
841 |
+
|
842 |
+
class VectorQuantization(nn.Module):
|
843 |
+
"""Vector quantization implementation.
|
844 |
+
Currently supports only euclidean distance.
|
845 |
+
|
846 |
+
Args:
|
847 |
+
dim (int): Dimension
|
848 |
+
codebook_size (int): Codebook size
|
849 |
+
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
|
850 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
851 |
+
epsilon (float): Epsilon value for numerical stability.
|
852 |
+
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
853 |
+
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
854 |
+
threshold_ema_dead_code (int):
|
855 |
+
channels_last (bool): Channels are the last dimension in the input tensors.
|
856 |
+
commitment_weight (float): Weight for commitment loss.
|
857 |
+
orthogonal_reg_weight (float): Orthogonal regularization weights.
|
858 |
+
orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes.
|
859 |
+
orthogonal_reg_max_codes (optional int): Maximum number of codes to consider
|
860 |
+
for orthogonal regularization.
|
861 |
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
862 |
+
that have an exponential moving average cluster size less than the specified threshold with
|
863 |
+
randomly selected vector from the current batch.
|
864 |
+
"""
|
865 |
+
def __init__(
|
866 |
+
self,
|
867 |
+
dim: int,
|
868 |
+
codebook_size: int,
|
869 |
+
codebook_dim: tp.Optional[int] = None,
|
870 |
+
decay: float = 0.8,
|
871 |
+
epsilon: float = 1e-5,
|
872 |
+
kmeans_init: bool = False,
|
873 |
+
kmeans_iters: int = 10,
|
874 |
+
threshold_ema_dead_code: int = 2,
|
875 |
+
channels_last: bool = False,
|
876 |
+
commitment_weight: float = 1.,
|
877 |
+
orthogonal_reg_weight: float = 0.0,
|
878 |
+
orthogonal_reg_active_codes_only: bool = False,
|
879 |
+
orthogonal_reg_max_codes: tp.Optional[int] = None,
|
880 |
+
):
|
881 |
+
super().__init__()
|
882 |
+
_codebook_dim: int = default(codebook_dim, dim)
|
883 |
+
|
884 |
+
requires_projection = _codebook_dim != dim
|
885 |
+
self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity())
|
886 |
+
self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity())
|
887 |
+
|
888 |
+
self.epsilon = epsilon
|
889 |
+
self.commitment_weight = commitment_weight
|
890 |
+
|
891 |
+
self.orthogonal_reg_weight = orthogonal_reg_weight
|
892 |
+
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
|
893 |
+
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
|
894 |
+
|
895 |
+
self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
|
896 |
+
kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
|
897 |
+
decay=decay, epsilon=epsilon,
|
898 |
+
threshold_ema_dead_code=threshold_ema_dead_code)
|
899 |
+
self.codebook_size = codebook_size
|
900 |
+
|
901 |
+
self.channels_last = channels_last
|
902 |
+
|
903 |
+
@property
|
904 |
+
def codebook(self):
|
905 |
+
return self._codebook.embed
|
906 |
+
|
907 |
+
@property
|
908 |
+
def inited(self):
|
909 |
+
return self._codebook.inited
|
910 |
+
|
911 |
+
def _preprocess(self, x):
|
912 |
+
if not self.channels_last:
|
913 |
+
x = rearrange(x, "b d n -> b n d")
|
914 |
+
return x
|
915 |
+
|
916 |
+
def _postprocess(self, quantize):
|
917 |
+
if not self.channels_last:
|
918 |
+
quantize = rearrange(quantize, "b n d -> b d n")
|
919 |
+
return quantize
|
920 |
+
|
921 |
+
def encode(self, x):
|
922 |
+
x = self._preprocess(x)
|
923 |
+
x = self.project_in(x)
|
924 |
+
embed_in = self._codebook.encode(x)
|
925 |
+
return embed_in
|
926 |
+
|
927 |
+
def decode(self, embed_ind):
|
928 |
+
quantize = self._codebook.decode(embed_ind)
|
929 |
+
quantize = self.project_out(quantize)
|
930 |
+
quantize = self._postprocess(quantize)
|
931 |
+
return quantize
|
932 |
+
|
933 |
+
def forward(self, x):
|
934 |
+
device = x.device
|
935 |
+
x = self._preprocess(x)
|
936 |
+
|
937 |
+
x = self.project_in(x)
|
938 |
+
quantize, embed_ind = self._codebook(x)
|
939 |
+
|
940 |
+
if self.training:
|
941 |
+
quantize = x + (quantize - x).detach()
|
942 |
+
|
943 |
+
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
|
944 |
+
|
945 |
+
if self.training:
|
946 |
+
if self.commitment_weight > 0:
|
947 |
+
commit_loss = F.mse_loss(quantize.detach(), x)
|
948 |
+
loss = loss + commit_loss * self.commitment_weight
|
949 |
+
|
950 |
+
if self.orthogonal_reg_weight > 0:
|
951 |
+
codebook = self.codebook
|
952 |
+
|
953 |
+
if self.orthogonal_reg_active_codes_only:
|
954 |
+
# only calculate orthogonal loss for the activated codes for this batch
|
955 |
+
unique_code_ids = torch.unique(embed_ind)
|
956 |
+
codebook = codebook[unique_code_ids]
|
957 |
+
|
958 |
+
num_codes = codebook.shape[0]
|
959 |
+
if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes:
|
960 |
+
rand_ids = torch.randperm(num_codes, device=device)[:self.orthogonal_reg_max_codes]
|
961 |
+
codebook = codebook[rand_ids]
|
962 |
+
|
963 |
+
orthogonal_reg_loss = orthogonal_loss_fn(codebook)
|
964 |
+
loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight
|
965 |
+
|
966 |
+
quantize = self.project_out(quantize)
|
967 |
+
quantize = self._postprocess(quantize)
|
968 |
+
|
969 |
+
return quantize, embed_ind, loss
|
970 |
+
|
971 |
+
|
972 |
+
class ResidualVectorQuantization(nn.Module):
|
973 |
+
"""Residual vector quantization implementation.
|
974 |
+
|
975 |
+
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
|
976 |
+
"""
|
977 |
+
def __init__(self, *, num_quantizers, **kwargs):
|
978 |
+
super().__init__()
|
979 |
+
codebook_size = kwargs.pop('codebook_size', None)
|
980 |
+
if codebook_size is None:
|
981 |
+
raise ValueError("codebook_size must be provided in kwargs")
|
982 |
+
if type(codebook_size) != list:
|
983 |
+
codebook_size = [codebook_size] * num_quantizers
|
984 |
+
self.layers = nn.ModuleList(
|
985 |
+
[VectorQuantization(codebook_size=cur_codebook_size, **kwargs) for _,cur_codebook_size in zip(range(num_quantizers), codebook_size)]
|
986 |
+
)
|
987 |
+
|
988 |
+
|
989 |
+
# self.layers = nn.ModuleList(
|
990 |
+
# [VectorQuantization(**kwargs) for _ in range(num_quantizers)]
|
991 |
+
# )
|
992 |
+
|
993 |
+
def forward(self, x, n_q: tp.Optional[int] = None):
|
994 |
+
quantized_out = 0.0
|
995 |
+
residual = x
|
996 |
+
|
997 |
+
all_losses = []
|
998 |
+
all_indices = []
|
999 |
+
|
1000 |
+
n_q = n_q or len(self.layers)
|
1001 |
+
|
1002 |
+
for i, layer in enumerate(self.layers[:n_q]):
|
1003 |
+
quantized, indices, loss = layer(residual)
|
1004 |
+
residual = residual - quantized
|
1005 |
+
quantized_out = quantized_out + quantized
|
1006 |
+
all_indices.append(indices)
|
1007 |
+
all_losses.append(loss)
|
1008 |
+
|
1009 |
+
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
|
1010 |
+
return quantized_out, out_indices, out_losses
|
1011 |
+
|
1012 |
+
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
|
1013 |
+
residual = x
|
1014 |
+
all_indices = []
|
1015 |
+
n_q = n_q or len(self.layers)
|
1016 |
+
for layer in self.layers[:n_q]:
|
1017 |
+
indices = layer.encode(residual)
|
1018 |
+
quantized = layer.decode(indices)
|
1019 |
+
# the original code is below
|
1020 |
+
# since quantize has the gradient of residual, according to line 321
|
1021 |
+
# quantize = x + (quantize - x).detach()
|
1022 |
+
# the code below will make commitment loss to be 0 for all codebooks except for codebook1
|
1023 |
+
# https://github.com/facebookresearch/encodec/issues/25
|
1024 |
+
# therefore we change it
|
1025 |
+
|
1026 |
+
residual = residual - quantized
|
1027 |
+
# residual = residual - quantized.detach()
|
1028 |
+
# since commitment loss is averaged, the scale of the loss won't get change (not as said in the issue above)
|
1029 |
+
all_indices.append(indices)
|
1030 |
+
out_indices = torch.stack(all_indices)
|
1031 |
+
return out_indices
|
1032 |
+
|
1033 |
+
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
|
1034 |
+
quantized_out = torch.tensor(0.0, device=q_indices.device)
|
1035 |
+
for i, indices in enumerate(q_indices):
|
1036 |
+
layer = self.layers[i]
|
1037 |
+
quantized = layer.decode(indices)
|
1038 |
+
quantized_out = quantized_out + quantized
|
1039 |
+
return quantized_out
|
1040 |
+
|
1041 |
+
|
1042 |
+
class ResidualVectorQuantizer(BaseQuantizer):
|
1043 |
+
"""Residual Vector Quantizer.
|
1044 |
+
|
1045 |
+
Args:
|
1046 |
+
dimension (int): Dimension of the codebooks.
|
1047 |
+
n_q (int): Number of residual vector quantizers used.
|
1048 |
+
q_dropout (bool): Random quantizer drop out at train time.
|
1049 |
+
bins (int): Codebook size.
|
1050 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
1051 |
+
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
1052 |
+
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
1053 |
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
1054 |
+
that have an exponential moving average cluster size less than the specified threshold with
|
1055 |
+
randomly selected vector from the current batch.
|
1056 |
+
orthogonal_reg_weight (float): Orthogonal regularization weights.
|
1057 |
+
orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes.
|
1058 |
+
orthogonal_reg_max_codes (optional int): Maximum number of codes to consider.
|
1059 |
+
for orthogonal regularization.
|
1060 |
+
"""
|
1061 |
+
def __init__(
|
1062 |
+
self,
|
1063 |
+
dimension: int = 256,
|
1064 |
+
n_q: int = 8,
|
1065 |
+
q_dropout: bool = False,
|
1066 |
+
bins: tp.Union[int, tp.List[int]] = 1024,
|
1067 |
+
decay: float = 0.99,
|
1068 |
+
kmeans_init: bool = True,
|
1069 |
+
kmeans_iters: int = 10,
|
1070 |
+
threshold_ema_dead_code: int = 2,
|
1071 |
+
orthogonal_reg_weight: float = 0.0,
|
1072 |
+
orthogonal_reg_active_codes_only: bool = False,
|
1073 |
+
orthogonal_reg_max_codes: tp.Optional[int] = None,
|
1074 |
+
):
|
1075 |
+
super().__init__()
|
1076 |
+
self.max_n_q = n_q
|
1077 |
+
self.n_q = n_q
|
1078 |
+
self.q_dropout = q_dropout
|
1079 |
+
self.dimension = dimension
|
1080 |
+
self.bins = bins
|
1081 |
+
self.decay = decay
|
1082 |
+
self.kmeans_init = kmeans_init
|
1083 |
+
self.kmeans_iters = kmeans_iters
|
1084 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
1085 |
+
self.orthogonal_reg_weight = orthogonal_reg_weight
|
1086 |
+
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
|
1087 |
+
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
|
1088 |
+
self.vq = ResidualVectorQuantization(
|
1089 |
+
dim=self.dimension,
|
1090 |
+
codebook_size=self.bins,
|
1091 |
+
num_quantizers=self.n_q,
|
1092 |
+
decay=self.decay,
|
1093 |
+
kmeans_init=self.kmeans_init,
|
1094 |
+
kmeans_iters=self.kmeans_iters,
|
1095 |
+
threshold_ema_dead_code=self.threshold_ema_dead_code,
|
1096 |
+
orthogonal_reg_weight=self.orthogonal_reg_weight,
|
1097 |
+
orthogonal_reg_active_codes_only=self.orthogonal_reg_active_codes_only,
|
1098 |
+
orthogonal_reg_max_codes=self.orthogonal_reg_max_codes,
|
1099 |
+
channels_last=False
|
1100 |
+
)
|
1101 |
+
|
1102 |
+
def forward(self, x: torch.Tensor, frame_rate: int):
|
1103 |
+
n_q = self.n_q
|
1104 |
+
if self.training and self.q_dropout:
|
1105 |
+
n_q = int(torch.randint(1, self.n_q + 1, (1,)).item())
|
1106 |
+
if type(self.bins) == list:
|
1107 |
+
bins = self.bins
|
1108 |
+
else:
|
1109 |
+
bins = [self.bins] * self.n_q
|
1110 |
+
bw_per_q = [math.log2(bin) * frame_rate / 1000 for bin in bins]
|
1111 |
+
bw = torch.tensor(sum(bw_per_q)).to(x)
|
1112 |
+
quantized, codes, commit_loss = self.vq(x, n_q=n_q)
|
1113 |
+
codes = codes.transpose(0, 1)
|
1114 |
+
# codes is [B, K, T], with T frames, K nb of codebooks.
|
1115 |
+
return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss))
|
1116 |
+
|
1117 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
1118 |
+
"""Encode a given input tensor with the specified frame rate at the given bandwidth.
|
1119 |
+
The RVQ encode method sets the appropriate number of quantizer to use
|
1120 |
+
and returns indices for each quantizer.
|
1121 |
+
"""
|
1122 |
+
n_q = self.n_q
|
1123 |
+
codes = self.vq.encode(x, n_q=n_q)
|
1124 |
+
codes = codes.transpose(0, 1)
|
1125 |
+
# codes is [B, K, T], with T frames, K nb of codebooks.
|
1126 |
+
return codes
|
1127 |
+
|
1128 |
+
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
1129 |
+
"""Decode the given codes to the quantized representation."""
|
1130 |
+
# codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T].
|
1131 |
+
codes = codes.transpose(0, 1)
|
1132 |
+
quantized = self.vq.decode(codes)
|
1133 |
+
return quantized
|
1134 |
+
|
1135 |
+
@property
|
1136 |
+
def total_codebooks(self):
|
1137 |
+
return self.max_n_q
|
1138 |
+
|
1139 |
+
@property
|
1140 |
+
def num_codebooks(self):
|
1141 |
+
return self.n_q
|
1142 |
+
|
1143 |
+
def set_num_codebooks(self, n: int):
|
1144 |
+
assert n > 0 and n <= self.max_n_q
|
1145 |
+
self.n_q = n
|
1146 |
+
|
1147 |
+
class DummyQuantizer(BaseQuantizer):
|
1148 |
+
"""Fake quantizer that actually does not perform any quantization.
|
1149 |
+
"""
|
1150 |
+
def __init__(self):
|
1151 |
+
super().__init__()
|
1152 |
+
|
1153 |
+
def forward(self, x: torch.Tensor, frame_rate: int):
|
1154 |
+
q = x.unsqueeze(1)
|
1155 |
+
return QuantizedResult(x, q, torch.tensor(q.numel() * 32 * frame_rate / 1000 / len(x)).to(x))
|
1156 |
+
|
1157 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
1158 |
+
"""Encode a given input tensor with the specified sample rate at the given bandwidth.
|
1159 |
+
In the case of the DummyQuantizer, the codes are actually identical
|
1160 |
+
to the input and resulting quantized representation as no quantization is done.
|
1161 |
+
"""
|
1162 |
+
return x.unsqueeze(1)
|
1163 |
+
|
1164 |
+
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
1165 |
+
"""Decode the given codes to the quantized representation.
|
1166 |
+
In the case of the DummyQuantizer, the codes are actually identical
|
1167 |
+
to the input and resulting quantized representation as no quantization is done.
|
1168 |
+
"""
|
1169 |
+
return codes.squeeze(1)
|
1170 |
+
|
1171 |
+
@property
|
1172 |
+
def total_codebooks(self):
|
1173 |
+
"""Total number of codebooks."""
|
1174 |
+
return 1
|
1175 |
+
|
1176 |
+
@property
|
1177 |
+
def num_codebooks(self):
|
1178 |
+
"""Total number of codebooks."""
|
1179 |
+
return self.total_codebooks
|
1180 |
+
|
1181 |
+
def set_num_codebooks(self, n: int):
|
1182 |
+
"""Set the number of active codebooks."""
|
1183 |
+
raise AttributeError("Cannot override the number of codebooks for the dummy quantizer")
|
1184 |
+
|
1185 |
+
|
1186 |
+
class EncodecModel(CompressionModel):
|
1187 |
+
"""Encodec model operating on the raw waveform.
|
1188 |
+
|
1189 |
+
Args:
|
1190 |
+
encoder (nn.Module): Encoder network.
|
1191 |
+
decoder (nn.Module): Decoder network.
|
1192 |
+
quantizer (BaseQuantizer): Quantizer network.
|
1193 |
+
frame_rate (int): Frame rate for the latent representation.
|
1194 |
+
sample_rate (int): Audio sample rate.
|
1195 |
+
channels (int): Number of audio channels.
|
1196 |
+
causal (bool): Whether to use a causal version of the model.
|
1197 |
+
renormalize (bool): Whether to renormalize the audio before running the model.
|
1198 |
+
"""
|
1199 |
+
# we need assignment to override the property in the abstract class,
|
1200 |
+
# I couldn't find a better way...
|
1201 |
+
frame_rate: float = 0
|
1202 |
+
sample_rate: int = 0
|
1203 |
+
channels: int = 0
|
1204 |
+
|
1205 |
+
def __init__(self,
|
1206 |
+
encoder: nn.Module,
|
1207 |
+
decoder: nn.Module,
|
1208 |
+
quantizer: BaseQuantizer,
|
1209 |
+
frame_rate: int,
|
1210 |
+
sample_rate: int,
|
1211 |
+
channels: int,
|
1212 |
+
causal: bool = False,
|
1213 |
+
renormalize: bool = False):
|
1214 |
+
super().__init__()
|
1215 |
+
self.encoder = encoder
|
1216 |
+
self.decoder = decoder
|
1217 |
+
self.quantizer = quantizer
|
1218 |
+
self.frame_rate = frame_rate
|
1219 |
+
self.sample_rate = sample_rate
|
1220 |
+
self.channels = channels
|
1221 |
+
self.renormalize = renormalize
|
1222 |
+
self.causal = causal
|
1223 |
+
if self.causal:
|
1224 |
+
# we force disabling here to avoid handling linear overlap of segments
|
1225 |
+
# as supported in original EnCodec codebase.
|
1226 |
+
assert not self.renormalize, 'Causal model does not support renormalize'
|
1227 |
+
|
1228 |
+
@property
|
1229 |
+
def total_codebooks(self):
|
1230 |
+
"""Total number of quantizer codebooks available."""
|
1231 |
+
return self.quantizer.total_codebooks
|
1232 |
+
|
1233 |
+
@property
|
1234 |
+
def num_codebooks(self):
|
1235 |
+
"""Active number of codebooks used by the quantizer."""
|
1236 |
+
return self.quantizer.num_codebooks
|
1237 |
+
|
1238 |
+
def set_num_codebooks(self, n: int):
|
1239 |
+
"""Set the active number of codebooks used by the quantizer."""
|
1240 |
+
self.quantizer.set_num_codebooks(n)
|
1241 |
+
|
1242 |
+
@property
|
1243 |
+
def cardinality(self):
|
1244 |
+
"""Cardinality of each codebook."""
|
1245 |
+
return self.quantizer.bins
|
1246 |
+
|
1247 |
+
def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
1248 |
+
scale: tp.Optional[torch.Tensor]
|
1249 |
+
if self.renormalize:
|
1250 |
+
mono = x.mean(dim=1, keepdim=True)
|
1251 |
+
volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt()
|
1252 |
+
scale = 1e-8 + volume
|
1253 |
+
x = x / scale
|
1254 |
+
scale = scale.view(-1, 1)
|
1255 |
+
else:
|
1256 |
+
scale = None
|
1257 |
+
return x, scale
|
1258 |
+
|
1259 |
+
def postprocess(self,
|
1260 |
+
x: torch.Tensor,
|
1261 |
+
scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
|
1262 |
+
if scale is not None:
|
1263 |
+
assert self.renormalize
|
1264 |
+
x = x * scale.view(-1, 1, 1)
|
1265 |
+
return x
|
1266 |
+
|
1267 |
+
def forward(self, x: torch.Tensor, encode=False) -> QuantizedResult:
|
1268 |
+
if encode:
|
1269 |
+
return self.encode(x)
|
1270 |
+
else:
|
1271 |
+
raise NotImplementedError("model forward and training is not supported.")
|
1272 |
+
assert x.dim() == 3
|
1273 |
+
length = x.shape[-1]
|
1274 |
+
x, scale = self.preprocess(x)
|
1275 |
+
|
1276 |
+
emb = self.encoder(x)
|
1277 |
+
q_res = self.quantizer(emb, self.frame_rate)
|
1278 |
+
out = self.decoder(q_res.x)
|
1279 |
+
|
1280 |
+
# remove extra padding added by the encoder and decoder
|
1281 |
+
assert out.shape[-1] >= length, (out.shape[-1], length)
|
1282 |
+
out = out[..., :length]
|
1283 |
+
|
1284 |
+
q_res.x = self.postprocess(out, scale)
|
1285 |
+
|
1286 |
+
return q_res
|
1287 |
+
|
1288 |
+
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
1289 |
+
"""Encode the given input tensor to quantized representation along with scale parameter.
|
1290 |
+
|
1291 |
+
Args:
|
1292 |
+
x (torch.Tensor): Float tensor of shape [B, C, T]
|
1293 |
+
|
1294 |
+
Returns:
|
1295 |
+
codes, scale (tuple of torch.Tensor, torch.Tensor): Tuple composed of:
|
1296 |
+
codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep.
|
1297 |
+
scale a float tensor containing the scale for audio renormalizealization.
|
1298 |
+
"""
|
1299 |
+
assert x.dim() == 3
|
1300 |
+
x, scale = self.preprocess(x)
|
1301 |
+
emb = self.encoder(x)
|
1302 |
+
codes = self.quantizer.encode(emb)
|
1303 |
+
return codes, scale
|
1304 |
+
|
1305 |
+
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
1306 |
+
"""Decode the given codes to a reconstructed representation, using the scale to perform
|
1307 |
+
audio denormalization if needed.
|
1308 |
+
|
1309 |
+
Args:
|
1310 |
+
codes (torch.Tensor): Int tensor of shape [B, K, T]
|
1311 |
+
scale (torch.Tensor, optional): Float tensor containing the scale value.
|
1312 |
+
|
1313 |
+
Returns:
|
1314 |
+
out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio.
|
1315 |
+
"""
|
1316 |
+
emb = self.decode_latent(codes)
|
1317 |
+
out = self.decoder(emb)
|
1318 |
+
out = self.postprocess(out, scale)
|
1319 |
+
# out contains extra padding added by the encoder and decoder
|
1320 |
+
return out
|
1321 |
+
|
1322 |
+
def decode_latent(self, codes: torch.Tensor):
|
1323 |
+
"""Decode from the discrete codes to continuous latent space."""
|
1324 |
+
return self.quantizer.decode(codes)
|
1325 |
+
|
1326 |
+
class EncodecModel_encode_only(CompressionModel):
|
1327 |
+
"""Encodec model operating on the raw waveform. Encode only, so no decoder
|
1328 |
+
|
1329 |
+
Args:
|
1330 |
+
encoder (nn.Module): Encoder network.
|
1331 |
+
quantizer (BaseQuantizer): Quantizer network.
|
1332 |
+
frame_rate (int): Frame rate for the latent representation.
|
1333 |
+
sample_rate (int): Audio sample rate.
|
1334 |
+
channels (int): Number of audio channels.
|
1335 |
+
causal (bool): Whether to use a causal version of the model.
|
1336 |
+
renormalize (bool): Whether to renormalize the audio before running the model.
|
1337 |
+
"""
|
1338 |
+
# we need assignment to override the property in the abstract class,
|
1339 |
+
# I couldn't find a better way...
|
1340 |
+
frame_rate: float = 0
|
1341 |
+
sample_rate: int = 0
|
1342 |
+
channels: int = 0
|
1343 |
+
|
1344 |
+
def __init__(self,
|
1345 |
+
encoder: nn.Module,
|
1346 |
+
quantizer: BaseQuantizer,
|
1347 |
+
frame_rate: int,
|
1348 |
+
sample_rate: int,
|
1349 |
+
channels: int,
|
1350 |
+
causal: bool = False,
|
1351 |
+
renormalize: bool = False):
|
1352 |
+
super().__init__()
|
1353 |
+
self.encoder = encoder
|
1354 |
+
self.quantizer = quantizer
|
1355 |
+
self.frame_rate = frame_rate
|
1356 |
+
self.sample_rate = sample_rate
|
1357 |
+
self.channels = channels
|
1358 |
+
self.renormalize = renormalize
|
1359 |
+
self.causal = causal
|
1360 |
+
if self.causal:
|
1361 |
+
# we force disabling here to avoid handling linear overlap of segments
|
1362 |
+
# as supported in original EnCodec codebase.
|
1363 |
+
assert not self.renormalize, 'Causal model does not support renormalize'
|
1364 |
+
|
1365 |
+
@property
|
1366 |
+
def total_codebooks(self):
|
1367 |
+
"""Total number of quantizer codebooks available."""
|
1368 |
+
return self.quantizer.total_codebooks
|
1369 |
+
|
1370 |
+
@property
|
1371 |
+
def num_codebooks(self):
|
1372 |
+
"""Active number of codebooks used by the quantizer."""
|
1373 |
+
return self.quantizer.num_codebooks
|
1374 |
+
|
1375 |
+
def set_num_codebooks(self, n: int):
|
1376 |
+
"""Set the active number of codebooks used by the quantizer."""
|
1377 |
+
self.quantizer.set_num_codebooks(n)
|
1378 |
+
|
1379 |
+
@property
|
1380 |
+
def cardinality(self):
|
1381 |
+
"""Cardinality of each codebook."""
|
1382 |
+
return self.quantizer.bins
|
1383 |
+
|
1384 |
+
def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
1385 |
+
scale: tp.Optional[torch.Tensor]
|
1386 |
+
if self.renormalize:
|
1387 |
+
mono = x.mean(dim=1, keepdim=True)
|
1388 |
+
volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt()
|
1389 |
+
scale = 1e-8 + volume
|
1390 |
+
x = x / scale
|
1391 |
+
scale = scale.view(-1, 1)
|
1392 |
+
else:
|
1393 |
+
scale = None
|
1394 |
+
return x, scale
|
1395 |
+
|
1396 |
+
def postprocess(self,
|
1397 |
+
x: torch.Tensor,
|
1398 |
+
scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
|
1399 |
+
if scale is not None:
|
1400 |
+
assert self.renormalize
|
1401 |
+
x = x * scale.view(-1, 1, 1)
|
1402 |
+
return x
|
1403 |
+
|
1404 |
+
def forward(self, x: torch.Tensor, encode=False) -> QuantizedResult:
|
1405 |
+
if encode:
|
1406 |
+
return self.encode(x)
|
1407 |
+
else:
|
1408 |
+
raise NotImplementedError("model forward and training is not supported.")
|
1409 |
+
assert x.dim() == 3
|
1410 |
+
length = x.shape[-1]
|
1411 |
+
x, scale = self.preprocess(x)
|
1412 |
+
|
1413 |
+
emb = self.encoder(x)
|
1414 |
+
q_res = self.quantizer(emb, self.frame_rate)
|
1415 |
+
out = self.decoder(q_res.x)
|
1416 |
+
|
1417 |
+
# remove extra padding added by the encoder and decoder
|
1418 |
+
assert out.shape[-1] >= length, (out.shape[-1], length)
|
1419 |
+
out = out[..., :length]
|
1420 |
+
|
1421 |
+
q_res.x = self.postprocess(out, scale)
|
1422 |
+
|
1423 |
+
return q_res
|
1424 |
+
|
1425 |
+
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
1426 |
+
"""Encode the given input tensor to quantized representation along with scale parameter.
|
1427 |
+
|
1428 |
+
Args:
|
1429 |
+
x (torch.Tensor): Float tensor of shape [B, C, T]
|
1430 |
+
|
1431 |
+
Returns:
|
1432 |
+
codes, scale (tuple of torch.Tensor, torch.Tensor): Tuple composed of:
|
1433 |
+
codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep.
|
1434 |
+
scale a float tensor containing the scale for audio renormalizealization.
|
1435 |
+
"""
|
1436 |
+
assert x.dim() == 3
|
1437 |
+
x, scale = self.preprocess(x)
|
1438 |
+
emb = self.encoder(x)
|
1439 |
+
codes = self.quantizer.encode(emb)
|
1440 |
+
return codes, scale
|
1441 |
+
|
1442 |
+
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
1443 |
+
"""Decode the given codes to a reconstructed representation, using the scale to perform
|
1444 |
+
audio denormalization if needed.
|
1445 |
+
|
1446 |
+
Args:
|
1447 |
+
codes (torch.Tensor): Int tensor of shape [B, K, T]
|
1448 |
+
scale (torch.Tensor, optional): Float tensor containing the scale value.
|
1449 |
+
|
1450 |
+
Returns:
|
1451 |
+
out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio.
|
1452 |
+
"""
|
1453 |
+
raise NotImplementedError("Decode is not supported for encode only model")
|
1454 |
+
emb = self.decode_latent(codes)
|
1455 |
+
out = self.decoder(emb)
|
1456 |
+
out = self.postprocess(out, scale)
|
1457 |
+
# out contains extra padding added by the encoder and decoder
|
1458 |
+
return out
|
1459 |
+
|
1460 |
+
def decode_latent(self, codes: torch.Tensor):
|
1461 |
+
"""Decode from the discrete codes to continuous latent space."""
|
1462 |
+
raise NotImplementedError("Decode is not supported for encode only model")
|
1463 |
+
return self.quantizer.decode(codes)
|
1464 |
+
|
1465 |
+
def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) -> BaseQuantizer:
|
1466 |
+
klass = {
|
1467 |
+
'no_quant': DummyQuantizer,
|
1468 |
+
'rvq': ResidualVectorQuantizer
|
1469 |
+
}[quantizer]
|
1470 |
+
kwargs = dict_from_config(getattr(cfg, quantizer))
|
1471 |
+
if quantizer != 'no_quant':
|
1472 |
+
kwargs['dimension'] = dimension
|
1473 |
+
return klass(**kwargs)
|
1474 |
+
|
1475 |
+
def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig):
|
1476 |
+
if encoder_name == 'seanet':
|
1477 |
+
kwargs = dict_from_config(getattr(cfg, 'seanet'))
|
1478 |
+
encoder_override_kwargs = kwargs.pop('encoder')
|
1479 |
+
decoder_override_kwargs = kwargs.pop('decoder')
|
1480 |
+
encoder_kwargs = {**kwargs, **encoder_override_kwargs}
|
1481 |
+
decoder_kwargs = {**kwargs, **decoder_override_kwargs}
|
1482 |
+
encoder = SEANetEncoder(**encoder_kwargs)
|
1483 |
+
decoder = SEANetDecoder(**decoder_kwargs)
|
1484 |
+
return encoder, decoder
|
1485 |
+
else:
|
1486 |
+
raise KeyError(f"Unexpected compression model {cfg.compression_model}")
|
1487 |
+
|
1488 |
+
|
1489 |
+
def get_compression_model(ckpt_fn, encode_only=False, device="cpu") -> CompressionModel:
|
1490 |
+
"""Instantiate a compression model."""
|
1491 |
+
if device == None:
|
1492 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
1493 |
+
state = torch.load(ckpt_fn, map_location='cpu')
|
1494 |
+
cfg = state['xp.cfg']
|
1495 |
+
cfg.device = str(device)
|
1496 |
+
weights = state['best_state']['model']
|
1497 |
+
assert cfg.compression_model == 'encodec', "Only Encodec model is supported for now."
|
1498 |
+
if encode_only:
|
1499 |
+
all_keys = list(weights.keys())
|
1500 |
+
for key in all_keys:
|
1501 |
+
if key.startswith('decoder'):
|
1502 |
+
del weights[key]
|
1503 |
+
kwargs = dict_from_config(getattr(cfg, 'encodec'))
|
1504 |
+
encoder_name = kwargs.pop('autoencoder')
|
1505 |
+
quantizer_name = kwargs.pop('quantizer')
|
1506 |
+
encoder, _ = get_encodec_autoencoder(encoder_name, cfg)
|
1507 |
+
quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension)
|
1508 |
+
frame_rate = kwargs['sample_rate'] // encoder.hop_length
|
1509 |
+
renormalize = kwargs.pop('renormalize', False)
|
1510 |
+
# deprecated params
|
1511 |
+
kwargs.pop('renorm', None)
|
1512 |
+
compression_model = EncodecModel_encode_only(encoder, quantizer,
|
1513 |
+
frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device)
|
1514 |
+
assert compression_model.sample_rate == cfg.sample_rate, "Compression model sample rate should match"
|
1515 |
+
compression_model.load_state_dict(weights)
|
1516 |
+
compression_model.eval()
|
1517 |
+
return compression_model
|
1518 |
+
|
1519 |
+
else:
|
1520 |
+
kwargs = dict_from_config(getattr(cfg, 'encodec'))
|
1521 |
+
encoder_name = kwargs.pop('autoencoder')
|
1522 |
+
quantizer_name = kwargs.pop('quantizer')
|
1523 |
+
encoder, decoder = get_encodec_autoencoder(encoder_name, cfg)
|
1524 |
+
quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension)
|
1525 |
+
frame_rate = kwargs['sample_rate'] // encoder.hop_length
|
1526 |
+
renormalize = kwargs.pop('renormalize', False)
|
1527 |
+
# deprecated params
|
1528 |
+
kwargs.pop('renorm', None)
|
1529 |
+
compression_model = EncodecModel(encoder, decoder, quantizer,
|
1530 |
+
frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device)
|
1531 |
+
assert compression_model.sample_rate == cfg.sample_rate, "Compression model sample rate should match"
|
1532 |
+
compression_model.load_state_dict(weights)
|
1533 |
+
compression_model.eval()
|
1534 |
+
return compression_model
|
1535 |
+
|
1536 |
+
if __name__ == "__main__":
|
1537 |
+
import torchaudio
|
1538 |
+
ckpt_fn = "/home/pyp/BoostedVoiceEditor/pretrained/encodec_6f79c6a8.th"
|
1539 |
+
audio_in_fns = ["/home/pyp/BoostedVoiceEditor/demo/pam.wav", "/home/pyp/BoostedVoiceEditor/demo/ray.wav", "/home/pyp/BoostedVoiceEditor/demo/84_121550_000074_000000.wav", "/home/pyp/BoostedVoiceEditor/demo/caribbean.wav", "/home/pyp/BoostedVoiceEditor/demo/bible.wav", "/home/pyp/BoostedVoiceEditor/demo/miley.wav"]
|
1540 |
+
audio_out_fns = ["/home/pyp/BoostedVoiceEditor/demo/pam_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/ray_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/84_121550_000074_000000_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/caribbean_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/bible_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/miley_encodecTest.wav"]
|
1541 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
1542 |
+
model = get_compression_model(ckpt_fn, device=device)
|
1543 |
+
|
1544 |
+
for audio_in_fn, audio_out_fn in zip(audio_in_fns, audio_out_fns):
|
1545 |
+
audio_in, sr = torchaudio.load(audio_in_fn)
|
1546 |
+
if sr != model.sample_rate:
|
1547 |
+
audio_in = torchaudio.transforms.Resample(sr, model.sample_rate)(audio_in)
|
1548 |
+
if audio_in.shape[0] == 2:
|
1549 |
+
audio_in = audio_in.mean(dim=0, keepdim=True)
|
1550 |
+
audio_in = audio_in.unsqueeze(0)
|
1551 |
+
audio_in = audio_in.to(torch.float32).to(device)
|
1552 |
+
codes = model.encode(audio_in)[0]
|
1553 |
+
audio_out = model.decode(codes)[0].cpu()
|
1554 |
+
torchaudio.save(audio_out_fn, audio_out, model.sample_rate)
|
data/emilia_preprocessing/sha256hash.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import sys
|
3 |
+
|
4 |
+
def sha256_hash_file(filename):
|
5 |
+
sha256_hash = hashlib.sha256()
|
6 |
+
with open(filename, "rb") as file:
|
7 |
+
# Read and update hash string in chunks to handle large files
|
8 |
+
for byte_block in iter(lambda: file.read(4096), b""):
|
9 |
+
sha256_hash.update(byte_block)
|
10 |
+
return sha256_hash.hexdigest()
|
11 |
+
|
12 |
+
# Usage example
|
13 |
+
filename = sys.argv[1]
|
14 |
+
print(sha256_hash_file(filename))
|
data/emilia_preprocessing/step1_download.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# conda activate emilia
|
2 |
+
from datasets import load_dataset
|
3 |
+
import fire
|
4 |
+
def main(root: str="/data/scratch/pyp/datasets/emilia"):
|
5 |
+
path = "EN/*.tar.gz*"
|
6 |
+
dataset = load_dataset("amphion/Emilia-Dataset", data_files={"en": path}, split="en", streaming=False, revision="fc71e07e8572f5f3be1dbd02ed3172a4d298f152", cache_dir=root)
|
7 |
+
|
8 |
+
if __name__ == "__main__":
|
9 |
+
fire.Fire(main)
|
data/emilia_preprocessing/step2_log_tar_files.sh
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
emilia_root=$1 # /data/scratch/pyp/datasets/emilia/downloads
|
2 |
+
for file in ${emilia_root}/* ; do
|
3 |
+
# Check if gzip compressed archive
|
4 |
+
if file "$file" | grep -q 'gzip compressed data'; then
|
5 |
+
# Extract string of form 'was "EN_B00100.tar"'' from the output of the file command to keep EN_B00100
|
6 |
+
filename=$(file "$file" | grep -oP '(?<=was ")[^"]*' | sed 's/\.tar$//')
|
7 |
+
# Get the file size
|
8 |
+
size=$(du -sh "$file" | cut -f1)
|
9 |
+
original_filename=$(basename "$file")
|
10 |
+
# Get URL string from corresponding JSON file with same basename
|
11 |
+
json_file=$file.json
|
12 |
+
if [ -f "$json_file" ]; then
|
13 |
+
# url=$(jq -r '.url' "$json_file") # jq is not installed on the server
|
14 |
+
url=$(python3 -c "import sys, json; print(json.load(open('$json_file'))['url'])")
|
15 |
+
else
|
16 |
+
url="N/A"
|
17 |
+
fi
|
18 |
+
# Compute SHA256 hash of the file
|
19 |
+
hash=$(python sha256hash.py "$file")
|
20 |
+
echo $original_filename
|
21 |
+
# Write filename, size, hash, original filename, URL to output file
|
22 |
+
echo "$filename, $size, $hash, $original_filename, $url" >> file_log.txt
|
23 |
+
fi
|
24 |
+
done
|
25 |
+
|
26 |
+
# Sort the output file by filename
|
27 |
+
sort -o file_log.txt file_log.txt
|
data/emilia_preprocessing/step3_untar.sh
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Define the root directory where the tar files are located
|
4 |
+
root=$1 # /data/scratch/pyp/datasets/emilia/downloads
|
5 |
+
save_root=$2 # /data/scratch/pyp/datasets/emilia/preprocessed/audio
|
6 |
+
|
7 |
+
mkdir -p "${save_root}"
|
8 |
+
|
9 |
+
# Input log files
|
10 |
+
log_file="file_log.txt" # Full log of files to process
|
11 |
+
exist_log_file="file_log_debug.txt" # Log of already processed files
|
12 |
+
failure_log="untar_failures.log" # Log file for untar failures
|
13 |
+
|
14 |
+
# Clear previous failure log
|
15 |
+
> "$failure_log"
|
16 |
+
|
17 |
+
# Create an array of filenames already processed (from exist_log_file)
|
18 |
+
if [ -f "$exist_log_file" ]; then
|
19 |
+
mapfile -t existing_files < "$exist_log_file"
|
20 |
+
else
|
21 |
+
existing_files=()
|
22 |
+
fi
|
23 |
+
|
24 |
+
# Create a temporary filtered log of files to process
|
25 |
+
filtered_log="filtered_file_log.txt"
|
26 |
+
grep -v -F -f "$exist_log_file" "$log_file" > "$filtered_log"
|
27 |
+
|
28 |
+
# Count total filtered files
|
29 |
+
total_files=$(wc -l < "$filtered_log")
|
30 |
+
echo "Found $total_files entries to process in $filtered_log."
|
31 |
+
|
32 |
+
# Print the filtered files
|
33 |
+
echo "Filtered files to process:"
|
34 |
+
cat "$filtered_log"
|
35 |
+
echo
|
36 |
+
|
37 |
+
# Confirm before starting processing
|
38 |
+
read -p "Do you want to proceed with the above files? (y/n): " confirm
|
39 |
+
if [[ "$confirm" != "y" ]]; then
|
40 |
+
echo "Operation canceled."
|
41 |
+
rm -f "$filtered_log"
|
42 |
+
exit 1
|
43 |
+
fi
|
44 |
+
|
45 |
+
# Start time
|
46 |
+
start_time=$(date +%s)
|
47 |
+
|
48 |
+
# Counter for how many lines we've processed
|
49 |
+
count=0
|
50 |
+
|
51 |
+
# Process filtered log
|
52 |
+
while IFS=',' read -r filename size local_sha256 original_filename url; do
|
53 |
+
count=$((count + 1))
|
54 |
+
|
55 |
+
# Trim leading/trailing whitespace
|
56 |
+
filename=$(echo "$filename" | xargs)
|
57 |
+
size=$(echo "$size" | xargs)
|
58 |
+
local_sha256=$(echo "$local_sha256" | xargs)
|
59 |
+
original_filename=$(echo "$original_filename" | xargs)
|
60 |
+
url=$(echo "$url" | xargs)
|
61 |
+
|
62 |
+
# Construct the full path to the tar file
|
63 |
+
tar_file="${root}/${original_filename}"
|
64 |
+
|
65 |
+
# Check if the tar file exists
|
66 |
+
if [ ! -f "$tar_file" ]; then
|
67 |
+
echo "❌ File not found: $tar_file"
|
68 |
+
echo "$filename, $size, $local_sha256, $original_filename, $url" >> "$failure_log"
|
69 |
+
else
|
70 |
+
# Try to untar the file
|
71 |
+
echo "[$count/$total_files] Untarring $tar_file..."
|
72 |
+
|
73 |
+
if ! tar -xf "$tar_file" -C "${save_root}"; then
|
74 |
+
# If untar fails, log the failure
|
75 |
+
echo "❌ Failed to untar: $tar_file"
|
76 |
+
echo "$filename, $size, $local_sha256, $original_filename, $url" >> "$failure_log"
|
77 |
+
else
|
78 |
+
echo "✅ Successfully untarred: $tar_file"
|
79 |
+
# Append successfully untarred filename to exist_log_file
|
80 |
+
echo "$filename" >> "$exist_log_file"
|
81 |
+
fi
|
82 |
+
fi
|
83 |
+
|
84 |
+
# Calculate elapsed time, average time per file, and ETA
|
85 |
+
now=$(date +%s)
|
86 |
+
elapsed=$(( now - start_time )) # total seconds since the start
|
87 |
+
if [ $count -gt 0 ]; then
|
88 |
+
avg_time=$(awk "BEGIN { printf \"%.2f\", $elapsed / $count }")
|
89 |
+
remain=$(( total_files - count ))
|
90 |
+
eta_seconds=$(awk "BEGIN { printf \"%.0f\", $avg_time * $remain }")
|
91 |
+
eta_formatted=$(date -ud "@${eta_seconds}" +'%H:%M:%S')
|
92 |
+
echo "Elapsed: ${elapsed}s | Avg/f: ${avg_time}s | Remaining: $remain files | ETA: ~${eta_formatted}"
|
93 |
+
fi
|
94 |
+
|
95 |
+
done < "$filtered_log"
|
96 |
+
|
97 |
+
# Clean up temporary filtered log
|
98 |
+
rm -f "$filtered_log"
|
99 |
+
|
100 |
+
# Summary
|
101 |
+
echo "Untar operation completed. Check $failure_log for any failures."
|
data/emilia_preprocessing/step4_construct_manifest.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# construct manifest file for training, note that we only have one train split
|
2 |
+
# also create neighbors folder for each sample, which is simply done through speaker label in the original manifest where each file has rows
|
3 |
+
# path\tdistance\tduration
|
4 |
+
# where distance is always 0 because we don't know the distance between the samples
|
5 |
+
|
6 |
+
# waiting on Yushen Chen to provide data filtering approach
|
7 |
+
import sys, copy
|
8 |
+
import os, random, numpy as np, socket
|
9 |
+
|
10 |
+
import json
|
11 |
+
import tqdm
|
12 |
+
from multiprocessing import Pool
|
13 |
+
import glob, os
|
14 |
+
from collections import defaultdict
|
15 |
+
def write_jsonl(data, fn):
|
16 |
+
with open(fn, "w") as file:
|
17 |
+
for entry in data:
|
18 |
+
file.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
19 |
+
def read_jsonl(file_path):
|
20 |
+
cur_data = []
|
21 |
+
with open(file_path, 'r', encoding='utf-8-sig') as file:
|
22 |
+
for line in file:
|
23 |
+
cur_data.append(json.loads(line.strip()))
|
24 |
+
return cur_data
|
25 |
+
|
26 |
+
def repetition_found(text, length=2, tolerance=10):
|
27 |
+
pattern_count = defaultdict(int)
|
28 |
+
for i in range(len(text) - length + 1):
|
29 |
+
pattern = text[i : i + length]
|
30 |
+
pattern_count[pattern] += 1
|
31 |
+
for pattern, count in pattern_count.items():
|
32 |
+
if count > tolerance:
|
33 |
+
return True
|
34 |
+
return False
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
out_en = {
|
39 |
+
"EN_B00013_S00913",
|
40 |
+
"EN_B00042_S00120",
|
41 |
+
"EN_B00055_S04111",
|
42 |
+
"EN_B00061_S00693",
|
43 |
+
"EN_B00061_S01494",
|
44 |
+
"EN_B00061_S03375",
|
45 |
+
"EN_B00059_S00092",
|
46 |
+
"EN_B00111_S04300",
|
47 |
+
"EN_B00100_S03759",
|
48 |
+
"EN_B00087_S03811",
|
49 |
+
"EN_B00059_S00950",
|
50 |
+
"EN_B00089_S00946",
|
51 |
+
"EN_B00078_S05127",
|
52 |
+
"EN_B00070_S04089",
|
53 |
+
"EN_B00074_S09659",
|
54 |
+
"EN_B00061_S06983",
|
55 |
+
"EN_B00061_S07060",
|
56 |
+
"EN_B00059_S08397",
|
57 |
+
"EN_B00082_S06192",
|
58 |
+
"EN_B00091_S01238",
|
59 |
+
"EN_B00089_S07349",
|
60 |
+
"EN_B00070_S04343",
|
61 |
+
"EN_B00061_S02400",
|
62 |
+
"EN_B00076_S01262",
|
63 |
+
"EN_B00068_S06467",
|
64 |
+
"EN_B00076_S02943",
|
65 |
+
"EN_B00064_S05954",
|
66 |
+
"EN_B00061_S05386",
|
67 |
+
"EN_B00066_S06544",
|
68 |
+
"EN_B00076_S06944",
|
69 |
+
"EN_B00072_S08620",
|
70 |
+
"EN_B00076_S07135",
|
71 |
+
"EN_B00076_S09127",
|
72 |
+
"EN_B00065_S00497",
|
73 |
+
"EN_B00059_S06227",
|
74 |
+
"EN_B00063_S02859",
|
75 |
+
"EN_B00075_S01547",
|
76 |
+
"EN_B00061_S08286",
|
77 |
+
"EN_B00079_S02901",
|
78 |
+
"EN_B00092_S03643",
|
79 |
+
"EN_B00096_S08653",
|
80 |
+
"EN_B00063_S04297",
|
81 |
+
"EN_B00063_S04614",
|
82 |
+
"EN_B00079_S04698",
|
83 |
+
"EN_B00104_S01666",
|
84 |
+
"EN_B00061_S09504",
|
85 |
+
"EN_B00061_S09694",
|
86 |
+
"EN_B00065_S05444",
|
87 |
+
"EN_B00063_S06860",
|
88 |
+
"EN_B00065_S05725",
|
89 |
+
"EN_B00069_S07628",
|
90 |
+
"EN_B00083_S03875",
|
91 |
+
"EN_B00071_S07665",
|
92 |
+
"EN_B00071_S07665",
|
93 |
+
"EN_B00062_S04187",
|
94 |
+
"EN_B00065_S09873",
|
95 |
+
"EN_B00065_S09922",
|
96 |
+
"EN_B00084_S02463",
|
97 |
+
"EN_B00067_S05066",
|
98 |
+
"EN_B00106_S08060",
|
99 |
+
"EN_B00073_S06399",
|
100 |
+
"EN_B00073_S09236",
|
101 |
+
"EN_B00087_S00432",
|
102 |
+
"EN_B00085_S05618",
|
103 |
+
"EN_B00064_S01262",
|
104 |
+
"EN_B00072_S01739",
|
105 |
+
"EN_B00059_S03913",
|
106 |
+
"EN_B00069_S04036",
|
107 |
+
"EN_B00067_S05623",
|
108 |
+
"EN_B00060_S05389",
|
109 |
+
"EN_B00060_S07290",
|
110 |
+
"EN_B00062_S08995",
|
111 |
+
}
|
112 |
+
en_filters = ["ا", "い", "て"]
|
113 |
+
|
114 |
+
|
115 |
+
from multiprocessing import Pool
|
116 |
+
|
117 |
+
def process_meta_item(item, root, sub_root, audio_folder, audio_ext, text_ext):
|
118 |
+
global filtered_duration, filtered_count, total_duration, total_count
|
119 |
+
# Data filtering following Yushen's approach
|
120 |
+
if (
|
121 |
+
item["wav"].split("/")[-1] in out_en
|
122 |
+
or any(t in item["text"] for t in en_filters)
|
123 |
+
or repetition_found(item["text"], length=4)
|
124 |
+
):
|
125 |
+
return None, item["duration"], 1, 0, 0, (None, None) # Return filtered results
|
126 |
+
|
127 |
+
# Trim leading space from text if exists
|
128 |
+
if item["text"].startswith(" "):
|
129 |
+
item["text"] = item["text"][1:]
|
130 |
+
|
131 |
+
# write text to text file
|
132 |
+
text_fn = os.path.join(root, sub_root, audio_folder, item["wav"].replace(audio_ext, text_ext))
|
133 |
+
os.makedirs(os.path.dirname(text_fn), exist_ok=True)
|
134 |
+
with open(text_fn, "w") as f:
|
135 |
+
f.write(item["text"])
|
136 |
+
|
137 |
+
# spk2info[item["speaker"]].append(item)
|
138 |
+
return (
|
139 |
+
f"{item['wav']}\t{item['duration']}\n",
|
140 |
+
0,
|
141 |
+
0,
|
142 |
+
item["duration"],
|
143 |
+
1,
|
144 |
+
(item['speaker'], item)
|
145 |
+
) # Return processed results
|
146 |
+
|
147 |
+
|
148 |
+
def parallel_process_meta(meta, root, sub_root, audio_folder, num_workers, audio_ext, text_ext):
|
149 |
+
with Pool(num_workers) as pool:
|
150 |
+
results = pool.starmap(
|
151 |
+
process_meta_item,
|
152 |
+
[(item, root, sub_root, audio_folder, audio_ext, text_ext) for item in meta],
|
153 |
+
)
|
154 |
+
|
155 |
+
processed_items = []
|
156 |
+
spkitem = []
|
157 |
+
filtered_duration = 0
|
158 |
+
filtered_count = 0
|
159 |
+
total_duration = 0
|
160 |
+
total_count = 0
|
161 |
+
|
162 |
+
for result in results:
|
163 |
+
if result[0]: # If the item was processed
|
164 |
+
processed_items.append(result[0])
|
165 |
+
filtered_duration += result[1]
|
166 |
+
filtered_count += result[2]
|
167 |
+
total_duration += result[3]
|
168 |
+
total_count += result[4]
|
169 |
+
spkitem.append(result[5])
|
170 |
+
|
171 |
+
return processed_items, filtered_duration, filtered_count, total_duration, total_count, spkitem
|
172 |
+
|
173 |
+
|
174 |
+
def main(
|
175 |
+
root: str = "/data/scratch/pyp/datasets/emilia",
|
176 |
+
sub_root: str = "preprocessed",
|
177 |
+
audio_folder: str = "audio",
|
178 |
+
manifest_folder: str = "manifest_for_codec",
|
179 |
+
neighbors_folder: str = "neighbors",
|
180 |
+
audio_ext: str = ".mp3",
|
181 |
+
text_ext: str = ".txt",
|
182 |
+
num_workers: int = 8, # Specify the number of workers
|
183 |
+
):
|
184 |
+
# Find the segments that are untarred
|
185 |
+
all_fns = [
|
186 |
+
item
|
187 |
+
for item in glob.glob(f"{root}/{sub_root}/{audio_folder}/*")
|
188 |
+
if os.path.basename(item).startswith("EN_") and os.path.isdir(item)
|
189 |
+
]
|
190 |
+
print(f"found {len(all_fns)} untarred segments")
|
191 |
+
print(f"{all_fns[:3]}")
|
192 |
+
|
193 |
+
res = []
|
194 |
+
total_duration = 0
|
195 |
+
total_count = 0
|
196 |
+
filtered_duration = 0
|
197 |
+
filtered_count = 0
|
198 |
+
|
199 |
+
for fn in tqdm.tqdm(all_fns, desc="overall progress"):
|
200 |
+
spk2info = defaultdict(list)
|
201 |
+
metafn = os.path.join(root, "EN", os.path.basename(fn) + ".jsonl")
|
202 |
+
meta = read_jsonl(metafn)
|
203 |
+
|
204 |
+
# Parallel process metadata
|
205 |
+
processed_items, fd, fc, td, tc, spkitem = parallel_process_meta(
|
206 |
+
meta, root, sub_root, audio_folder, num_workers, audio_ext, text_ext
|
207 |
+
)
|
208 |
+
|
209 |
+
# Aggregate results
|
210 |
+
res.extend(processed_items)
|
211 |
+
filtered_duration += fd
|
212 |
+
filtered_count += fc
|
213 |
+
total_duration += td
|
214 |
+
total_count += tc
|
215 |
+
|
216 |
+
for spk, item in spkitem:
|
217 |
+
if spk:
|
218 |
+
spk2info[spk].append(item)
|
219 |
+
|
220 |
+
# Save neighbor files
|
221 |
+
for spk in spk2info:
|
222 |
+
for item in spk2info[spk]:
|
223 |
+
neighbor_fn = os.path.join(
|
224 |
+
root,
|
225 |
+
sub_root,
|
226 |
+
neighbors_folder,
|
227 |
+
item["wav"].replace(audio_ext, text_ext),
|
228 |
+
)
|
229 |
+
os.makedirs(os.path.dirname(neighbor_fn), exist_ok=True)
|
230 |
+
tobe_write = [f"{neighbor_item['wav'].replace(audio_ext, text_ext)}\t0\t{neighbor_item['duration']}\n" for neighbor_item in spk2info[spk] if neighbor_item["wav"] != item["wav"]]
|
231 |
+
if tobe_write:
|
232 |
+
with open(neighbor_fn, "w") as f:
|
233 |
+
f.writelines(tobe_write)
|
234 |
+
|
235 |
+
print(
|
236 |
+
f"total duration: {total_duration / 3600:.2f} hours, total count: {total_count}"
|
237 |
+
)
|
238 |
+
print(
|
239 |
+
f"filtered duration: {filtered_duration / 3600:.2f} hours, filtered count: {filtered_count}"
|
240 |
+
)
|
241 |
+
save_fn = os.path.join(root, sub_root, manifest_folder, "train.txt")
|
242 |
+
os.makedirs(os.path.dirname(save_fn), exist_ok=True)
|
243 |
+
with open(save_fn, "w") as f:
|
244 |
+
for item in res:
|
245 |
+
f.write(item)
|
246 |
+
|
247 |
+
|
248 |
+
if __name__ == "__main__":
|
249 |
+
import fire
|
250 |
+
|
251 |
+
fire.Fire(main)
|
data/emilia_preprocessing/step5_phonemize.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys, copy
|
2 |
+
import os, random, numpy as np, socket
|
3 |
+
|
4 |
+
import json
|
5 |
+
import tqdm
|
6 |
+
from multiprocessing import Pool
|
7 |
+
import glob, os, fire
|
8 |
+
from collections import defaultdict
|
9 |
+
sys.path.insert(0, "../../")
|
10 |
+
from data.tokenizer import TextTokenizer, tokenize_text
|
11 |
+
|
12 |
+
def write_jsonl(data, fn):
|
13 |
+
with open(fn, "w") as file:
|
14 |
+
for entry in data:
|
15 |
+
file.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
16 |
+
|
17 |
+
|
18 |
+
def read_jsonl(file_path):
|
19 |
+
cur_data = []
|
20 |
+
with open(file_path, 'r', encoding='utf-8-sig') as file:
|
21 |
+
for line in file:
|
22 |
+
cur_data.append(json.loads(line.strip()))
|
23 |
+
return cur_data
|
24 |
+
|
25 |
+
|
26 |
+
def phonemize_and_save(text, fn, text_tokenizer):
|
27 |
+
"""Phonemizes the text and saves the result to a file."""
|
28 |
+
phn = tokenize_text(text_tokenizer, text)
|
29 |
+
os.makedirs(os.path.dirname(fn), exist_ok=True)
|
30 |
+
with open(fn, "w") as f:
|
31 |
+
f.write(" ".join(phn))
|
32 |
+
return set(phn)
|
33 |
+
|
34 |
+
|
35 |
+
def process_item(item, root, sub_root, audio_folder, phn_folder, audio_ext, text_ext, phn_ext, text_tokenizer):
|
36 |
+
"""Worker function to process a single item."""
|
37 |
+
text_path = os.path.join(root, sub_root, audio_folder, item[0].replace(audio_ext, text_ext))
|
38 |
+
if not os.path.exists(text_path):
|
39 |
+
return {"missing_text": text_path, "success": False, "cur_phn_set": set()}
|
40 |
+
|
41 |
+
with open(text_path, "r") as f:
|
42 |
+
text = [line.strip() for line in f.readlines()]
|
43 |
+
text = " ".join(text)
|
44 |
+
|
45 |
+
phn_path = os.path.join(root, sub_root, phn_folder, item[0].replace(audio_ext, phn_ext))
|
46 |
+
cur_phn_set = phonemize_and_save(text, phn_path, text_tokenizer)
|
47 |
+
return {"missing_text": None, "success": True, "cur_phn_set": cur_phn_set}
|
48 |
+
|
49 |
+
|
50 |
+
def process_item_star(args):
|
51 |
+
"""Unpacks arguments for `process_item` to work with `imap`."""
|
52 |
+
return process_item(*args)
|
53 |
+
|
54 |
+
def main(
|
55 |
+
root="/data/scratch/pyp/datasets/emilia",
|
56 |
+
sub_root="preprocessed",
|
57 |
+
manifest_folder="manifest_for_codec",
|
58 |
+
audio_folder="audio",
|
59 |
+
phn_folder="phoneme",
|
60 |
+
audio_ext=".mp3",
|
61 |
+
text_ext=".txt",
|
62 |
+
phn_ext=".txt",
|
63 |
+
num_workers=8,
|
64 |
+
):
|
65 |
+
"""Main function to process phoneme generation in parallel."""
|
66 |
+
# # Initialize the tokenizer
|
67 |
+
text_tokenizer = TextTokenizer()
|
68 |
+
all_fns = glob.glob(f"{root}/{sub_root}/{manifest_folder}/*.txt")
|
69 |
+
print(f"found {len(all_fns)} manifest files")
|
70 |
+
print(f"{all_fns[:3]=}")
|
71 |
+
|
72 |
+
data = []
|
73 |
+
for fn in all_fns:
|
74 |
+
with open(fn, "r") as f:
|
75 |
+
data += [line.strip().split("\t") for line in f]
|
76 |
+
|
77 |
+
vocab = set()
|
78 |
+
|
79 |
+
################## parallel processing ##################
|
80 |
+
################## parallel processing ##################
|
81 |
+
################## parallel processing ##################
|
82 |
+
# Prepare arguments for the worker function
|
83 |
+
# tasks = [
|
84 |
+
# (
|
85 |
+
# item,
|
86 |
+
# root,
|
87 |
+
# sub_root,
|
88 |
+
# audio_folder,
|
89 |
+
# phn_folder,
|
90 |
+
# audio_ext,
|
91 |
+
# text_ext,
|
92 |
+
# phn_ext,
|
93 |
+
# text_tokenizer,
|
94 |
+
# )
|
95 |
+
# for item in data
|
96 |
+
# ]
|
97 |
+
|
98 |
+
# # Parallel processing with progress monitoring
|
99 |
+
# results = []
|
100 |
+
# with Pool(num_workers) as pool:
|
101 |
+
# for result in tqdm.tqdm(
|
102 |
+
# pool.imap_unordered(process_item_star, tasks),
|
103 |
+
# total=len(tasks),
|
104 |
+
# desc="Processing items",
|
105 |
+
# ):
|
106 |
+
# results.append(result)
|
107 |
+
# # read all manifest endswith .txt
|
108 |
+
# missing_text = [result["missing_text"] for result in results if not result["success"]]
|
109 |
+
# for result in results:
|
110 |
+
# if result['success']:
|
111 |
+
# vocab.update(result['cur_phn_set'])
|
112 |
+
################## parallel processing ##################
|
113 |
+
################## parallel processing ##################
|
114 |
+
################## parallel processing ##################
|
115 |
+
|
116 |
+
################## sequential processing ##################
|
117 |
+
################## sequential processing ##################
|
118 |
+
################## sequential processing ##################
|
119 |
+
missing_text = []
|
120 |
+
for item in tqdm.tqdm(data):
|
121 |
+
text_path = os.path.join(root, sub_root, audio_folder, item[0].replace(audio_ext, text_ext))
|
122 |
+
if not os.path.exists(text_path):
|
123 |
+
missing_text.append(text_path)
|
124 |
+
continue
|
125 |
+
try:
|
126 |
+
with open(text_path, "r") as f:
|
127 |
+
text = [line.strip() for line in f.readlines()]
|
128 |
+
text = " ".join(text)
|
129 |
+
except:
|
130 |
+
print(f"Error reading {text_path}")
|
131 |
+
continue
|
132 |
+
cur_phn_set = phonemize_and_save(text, os.path.join(root, sub_root, phn_folder, item[0].replace(audio_ext, phn_ext)), text_tokenizer)
|
133 |
+
vocab.update(cur_phn_set)
|
134 |
+
################## sequential processing ##################
|
135 |
+
################## sequential processing ##################
|
136 |
+
################## sequential processing ##################
|
137 |
+
|
138 |
+
# save the vocab
|
139 |
+
vocab = list(vocab)
|
140 |
+
# sort the vocab
|
141 |
+
vocab.sort()
|
142 |
+
with open(os.path.join(root, sub_root, "vocab.txt"), "w") as f:
|
143 |
+
f.write("\n".join(vocab))
|
144 |
+
|
145 |
+
# Collect missing text paths
|
146 |
+
print(f"Missing text files: {len(missing_text)}")
|
147 |
+
if missing_text:
|
148 |
+
print("Some missing files:", missing_text[:10]) # Print the first 10 missing files as an example
|
149 |
+
|
150 |
+
|
151 |
+
if __name__ == "__main__":
|
152 |
+
fire.Fire(main)
|
153 |
+
|
154 |
+
|
155 |
+
|
156 |
+
|
157 |
+
|
158 |
+
|
data/emilia_preprocessing/step6_encodec_encode.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from email.policy import default
|
3 |
+
def parse_args():
|
4 |
+
parser = argparse.ArgumentParser(description="encode the dataset using codec model")
|
5 |
+
parser.add_argument('--root', type=str, default="/data/scratch/pyp/datasets/emilia", help="Path to the directory")
|
6 |
+
parser.add_argument('--sub_root', type=str, default="preprocessed", help="sub directory")
|
7 |
+
parser.add_argument('--encodec_name', type=str, default="encodec_6f79c6a8.th", help="name of the codec model")
|
8 |
+
parser.add_argument('--n_workers', type=int, default=16, help="Number of parallel worker processes")
|
9 |
+
parser.add_argument('--batch_size', type=int, default=16, help="batch size for codec encoding, decrease it if OOM. This is the sum of batch size *over each gpu*, so increase it if you are using more gpus")
|
10 |
+
parser.add_argument('--audio_sr', type=int, default=16000, help='input audio sample rate')
|
11 |
+
parser.add_argument('--model_sr', type=int, default=16000, help='encodec input audio sample rate')
|
12 |
+
parser.add_argument('--downsample_rate', type=int, default=320, help='encodec downsample rate')
|
13 |
+
parser.add_argument('--model_code_sr', type=float, default=50, help='codec model code sample rate')
|
14 |
+
parser.add_argument('--len_cap', type=float, default=1000, help='will drop audios that are longer than this number')
|
15 |
+
parser.add_argument('--min_len', type=float, default=0.5, help='will drop audios that are shorter than this number')
|
16 |
+
parser.add_argument('--partition', type=str, default="1/1", help='split for parallel processing')
|
17 |
+
parser.add_argument('--split', type=str, default='train', choices=['train', 'valid', 'test'])
|
18 |
+
return parser.parse_args()
|
19 |
+
|
20 |
+
if __name__ == "__main__":
|
21 |
+
import logging
|
22 |
+
formatter = (
|
23 |
+
"%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s"
|
24 |
+
)
|
25 |
+
logging.basicConfig(format=formatter, level=logging.INFO)
|
26 |
+
|
27 |
+
import os, sys
|
28 |
+
import numpy as np
|
29 |
+
import torch
|
30 |
+
import torchaudio
|
31 |
+
import tqdm
|
32 |
+
import time
|
33 |
+
|
34 |
+
args = parse_args()
|
35 |
+
|
36 |
+
def sort_by_audio_len(lens):
|
37 |
+
inds = np.argsort(lens).tolist()
|
38 |
+
|
39 |
+
logging.info(f"longest: {lens[inds[-1]]/args.downsample_rate} encodec codes, {lens[inds[-1]]/args.model_sr:.2f} sec.")
|
40 |
+
logging.info(f"shortest: {lens[inds[0]]/args.downsample_rate} encodec codes, {lens[inds[0]]/args.model_sr:.2f} sec.")
|
41 |
+
logging.info(f"median: {lens[inds[len(inds)//2]]/args.downsample_rate} encodec codes, {lens[inds[len(inds)//2]]/args.model_sr:.2f} sec.")
|
42 |
+
logging.info(f"95 percentile longest: {lens[inds[int(len(inds)*0.95)]]/args.downsample_rate} encodec codes, {lens[inds[int(len(inds)*0.95)]]/args.model_sr:.2f} sec.")
|
43 |
+
return inds[::-1]
|
44 |
+
|
45 |
+
def write_array_to_txt_file(array, filename):
|
46 |
+
with open(filename, 'w') as f:
|
47 |
+
for a in array[:-1]:
|
48 |
+
f.write(' '.join(map(str, a))+'\n')
|
49 |
+
f.write(' '.join(map(str, array[-1])))
|
50 |
+
|
51 |
+
class mydataset(torch.utils.data.Dataset):
|
52 |
+
def __init__(self, split):
|
53 |
+
super().__init__()
|
54 |
+
self.split = split
|
55 |
+
self.audio_dir = audio_dir
|
56 |
+
manifest_fn = os.path.join(encodec_manifest_dir, split+".txt")
|
57 |
+
cur_sp = int(args.partition.split("/")[0])-1
|
58 |
+
total_sp = int(args.partition.split("/")[1])
|
59 |
+
with open(manifest_fn, "r") as rf:
|
60 |
+
self.data = [l.strip().split("\t") for l in rf.readlines()][cur_sp::total_sp]
|
61 |
+
def __len__(self):
|
62 |
+
return len(self.data)
|
63 |
+
def __getitem__(self, ind):
|
64 |
+
try:
|
65 |
+
afn = self.data[ind][0]
|
66 |
+
fn = os.path.join(self.audio_dir, afn)
|
67 |
+
audio, sr = torchaudio.load(fn)
|
68 |
+
if sr != args.model_sr:
|
69 |
+
audio = torchaudio.transforms.Resample(sr, args.model_sr)(audio)
|
70 |
+
sr = args.model_sr
|
71 |
+
assert sr == args.model_sr, sr
|
72 |
+
except Exception as e:
|
73 |
+
# logging.info(f"{e}")
|
74 |
+
return None, None, None
|
75 |
+
assert audio.ndim==2 and audio.shape[0] == 1, audio.shape
|
76 |
+
return audio.type(torch.float32).squeeze(0), audio.shape[-1], os.path.splitext(afn)[0]
|
77 |
+
def collate(self, batch):
|
78 |
+
lens, audios, segment_ids = [], [], []
|
79 |
+
for item in batch:
|
80 |
+
if item[0] != None:
|
81 |
+
audios.append(item[0])
|
82 |
+
lens.append(item[1])
|
83 |
+
segment_ids.append(item[2])
|
84 |
+
return audios, lens, segment_ids
|
85 |
+
|
86 |
+
# roots
|
87 |
+
sub_root = args.sub_root
|
88 |
+
encodec_manifest_dir = os.path.join(args.root, sub_root, "manifest_for_codec")
|
89 |
+
audio_dir = os.path.join(args.root, sub_root, "audio")
|
90 |
+
save_manifest_dir = os.path.join(args.root, sub_root,"manifest_final_encodec")
|
91 |
+
if args.encodec_name == "encodec_6f79c6a8.th":
|
92 |
+
save_codes_dir = os.path.join(args.root, sub_root,"encodec_4cb")
|
93 |
+
elif args.encodec_name == "encodec_8cb1024_giga.th":
|
94 |
+
save_codes_dir = os.path.join(args.root, sub_root,"encodec_8cb")
|
95 |
+
|
96 |
+
os.makedirs(save_manifest_dir, exist_ok=True)
|
97 |
+
os.makedirs(save_codes_dir, exist_ok=True)
|
98 |
+
|
99 |
+
def import_encodec():
|
100 |
+
from encodec import get_compression_model
|
101 |
+
userdir = os.path.expanduser("~")
|
102 |
+
model = get_compression_model(os.path.join(userdir, "VoiceStar", f"pretrained/{args.encodec_name}"), encode_only=True, device="cuda")
|
103 |
+
model = torch.nn.DataParallel(model)
|
104 |
+
return model
|
105 |
+
model = import_encodec()
|
106 |
+
|
107 |
+
# setup dataloader
|
108 |
+
mega_batch_size = 2048
|
109 |
+
batch_size = args.batch_size
|
110 |
+
|
111 |
+
dataset = mydataset(args.split)
|
112 |
+
if len(dataset) == 0:
|
113 |
+
logging.info(f"no data found for split {args.split} partition {args.partition}")
|
114 |
+
sys.exit(0)
|
115 |
+
loader = torch.torch.utils.data.DataLoader(dataset, batch_size=mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=dataset.collate)
|
116 |
+
split = args.split
|
117 |
+
|
118 |
+
skip = 0
|
119 |
+
logging.info(f"now processing split {split} partition {args.partition}...")
|
120 |
+
mega_n_steps = int(np.ceil(len(loader.dataset) / mega_batch_size))
|
121 |
+
# mega_n_steps = int(np.ceil(len(gs) / mega_batch_size))
|
122 |
+
logging.info(f"partition the split {split} into {mega_n_steps} parts, each has at most {mega_batch_size} samples")
|
123 |
+
mani_fn = os.path.join(save_manifest_dir, f"{split}_{args.partition.replace('/', '=')}.txt")
|
124 |
+
logging.info(f"manifest for split {split} partition {args.partition.replace('/', '=')}.txt will be saved at {mani_fn}")
|
125 |
+
with open(mani_fn, "w") as mani_wf:
|
126 |
+
# with open(mani_fn, "a") as mani_wf: # resume from where we failed
|
127 |
+
for m, mega_batch in enumerate(tqdm.tqdm(loader, mininterval=60, maxinterval=60)):
|
128 |
+
|
129 |
+
logging.info(f"====================================")
|
130 |
+
logging.info(f"====================================")
|
131 |
+
logging.info(f"now processing mega step {m+1}/{mega_n_steps}")
|
132 |
+
|
133 |
+
try:
|
134 |
+
lengths = np.array(mega_batch[1])
|
135 |
+
sorted_inds = sort_by_audio_len(lengths)
|
136 |
+
for j in range(len(sorted_inds))[::-1]:
|
137 |
+
if lengths[sorted_inds[j]] < args.model_sr*args.min_len or lengths[sorted_inds[j]] > args.model_sr*args.len_cap: # skip samples that are too short (shorter than 0.2s), or too big (bigger than 80s)
|
138 |
+
skip += 1
|
139 |
+
del sorted_inds[j]
|
140 |
+
|
141 |
+
n_steps = int(np.ceil(len(sorted_inds) / batch_size))
|
142 |
+
for n in tqdm.tqdm(range(n_steps), disable=True):
|
143 |
+
inds_used = sorted_inds[n*batch_size:(n+1)*batch_size]
|
144 |
+
wav_batch = [mega_batch[0][id] for id in inds_used]
|
145 |
+
all_lens = [mega_batch[1][id] for id in inds_used]
|
146 |
+
segment_id_batch = [mega_batch[2][id] for id in inds_used]
|
147 |
+
padded_wav = torch.nn.utils.rnn.pad_sequence(wav_batch, batch_first=True).unsqueeze(1) # [B, T] -> [B, 1, T]
|
148 |
+
# Extract discrete codes from EnCodec
|
149 |
+
with torch.no_grad():
|
150 |
+
if max(all_lens) > 300000 and len(all_lens) > 1: # if utterances are long, simply pass half of them at a time
|
151 |
+
codes = []
|
152 |
+
inwav = padded_wav.cuda()
|
153 |
+
codes.append(model(inwav[:len(inwav)//2])[0].cpu())
|
154 |
+
codes.append(model(inwav[len(inwav)//2:])[0].cpu())
|
155 |
+
codes = torch.cat(codes, dim=0)
|
156 |
+
else:
|
157 |
+
encoded_frames = model(padded_wav.cuda())
|
158 |
+
codes = encoded_frames[0].cpu() # [B, n_codebook, T]
|
159 |
+
|
160 |
+
for i, length in enumerate(all_lens):
|
161 |
+
save_fn = os.path.join(save_codes_dir, segment_id_batch[i]+".txt")
|
162 |
+
actual_len = round(length / args.downsample_rate) # 320 is downsample rate for this model
|
163 |
+
cur_code = codes[i].tolist() if type(codes) == list else codes[i, :, :actual_len].tolist()
|
164 |
+
os.makedirs(os.path.dirname(save_fn), exist_ok=True)
|
165 |
+
write_array_to_txt_file(cur_code, save_fn)
|
166 |
+
|
167 |
+
mani_wf.write(f"{segment_id_batch[i]}\t{len(cur_code[0])}\n") # write to manifest file
|
168 |
+
# if i == 10:
|
169 |
+
# raise
|
170 |
+
except Exception as e:
|
171 |
+
print(f'exception!! at {m+1}')
|
172 |
+
print(e)
|
173 |
+
continue
|
174 |
+
|
175 |
+
# break
|
176 |
+
logging.info(f"split {split} partition {args.partition} has {len(loader.dataset)} samples in total, skipped {skip} due to utterance being too long or too short")
|
177 |
+
# break
|
data/emilia_preprocessing/step6_encodec_encode_script.sh
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
source ~/miniconda3/etc/profile.d/conda.sh
|
3 |
+
conda activate voicestar
|
4 |
+
|
5 |
+
dir=${processed_dir:-/data/scratch/pyp/datasets/emilia}
|
6 |
+
sub_root=${sub_root:-preprocessed}
|
7 |
+
encodec_name=${encodec_name:-"encodec_6f79c6a8.th"}
|
8 |
+
n_workers=${n_workers:-64}
|
9 |
+
batch_size=${batch_size:-512}
|
10 |
+
audio_sr=16000
|
11 |
+
model_sr=16000
|
12 |
+
downsample_rate=320
|
13 |
+
model_code_sr=50
|
14 |
+
len_cap=1000
|
15 |
+
min_len=0.5
|
16 |
+
partition=${partition:-"1/1"}
|
17 |
+
split=${split:-"train"}
|
18 |
+
|
19 |
+
python step6_encodec_encode.py --root $dir --sub_root ${sub_root} --encodec_name ${encodec_name} --n_workers $n_workers --batch_size $batch_size --audio_sr $audio_sr --model_sr $model_sr --downsample_rate $downsample_rate --model_code_sr $model_code_sr --len_cap $len_cap --min_len $min_len --partition $partition --split $split
|
data/encodec.py
ADDED
@@ -0,0 +1,1554 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Compression models or wrapper around existing models.
|
7 |
+
Also defines the main interface that a model must follow to be usable as an audio tokenizer.
|
8 |
+
"""
|
9 |
+
|
10 |
+
from abc import ABC, abstractmethod
|
11 |
+
from dataclasses import dataclass, field
|
12 |
+
import logging
|
13 |
+
import math
|
14 |
+
from pathlib import Path
|
15 |
+
import typing as tp
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
import torch
|
19 |
+
from torch import nn
|
20 |
+
from torch import einsum
|
21 |
+
import torch.nn.functional as F
|
22 |
+
from torch.nn.utils import spectral_norm, weight_norm
|
23 |
+
|
24 |
+
import logging
|
25 |
+
import warnings
|
26 |
+
from einops import rearrange, repeat
|
27 |
+
import omegaconf
|
28 |
+
# import flashy
|
29 |
+
|
30 |
+
CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
|
31 |
+
'time_group_norm'])
|
32 |
+
|
33 |
+
def dict_from_config(cfg: omegaconf.DictConfig) -> dict:
|
34 |
+
"""Convenience function to map an omegaconf configuration to a dictionary.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
cfg (omegaconf.DictConfig): Original configuration to map to dict.
|
38 |
+
Returns:
|
39 |
+
dict: Config as dictionary object.
|
40 |
+
"""
|
41 |
+
dct = omegaconf.OmegaConf.to_container(cfg, resolve=True)
|
42 |
+
assert isinstance(dct, dict)
|
43 |
+
return dct
|
44 |
+
|
45 |
+
@dataclass
|
46 |
+
class QuantizedResult:
|
47 |
+
x: torch.Tensor
|
48 |
+
codes: torch.Tensor
|
49 |
+
bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
|
50 |
+
penalty: tp.Optional[torch.Tensor] = None
|
51 |
+
metrics: dict = field(default_factory=dict)
|
52 |
+
|
53 |
+
class BaseQuantizer(nn.Module):
|
54 |
+
"""Base class for quantizers.
|
55 |
+
"""
|
56 |
+
|
57 |
+
def forward(self, x: torch.Tensor, frame_rate: int) -> QuantizedResult:
|
58 |
+
"""
|
59 |
+
Given input tensor x, returns first the quantized (or approximately quantized)
|
60 |
+
representation along with quantized codes, bandwidth, and any penalty term for the loss.
|
61 |
+
Finally, this returns a dict of metrics to update logging etc.
|
62 |
+
Frame rate must be passed so that the bandwidth is properly computed.
|
63 |
+
"""
|
64 |
+
raise NotImplementedError()
|
65 |
+
|
66 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
67 |
+
"""Encode a given input tensor with the specified sample rate at the given bandwidth."""
|
68 |
+
raise NotImplementedError()
|
69 |
+
|
70 |
+
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
71 |
+
"""Decode the given codes to the quantized representation."""
|
72 |
+
raise NotImplementedError()
|
73 |
+
|
74 |
+
@property
|
75 |
+
def total_codebooks(self):
|
76 |
+
"""Total number of codebooks."""
|
77 |
+
raise NotImplementedError()
|
78 |
+
|
79 |
+
@property
|
80 |
+
def num_codebooks(self):
|
81 |
+
"""Number of active codebooks."""
|
82 |
+
raise NotImplementedError()
|
83 |
+
|
84 |
+
def set_num_codebooks(self, n: int):
|
85 |
+
"""Set the number of active codebooks."""
|
86 |
+
raise NotImplementedError()
|
87 |
+
|
88 |
+
class CompressionModel(ABC, nn.Module):
|
89 |
+
"""Base API for all compression model that aim at being used as audio tokenizers
|
90 |
+
with a language model.
|
91 |
+
"""
|
92 |
+
|
93 |
+
@abstractmethod
|
94 |
+
def forward(self, x: torch.Tensor) -> QuantizedResult:
|
95 |
+
...
|
96 |
+
|
97 |
+
@abstractmethod
|
98 |
+
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
99 |
+
"""See `EncodecModel.encode`."""
|
100 |
+
...
|
101 |
+
|
102 |
+
@abstractmethod
|
103 |
+
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
104 |
+
"""See `EncodecModel.decode`."""
|
105 |
+
...
|
106 |
+
|
107 |
+
@abstractmethod
|
108 |
+
def decode_latent(self, codes: torch.Tensor):
|
109 |
+
"""Decode from the discrete codes to continuous latent space."""
|
110 |
+
...
|
111 |
+
|
112 |
+
@property
|
113 |
+
@abstractmethod
|
114 |
+
def channels(self) -> int:
|
115 |
+
...
|
116 |
+
|
117 |
+
@property
|
118 |
+
@abstractmethod
|
119 |
+
def frame_rate(self) -> float:
|
120 |
+
...
|
121 |
+
|
122 |
+
@property
|
123 |
+
@abstractmethod
|
124 |
+
def sample_rate(self) -> int:
|
125 |
+
...
|
126 |
+
|
127 |
+
@property
|
128 |
+
@abstractmethod
|
129 |
+
def cardinality(self) -> int:
|
130 |
+
...
|
131 |
+
|
132 |
+
@property
|
133 |
+
@abstractmethod
|
134 |
+
def num_codebooks(self) -> int:
|
135 |
+
...
|
136 |
+
|
137 |
+
@property
|
138 |
+
@abstractmethod
|
139 |
+
def total_codebooks(self) -> int:
|
140 |
+
...
|
141 |
+
|
142 |
+
@abstractmethod
|
143 |
+
def set_num_codebooks(self, n: int):
|
144 |
+
"""Set the active number of codebooks used by the quantizer."""
|
145 |
+
...
|
146 |
+
|
147 |
+
def apply_parametrization_norm(module: nn.Module, norm: str = 'none'):
|
148 |
+
assert norm in CONV_NORMALIZATIONS
|
149 |
+
if norm == 'weight_norm':
|
150 |
+
return weight_norm(module)
|
151 |
+
elif norm == 'spectral_norm':
|
152 |
+
return spectral_norm(module)
|
153 |
+
else:
|
154 |
+
# We already check was in CONV_NORMALIZATION, so any other choice
|
155 |
+
# doesn't need reparametrization.
|
156 |
+
return module
|
157 |
+
|
158 |
+
|
159 |
+
def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs):
|
160 |
+
"""Return the proper normalization module. If causal is True, this will ensure the returned
|
161 |
+
module is causal, or return an error if the normalization doesn't support causal evaluation.
|
162 |
+
"""
|
163 |
+
assert norm in CONV_NORMALIZATIONS
|
164 |
+
if norm == 'time_group_norm':
|
165 |
+
if causal:
|
166 |
+
raise ValueError("GroupNorm doesn't support causal evaluation.")
|
167 |
+
assert isinstance(module, nn.modules.conv._ConvNd)
|
168 |
+
return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
|
169 |
+
else:
|
170 |
+
return nn.Identity()
|
171 |
+
|
172 |
+
|
173 |
+
def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
|
174 |
+
padding_total: int = 0) -> int:
|
175 |
+
"""See `pad_for_conv1d`."""
|
176 |
+
length = x.shape[-1]
|
177 |
+
n_frames = (length - kernel_size + padding_total) / stride + 1
|
178 |
+
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
179 |
+
return ideal_length - length
|
180 |
+
|
181 |
+
|
182 |
+
def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
|
183 |
+
"""Pad for a convolution to make sure that the last window is full.
|
184 |
+
Extra padding is added at the end. This is required to ensure that we can rebuild
|
185 |
+
an output of the same length, as otherwise, even with padding, some time steps
|
186 |
+
might get removed.
|
187 |
+
For instance, with total padding = 4, kernel size = 4, stride = 2:
|
188 |
+
0 0 1 2 3 4 5 0 0 # (0s are padding)
|
189 |
+
1 2 3 # (output frames of a convolution, last 0 is never used)
|
190 |
+
0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
|
191 |
+
1 2 3 4 # once you removed padding, we are missing one time step !
|
192 |
+
"""
|
193 |
+
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
194 |
+
return F.pad(x, (0, extra_padding))
|
195 |
+
|
196 |
+
|
197 |
+
def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
|
198 |
+
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
|
199 |
+
If this is the case, we insert extra 0 padding to the right before the reflection happen.
|
200 |
+
"""
|
201 |
+
length = x.shape[-1]
|
202 |
+
padding_left, padding_right = paddings
|
203 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
204 |
+
if mode == 'reflect':
|
205 |
+
max_pad = max(padding_left, padding_right)
|
206 |
+
extra_pad = 0
|
207 |
+
if length <= max_pad:
|
208 |
+
extra_pad = max_pad - length + 1
|
209 |
+
x = F.pad(x, (0, extra_pad))
|
210 |
+
padded = F.pad(x, paddings, mode, value)
|
211 |
+
end = padded.shape[-1] - extra_pad
|
212 |
+
return padded[..., :end]
|
213 |
+
else:
|
214 |
+
return F.pad(x, paddings, mode, value)
|
215 |
+
|
216 |
+
|
217 |
+
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
|
218 |
+
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
|
219 |
+
padding_left, padding_right = paddings
|
220 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
221 |
+
assert (padding_left + padding_right) <= x.shape[-1]
|
222 |
+
end = x.shape[-1] - padding_right
|
223 |
+
return x[..., padding_left: end]
|
224 |
+
|
225 |
+
|
226 |
+
class NormConv1d(nn.Module):
|
227 |
+
"""Wrapper around Conv1d and normalization applied to this conv
|
228 |
+
to provide a uniform interface across normalization approaches.
|
229 |
+
"""
|
230 |
+
def __init__(self, *args, causal: bool = False, norm: str = 'none',
|
231 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
232 |
+
super().__init__()
|
233 |
+
self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
|
234 |
+
self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
|
235 |
+
self.norm_type = norm
|
236 |
+
|
237 |
+
def forward(self, x):
|
238 |
+
x = self.conv(x)
|
239 |
+
x = self.norm(x)
|
240 |
+
return x
|
241 |
+
|
242 |
+
|
243 |
+
class NormConv2d(nn.Module):
|
244 |
+
"""Wrapper around Conv2d and normalization applied to this conv
|
245 |
+
to provide a uniform interface across normalization approaches.
|
246 |
+
"""
|
247 |
+
def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
248 |
+
super().__init__()
|
249 |
+
self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
|
250 |
+
self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
|
251 |
+
self.norm_type = norm
|
252 |
+
|
253 |
+
def forward(self, x):
|
254 |
+
x = self.conv(x)
|
255 |
+
x = self.norm(x)
|
256 |
+
return x
|
257 |
+
|
258 |
+
|
259 |
+
class NormConvTranspose1d(nn.Module):
|
260 |
+
"""Wrapper around ConvTranspose1d and normalization applied to this conv
|
261 |
+
to provide a uniform interface across normalization approaches.
|
262 |
+
"""
|
263 |
+
def __init__(self, *args, causal: bool = False, norm: str = 'none',
|
264 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
265 |
+
super().__init__()
|
266 |
+
self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
|
267 |
+
self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
|
268 |
+
self.norm_type = norm
|
269 |
+
|
270 |
+
def forward(self, x):
|
271 |
+
x = self.convtr(x)
|
272 |
+
x = self.norm(x)
|
273 |
+
return x
|
274 |
+
|
275 |
+
|
276 |
+
class NormConvTranspose2d(nn.Module):
|
277 |
+
"""Wrapper around ConvTranspose2d and normalization applied to this conv
|
278 |
+
to provide a uniform interface across normalization approaches.
|
279 |
+
"""
|
280 |
+
def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
281 |
+
super().__init__()
|
282 |
+
self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
|
283 |
+
self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
|
284 |
+
|
285 |
+
def forward(self, x):
|
286 |
+
x = self.convtr(x)
|
287 |
+
x = self.norm(x)
|
288 |
+
return x
|
289 |
+
|
290 |
+
|
291 |
+
class StreamableConv1d(nn.Module):
|
292 |
+
"""Conv1d with some builtin handling of asymmetric or causal padding
|
293 |
+
and normalization.
|
294 |
+
"""
|
295 |
+
def __init__(self, in_channels: int, out_channels: int,
|
296 |
+
kernel_size: int, stride: int = 1, dilation: int = 1,
|
297 |
+
groups: int = 1, bias: bool = True, causal: bool = False,
|
298 |
+
norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
|
299 |
+
pad_mode: str = 'reflect'):
|
300 |
+
super().__init__()
|
301 |
+
# warn user on unusual setup between dilation and stride
|
302 |
+
if stride > 1 and dilation > 1:
|
303 |
+
warnings.warn("StreamableConv1d has been initialized with stride > 1 and dilation > 1"
|
304 |
+
f" (kernel_size={kernel_size} stride={stride}, dilation={dilation}).")
|
305 |
+
self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
|
306 |
+
dilation=dilation, groups=groups, bias=bias, causal=causal,
|
307 |
+
norm=norm, norm_kwargs=norm_kwargs)
|
308 |
+
self.causal = causal
|
309 |
+
self.pad_mode = pad_mode
|
310 |
+
|
311 |
+
def forward(self, x):
|
312 |
+
B, C, T = x.shape
|
313 |
+
kernel_size = self.conv.conv.kernel_size[0]
|
314 |
+
stride = self.conv.conv.stride[0]
|
315 |
+
dilation = self.conv.conv.dilation[0]
|
316 |
+
kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
|
317 |
+
padding_total = kernel_size - stride
|
318 |
+
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
319 |
+
if self.causal:
|
320 |
+
# Left padding for causal
|
321 |
+
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
|
322 |
+
else:
|
323 |
+
# Asymmetric padding required for odd strides
|
324 |
+
padding_right = padding_total // 2
|
325 |
+
padding_left = padding_total - padding_right
|
326 |
+
x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
|
327 |
+
return self.conv(x)
|
328 |
+
|
329 |
+
|
330 |
+
class StreamableConvTranspose1d(nn.Module):
|
331 |
+
"""ConvTranspose1d with some builtin handling of asymmetric or causal padding
|
332 |
+
and normalization.
|
333 |
+
"""
|
334 |
+
def __init__(self, in_channels: int, out_channels: int,
|
335 |
+
kernel_size: int, stride: int = 1, causal: bool = False,
|
336 |
+
norm: str = 'none', trim_right_ratio: float = 1.,
|
337 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}):
|
338 |
+
super().__init__()
|
339 |
+
self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
|
340 |
+
causal=causal, norm=norm, norm_kwargs=norm_kwargs)
|
341 |
+
self.causal = causal
|
342 |
+
self.trim_right_ratio = trim_right_ratio
|
343 |
+
assert self.causal or self.trim_right_ratio == 1., \
|
344 |
+
"`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
|
345 |
+
assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
|
346 |
+
|
347 |
+
def forward(self, x):
|
348 |
+
kernel_size = self.convtr.convtr.kernel_size[0]
|
349 |
+
stride = self.convtr.convtr.stride[0]
|
350 |
+
padding_total = kernel_size - stride
|
351 |
+
|
352 |
+
y = self.convtr(x)
|
353 |
+
|
354 |
+
# We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
|
355 |
+
# removed at the very end, when keeping only the right length for the output,
|
356 |
+
# as removing it here would require also passing the length at the matching layer
|
357 |
+
# in the encoder.
|
358 |
+
if self.causal:
|
359 |
+
# Trim the padding on the right according to the specified ratio
|
360 |
+
# if trim_right_ratio = 1.0, trim everything from right
|
361 |
+
padding_right = math.ceil(padding_total * self.trim_right_ratio)
|
362 |
+
padding_left = padding_total - padding_right
|
363 |
+
y = unpad1d(y, (padding_left, padding_right))
|
364 |
+
else:
|
365 |
+
# Asymmetric padding required for odd strides
|
366 |
+
padding_right = padding_total // 2
|
367 |
+
padding_left = padding_total - padding_right
|
368 |
+
y = unpad1d(y, (padding_left, padding_right))
|
369 |
+
return y
|
370 |
+
|
371 |
+
|
372 |
+
class StreamableLSTM(nn.Module):
|
373 |
+
"""LSTM without worrying about the hidden state, nor the layout of the data.
|
374 |
+
Expects input as convolutional layout.
|
375 |
+
"""
|
376 |
+
def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
|
377 |
+
super().__init__()
|
378 |
+
self.skip = skip
|
379 |
+
self.lstm = nn.LSTM(dimension, dimension, num_layers)
|
380 |
+
|
381 |
+
def forward(self, x):
|
382 |
+
x = x.permute(2, 0, 1)
|
383 |
+
y, _ = self.lstm(x)
|
384 |
+
if self.skip:
|
385 |
+
y = y + x
|
386 |
+
y = y.permute(1, 2, 0)
|
387 |
+
return y
|
388 |
+
|
389 |
+
|
390 |
+
class SEANetResnetBlock(nn.Module):
|
391 |
+
"""Residual block from SEANet model.
|
392 |
+
|
393 |
+
Args:
|
394 |
+
dim (int): Dimension of the input/output.
|
395 |
+
kernel_sizes (list): List of kernel sizes for the convolutions.
|
396 |
+
dilations (list): List of dilations for the convolutions.
|
397 |
+
activation (str): Activation function.
|
398 |
+
activation_params (dict): Parameters to provide to the activation function.
|
399 |
+
norm (str): Normalization method.
|
400 |
+
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
401 |
+
causal (bool): Whether to use fully causal convolution.
|
402 |
+
pad_mode (str): Padding mode for the convolutions.
|
403 |
+
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
404 |
+
true_skip (bool): Whether to use true skip connection or a simple
|
405 |
+
(streamable) convolution as the skip connection.
|
406 |
+
"""
|
407 |
+
def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1],
|
408 |
+
activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
|
409 |
+
norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False,
|
410 |
+
pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True):
|
411 |
+
super().__init__()
|
412 |
+
assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations'
|
413 |
+
act = getattr(nn, activation)
|
414 |
+
hidden = dim // compress
|
415 |
+
block = []
|
416 |
+
for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
|
417 |
+
in_chs = dim if i == 0 else hidden
|
418 |
+
out_chs = dim if i == len(kernel_sizes) - 1 else hidden
|
419 |
+
block += [
|
420 |
+
act(**activation_params),
|
421 |
+
StreamableConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation,
|
422 |
+
norm=norm, norm_kwargs=norm_params,
|
423 |
+
causal=causal, pad_mode=pad_mode),
|
424 |
+
]
|
425 |
+
self.block = nn.Sequential(*block)
|
426 |
+
self.shortcut: nn.Module
|
427 |
+
if true_skip:
|
428 |
+
self.shortcut = nn.Identity()
|
429 |
+
else:
|
430 |
+
self.shortcut = StreamableConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params,
|
431 |
+
causal=causal, pad_mode=pad_mode)
|
432 |
+
|
433 |
+
def forward(self, x):
|
434 |
+
return self.shortcut(x) + self.block(x)
|
435 |
+
|
436 |
+
|
437 |
+
class SEANetEncoder(nn.Module):
|
438 |
+
"""SEANet encoder.
|
439 |
+
|
440 |
+
Args:
|
441 |
+
channels (int): Audio channels.
|
442 |
+
dimension (int): Intermediate representation dimension.
|
443 |
+
n_filters (int): Base width for the model.
|
444 |
+
n_residual_layers (int): nb of residual layers.
|
445 |
+
ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
|
446 |
+
upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
|
447 |
+
that must match the decoder order. We use the decoder order as some models may only employ the decoder.
|
448 |
+
activation (str): Activation function.
|
449 |
+
activation_params (dict): Parameters to provide to the activation function.
|
450 |
+
norm (str): Normalization method.
|
451 |
+
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
452 |
+
kernel_size (int): Kernel size for the initial convolution.
|
453 |
+
last_kernel_size (int): Kernel size for the initial convolution.
|
454 |
+
residual_kernel_size (int): Kernel size for the residual layers.
|
455 |
+
dilation_base (int): How much to increase the dilation with each layer.
|
456 |
+
causal (bool): Whether to use fully causal convolution.
|
457 |
+
pad_mode (str): Padding mode for the convolutions.
|
458 |
+
true_skip (bool): Whether to use true skip connection or a simple
|
459 |
+
(streamable) convolution as the skip connection in the residual network blocks.
|
460 |
+
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
461 |
+
lstm (int): Number of LSTM layers at the end of the encoder.
|
462 |
+
disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
|
463 |
+
For the encoder, it corresponds to the N first blocks.
|
464 |
+
"""
|
465 |
+
def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
|
466 |
+
ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
|
467 |
+
norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
|
468 |
+
last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
|
469 |
+
pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0,
|
470 |
+
disable_norm_outer_blocks: int = 0):
|
471 |
+
super().__init__()
|
472 |
+
self.channels = channels
|
473 |
+
self.dimension = dimension
|
474 |
+
self.n_filters = n_filters
|
475 |
+
self.ratios = list(reversed(ratios))
|
476 |
+
del ratios
|
477 |
+
self.n_residual_layers = n_residual_layers
|
478 |
+
self.hop_length = np.prod(self.ratios)
|
479 |
+
self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
|
480 |
+
self.disable_norm_outer_blocks = disable_norm_outer_blocks
|
481 |
+
assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \
|
482 |
+
"Number of blocks for which to disable norm is invalid." \
|
483 |
+
"It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
|
484 |
+
|
485 |
+
act = getattr(nn, activation)
|
486 |
+
mult = 1
|
487 |
+
model: tp.List[nn.Module] = [
|
488 |
+
StreamableConv1d(channels, mult * n_filters, kernel_size,
|
489 |
+
norm='none' if self.disable_norm_outer_blocks >= 1 else norm,
|
490 |
+
norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
|
491 |
+
]
|
492 |
+
# Downsample to raw audio scale
|
493 |
+
for i, ratio in enumerate(self.ratios):
|
494 |
+
block_norm = 'none' if self.disable_norm_outer_blocks >= i + 2 else norm
|
495 |
+
# Add residual layers
|
496 |
+
for j in range(n_residual_layers):
|
497 |
+
model += [
|
498 |
+
SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1],
|
499 |
+
dilations=[dilation_base ** j, 1],
|
500 |
+
norm=block_norm, norm_params=norm_params,
|
501 |
+
activation=activation, activation_params=activation_params,
|
502 |
+
causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
|
503 |
+
|
504 |
+
# Add downsampling layers
|
505 |
+
model += [
|
506 |
+
act(**activation_params),
|
507 |
+
StreamableConv1d(mult * n_filters, mult * n_filters * 2,
|
508 |
+
kernel_size=ratio * 2, stride=ratio,
|
509 |
+
norm=block_norm, norm_kwargs=norm_params,
|
510 |
+
causal=causal, pad_mode=pad_mode),
|
511 |
+
]
|
512 |
+
mult *= 2
|
513 |
+
|
514 |
+
if lstm:
|
515 |
+
model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
|
516 |
+
|
517 |
+
model += [
|
518 |
+
act(**activation_params),
|
519 |
+
StreamableConv1d(mult * n_filters, dimension, last_kernel_size,
|
520 |
+
norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm,
|
521 |
+
norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
|
522 |
+
]
|
523 |
+
|
524 |
+
self.model = nn.Sequential(*model)
|
525 |
+
|
526 |
+
def forward(self, x):
|
527 |
+
return self.model(x)
|
528 |
+
|
529 |
+
|
530 |
+
class SEANetDecoder(nn.Module):
|
531 |
+
"""SEANet decoder.
|
532 |
+
|
533 |
+
Args:
|
534 |
+
channels (int): Audio channels.
|
535 |
+
dimension (int): Intermediate representation dimension.
|
536 |
+
n_filters (int): Base width for the model.
|
537 |
+
n_residual_layers (int): nb of residual layers.
|
538 |
+
ratios (Sequence[int]): kernel size and stride ratios.
|
539 |
+
activation (str): Activation function.
|
540 |
+
activation_params (dict): Parameters to provide to the activation function.
|
541 |
+
final_activation (str): Final activation function after all convolutions.
|
542 |
+
final_activation_params (dict): Parameters to provide to the activation function.
|
543 |
+
norm (str): Normalization method.
|
544 |
+
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
545 |
+
kernel_size (int): Kernel size for the initial convolution.
|
546 |
+
last_kernel_size (int): Kernel size for the initial convolution.
|
547 |
+
residual_kernel_size (int): Kernel size for the residual layers.
|
548 |
+
dilation_base (int): How much to increase the dilation with each layer.
|
549 |
+
causal (bool): Whether to use fully causal convolution.
|
550 |
+
pad_mode (str): Padding mode for the convolutions.
|
551 |
+
true_skip (bool): Whether to use true skip connection or a simple.
|
552 |
+
(streamable) convolution as the skip connection in the residual network blocks.
|
553 |
+
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
554 |
+
lstm (int): Number of LSTM layers at the end of the encoder.
|
555 |
+
disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
|
556 |
+
For the decoder, it corresponds to the N last blocks.
|
557 |
+
trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
|
558 |
+
If equal to 1.0, it means that all the trimming is done at the right.
|
559 |
+
"""
|
560 |
+
def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
|
561 |
+
ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
|
562 |
+
final_activation: tp.Optional[str] = None, final_activation_params: tp.Optional[dict] = None,
|
563 |
+
norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
|
564 |
+
last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
|
565 |
+
pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0,
|
566 |
+
disable_norm_outer_blocks: int = 0, trim_right_ratio: float = 1.0):
|
567 |
+
super().__init__()
|
568 |
+
self.dimension = dimension
|
569 |
+
self.channels = channels
|
570 |
+
self.n_filters = n_filters
|
571 |
+
self.ratios = ratios
|
572 |
+
del ratios
|
573 |
+
self.n_residual_layers = n_residual_layers
|
574 |
+
self.hop_length = np.prod(self.ratios)
|
575 |
+
self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
|
576 |
+
self.disable_norm_outer_blocks = disable_norm_outer_blocks
|
577 |
+
assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \
|
578 |
+
"Number of blocks for which to disable norm is invalid." \
|
579 |
+
"It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
|
580 |
+
|
581 |
+
act = getattr(nn, activation)
|
582 |
+
mult = int(2 ** len(self.ratios))
|
583 |
+
model: tp.List[nn.Module] = [
|
584 |
+
StreamableConv1d(dimension, mult * n_filters, kernel_size,
|
585 |
+
norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm,
|
586 |
+
norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
|
587 |
+
]
|
588 |
+
|
589 |
+
if lstm:
|
590 |
+
model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
|
591 |
+
|
592 |
+
# Upsample to raw audio scale
|
593 |
+
for i, ratio in enumerate(self.ratios):
|
594 |
+
block_norm = 'none' if self.disable_norm_outer_blocks >= self.n_blocks - (i + 1) else norm
|
595 |
+
# Add upsampling layers
|
596 |
+
model += [
|
597 |
+
act(**activation_params),
|
598 |
+
StreamableConvTranspose1d(mult * n_filters, mult * n_filters // 2,
|
599 |
+
kernel_size=ratio * 2, stride=ratio,
|
600 |
+
norm=block_norm, norm_kwargs=norm_params,
|
601 |
+
causal=causal, trim_right_ratio=trim_right_ratio),
|
602 |
+
]
|
603 |
+
# Add residual layers
|
604 |
+
for j in range(n_residual_layers):
|
605 |
+
model += [
|
606 |
+
SEANetResnetBlock(mult * n_filters // 2, kernel_sizes=[residual_kernel_size, 1],
|
607 |
+
dilations=[dilation_base ** j, 1],
|
608 |
+
activation=activation, activation_params=activation_params,
|
609 |
+
norm=block_norm, norm_params=norm_params, causal=causal,
|
610 |
+
pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
|
611 |
+
|
612 |
+
mult //= 2
|
613 |
+
|
614 |
+
# Add final layers
|
615 |
+
model += [
|
616 |
+
act(**activation_params),
|
617 |
+
StreamableConv1d(n_filters, channels, last_kernel_size,
|
618 |
+
norm='none' if self.disable_norm_outer_blocks >= 1 else norm,
|
619 |
+
norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
|
620 |
+
]
|
621 |
+
# Add optional final activation to decoder (eg. tanh)
|
622 |
+
if final_activation is not None:
|
623 |
+
final_act = getattr(nn, final_activation)
|
624 |
+
final_activation_params = final_activation_params or {}
|
625 |
+
model += [
|
626 |
+
final_act(**final_activation_params)
|
627 |
+
]
|
628 |
+
self.model = nn.Sequential(*model)
|
629 |
+
|
630 |
+
def forward(self, z):
|
631 |
+
y = self.model(z)
|
632 |
+
return y
|
633 |
+
|
634 |
+
|
635 |
+
def exists(val: tp.Optional[tp.Any]) -> bool:
|
636 |
+
return val is not None
|
637 |
+
|
638 |
+
|
639 |
+
def default(val: tp.Any, d: tp.Any) -> tp.Any:
|
640 |
+
return val if exists(val) else d
|
641 |
+
|
642 |
+
|
643 |
+
def l2norm(t):
|
644 |
+
return F.normalize(t, p=2, dim=-1)
|
645 |
+
|
646 |
+
|
647 |
+
def ema_inplace(moving_avg, new, decay: float):
|
648 |
+
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
649 |
+
|
650 |
+
|
651 |
+
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
|
652 |
+
return (x + epsilon) / (x.sum() + n_categories * epsilon)
|
653 |
+
|
654 |
+
|
655 |
+
def uniform_init(*shape: int):
|
656 |
+
t = torch.empty(shape)
|
657 |
+
nn.init.kaiming_uniform_(t)
|
658 |
+
return t
|
659 |
+
|
660 |
+
|
661 |
+
def sample_vectors(samples, num: int):
|
662 |
+
num_samples, device = samples.shape[0], samples.device
|
663 |
+
|
664 |
+
if num_samples >= num:
|
665 |
+
indices = torch.randperm(num_samples, device=device)[:num]
|
666 |
+
else:
|
667 |
+
indices = torch.randint(0, num_samples, (num,), device=device)
|
668 |
+
|
669 |
+
return samples[indices]
|
670 |
+
|
671 |
+
|
672 |
+
def kmeans(samples, num_clusters: int, num_iters: int = 10):
|
673 |
+
dim, dtype = samples.shape[-1], samples.dtype
|
674 |
+
|
675 |
+
means = sample_vectors(samples, num_clusters)
|
676 |
+
|
677 |
+
for _ in range(num_iters):
|
678 |
+
diffs = rearrange(samples, "n d -> n () d") - rearrange(
|
679 |
+
means, "c d -> () c d"
|
680 |
+
)
|
681 |
+
dists = -(diffs ** 2).sum(dim=-1)
|
682 |
+
|
683 |
+
buckets = dists.max(dim=-1).indices
|
684 |
+
bins = torch.bincount(buckets, minlength=num_clusters)
|
685 |
+
zero_mask = bins == 0
|
686 |
+
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
687 |
+
|
688 |
+
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
|
689 |
+
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
|
690 |
+
new_means = new_means / bins_min_clamped[..., None]
|
691 |
+
|
692 |
+
means = torch.where(zero_mask[..., None], means, new_means)
|
693 |
+
|
694 |
+
return means, bins
|
695 |
+
|
696 |
+
|
697 |
+
def orthogonal_loss_fn(t):
|
698 |
+
# eq (2) from https://arxiv.org/abs/2112.00384
|
699 |
+
n = t.shape[0]
|
700 |
+
normed_codes = l2norm(t)
|
701 |
+
identity = torch.eye(n, device=t.device)
|
702 |
+
cosine_sim = einsum("i d, j d -> i j", normed_codes, normed_codes)
|
703 |
+
return ((cosine_sim - identity) ** 2).sum() / (n ** 2)
|
704 |
+
|
705 |
+
|
706 |
+
class EuclideanCodebook(nn.Module):
|
707 |
+
"""Codebook with Euclidean distance.
|
708 |
+
|
709 |
+
Args:
|
710 |
+
dim (int): Dimension.
|
711 |
+
codebook_size (int): Codebook size.
|
712 |
+
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
|
713 |
+
If set to true, run the k-means algorithm on the first training batch and use
|
714 |
+
the learned centroids as initialization.
|
715 |
+
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
|
716 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
717 |
+
epsilon (float): Epsilon value for numerical stability.
|
718 |
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
719 |
+
that have an exponential moving average cluster size less than the specified threshold with
|
720 |
+
randomly selected vector from the current batch.
|
721 |
+
"""
|
722 |
+
def __init__(
|
723 |
+
self,
|
724 |
+
dim: int,
|
725 |
+
codebook_size: int,
|
726 |
+
kmeans_init: int = False,
|
727 |
+
kmeans_iters: int = 10,
|
728 |
+
decay: float = 0.8,
|
729 |
+
epsilon: float = 1e-5,
|
730 |
+
threshold_ema_dead_code: int = 2,
|
731 |
+
):
|
732 |
+
super().__init__()
|
733 |
+
self.decay = decay
|
734 |
+
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
|
735 |
+
embed = init_fn(codebook_size, dim)
|
736 |
+
|
737 |
+
self.codebook_size = codebook_size
|
738 |
+
|
739 |
+
self.kmeans_iters = kmeans_iters
|
740 |
+
self.epsilon = epsilon
|
741 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
742 |
+
|
743 |
+
self.register_buffer("inited", torch.Tensor([not kmeans_init]))
|
744 |
+
self.register_buffer("cluster_size", torch.zeros(codebook_size))
|
745 |
+
self.register_buffer("embed", embed)
|
746 |
+
self.register_buffer("embed_avg", embed.clone())
|
747 |
+
|
748 |
+
@torch.jit.ignore
|
749 |
+
def init_embed_(self, data):
|
750 |
+
if self.inited:
|
751 |
+
return
|
752 |
+
|
753 |
+
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
|
754 |
+
self.embed.data.copy_(embed)
|
755 |
+
self.embed_avg.data.copy_(embed.clone())
|
756 |
+
self.cluster_size.data.copy_(cluster_size)
|
757 |
+
self.inited.data.copy_(torch.Tensor([True]))
|
758 |
+
# Make sure all buffers across workers are in sync after initialization
|
759 |
+
flashy.distrib.broadcast_tensors(self.buffers())
|
760 |
+
|
761 |
+
def replace_(self, samples, mask):
|
762 |
+
modified_codebook = torch.where(
|
763 |
+
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
|
764 |
+
)
|
765 |
+
self.embed.data.copy_(modified_codebook)
|
766 |
+
|
767 |
+
def expire_codes_(self, batch_samples):
|
768 |
+
if self.threshold_ema_dead_code == 0:
|
769 |
+
return
|
770 |
+
|
771 |
+
expired_codes = self.cluster_size < self.threshold_ema_dead_code
|
772 |
+
if not torch.any(expired_codes):
|
773 |
+
return
|
774 |
+
|
775 |
+
batch_samples = rearrange(batch_samples, "... d -> (...) d")
|
776 |
+
self.replace_(batch_samples, mask=expired_codes)
|
777 |
+
flashy.distrib.broadcast_tensors(self.buffers())
|
778 |
+
|
779 |
+
def preprocess(self, x):
|
780 |
+
x = rearrange(x, "... d -> (...) d")
|
781 |
+
return x
|
782 |
+
|
783 |
+
def quantize(self, x):
|
784 |
+
embed = self.embed.t()
|
785 |
+
dist = -(
|
786 |
+
x.pow(2).sum(1, keepdim=True)
|
787 |
+
- 2 * x @ embed
|
788 |
+
+ embed.pow(2).sum(0, keepdim=True)
|
789 |
+
)
|
790 |
+
embed_ind = dist.max(dim=-1).indices
|
791 |
+
return embed_ind
|
792 |
+
|
793 |
+
def postprocess_emb(self, embed_ind, shape):
|
794 |
+
return embed_ind.view(*shape[:-1])
|
795 |
+
|
796 |
+
def dequantize(self, embed_ind):
|
797 |
+
quantize = F.embedding(embed_ind, self.embed)
|
798 |
+
return quantize
|
799 |
+
|
800 |
+
def encode(self, x):
|
801 |
+
shape = x.shape
|
802 |
+
# pre-process
|
803 |
+
x = self.preprocess(x)
|
804 |
+
# quantize
|
805 |
+
embed_ind = self.quantize(x)
|
806 |
+
# post-process
|
807 |
+
embed_ind = self.postprocess_emb(embed_ind, shape)
|
808 |
+
return embed_ind
|
809 |
+
|
810 |
+
def decode(self, embed_ind):
|
811 |
+
quantize = self.dequantize(embed_ind)
|
812 |
+
return quantize
|
813 |
+
|
814 |
+
def forward(self, x):
|
815 |
+
raise NotImplementedError()
|
816 |
+
shape, dtype = x.shape, x.dtype
|
817 |
+
x = self.preprocess(x)
|
818 |
+
self.init_embed_(x)
|
819 |
+
|
820 |
+
embed_ind = self.quantize(x)
|
821 |
+
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
|
822 |
+
embed_ind = self.postprocess_emb(embed_ind, shape)
|
823 |
+
quantize = self.dequantize(embed_ind)
|
824 |
+
|
825 |
+
if self.training:
|
826 |
+
# We do the expiry of code at that point as buffers are in sync
|
827 |
+
# and all the workers will take the same decision.
|
828 |
+
self.expire_codes_(x)
|
829 |
+
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
|
830 |
+
embed_sum = x.t() @ embed_onehot
|
831 |
+
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
|
832 |
+
cluster_size = (
|
833 |
+
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
|
834 |
+
* self.cluster_size.sum()
|
835 |
+
)
|
836 |
+
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
837 |
+
self.embed.data.copy_(embed_normalized)
|
838 |
+
|
839 |
+
return quantize, embed_ind
|
840 |
+
|
841 |
+
|
842 |
+
class VectorQuantization(nn.Module):
|
843 |
+
"""Vector quantization implementation.
|
844 |
+
Currently supports only euclidean distance.
|
845 |
+
|
846 |
+
Args:
|
847 |
+
dim (int): Dimension
|
848 |
+
codebook_size (int): Codebook size
|
849 |
+
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
|
850 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
851 |
+
epsilon (float): Epsilon value for numerical stability.
|
852 |
+
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
853 |
+
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
854 |
+
threshold_ema_dead_code (int):
|
855 |
+
channels_last (bool): Channels are the last dimension in the input tensors.
|
856 |
+
commitment_weight (float): Weight for commitment loss.
|
857 |
+
orthogonal_reg_weight (float): Orthogonal regularization weights.
|
858 |
+
orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes.
|
859 |
+
orthogonal_reg_max_codes (optional int): Maximum number of codes to consider
|
860 |
+
for orthogonal regularization.
|
861 |
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
862 |
+
that have an exponential moving average cluster size less than the specified threshold with
|
863 |
+
randomly selected vector from the current batch.
|
864 |
+
"""
|
865 |
+
def __init__(
|
866 |
+
self,
|
867 |
+
dim: int,
|
868 |
+
codebook_size: int,
|
869 |
+
codebook_dim: tp.Optional[int] = None,
|
870 |
+
decay: float = 0.8,
|
871 |
+
epsilon: float = 1e-5,
|
872 |
+
kmeans_init: bool = False,
|
873 |
+
kmeans_iters: int = 10,
|
874 |
+
threshold_ema_dead_code: int = 2,
|
875 |
+
channels_last: bool = False,
|
876 |
+
commitment_weight: float = 1.,
|
877 |
+
orthogonal_reg_weight: float = 0.0,
|
878 |
+
orthogonal_reg_active_codes_only: bool = False,
|
879 |
+
orthogonal_reg_max_codes: tp.Optional[int] = None,
|
880 |
+
):
|
881 |
+
super().__init__()
|
882 |
+
_codebook_dim: int = default(codebook_dim, dim)
|
883 |
+
|
884 |
+
requires_projection = _codebook_dim != dim
|
885 |
+
self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity())
|
886 |
+
self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity())
|
887 |
+
|
888 |
+
self.epsilon = epsilon
|
889 |
+
self.commitment_weight = commitment_weight
|
890 |
+
|
891 |
+
self.orthogonal_reg_weight = orthogonal_reg_weight
|
892 |
+
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
|
893 |
+
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
|
894 |
+
|
895 |
+
self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
|
896 |
+
kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
|
897 |
+
decay=decay, epsilon=epsilon,
|
898 |
+
threshold_ema_dead_code=threshold_ema_dead_code)
|
899 |
+
self.codebook_size = codebook_size
|
900 |
+
|
901 |
+
self.channels_last = channels_last
|
902 |
+
|
903 |
+
@property
|
904 |
+
def codebook(self):
|
905 |
+
return self._codebook.embed
|
906 |
+
|
907 |
+
@property
|
908 |
+
def inited(self):
|
909 |
+
return self._codebook.inited
|
910 |
+
|
911 |
+
def _preprocess(self, x):
|
912 |
+
if not self.channels_last:
|
913 |
+
x = rearrange(x, "b d n -> b n d")
|
914 |
+
return x
|
915 |
+
|
916 |
+
def _postprocess(self, quantize):
|
917 |
+
if not self.channels_last:
|
918 |
+
quantize = rearrange(quantize, "b n d -> b d n")
|
919 |
+
return quantize
|
920 |
+
|
921 |
+
def encode(self, x):
|
922 |
+
x = self._preprocess(x)
|
923 |
+
x = self.project_in(x)
|
924 |
+
embed_in = self._codebook.encode(x)
|
925 |
+
return embed_in
|
926 |
+
|
927 |
+
def decode(self, embed_ind):
|
928 |
+
quantize = self._codebook.decode(embed_ind)
|
929 |
+
quantize = self.project_out(quantize)
|
930 |
+
quantize = self._postprocess(quantize)
|
931 |
+
return quantize
|
932 |
+
|
933 |
+
def forward(self, x):
|
934 |
+
device = x.device
|
935 |
+
x = self._preprocess(x)
|
936 |
+
|
937 |
+
x = self.project_in(x)
|
938 |
+
quantize, embed_ind = self._codebook(x)
|
939 |
+
|
940 |
+
if self.training:
|
941 |
+
quantize = x + (quantize - x).detach()
|
942 |
+
|
943 |
+
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
|
944 |
+
|
945 |
+
if self.training:
|
946 |
+
if self.commitment_weight > 0:
|
947 |
+
commit_loss = F.mse_loss(quantize.detach(), x)
|
948 |
+
loss = loss + commit_loss * self.commitment_weight
|
949 |
+
|
950 |
+
if self.orthogonal_reg_weight > 0:
|
951 |
+
codebook = self.codebook
|
952 |
+
|
953 |
+
if self.orthogonal_reg_active_codes_only:
|
954 |
+
# only calculate orthogonal loss for the activated codes for this batch
|
955 |
+
unique_code_ids = torch.unique(embed_ind)
|
956 |
+
codebook = codebook[unique_code_ids]
|
957 |
+
|
958 |
+
num_codes = codebook.shape[0]
|
959 |
+
if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes:
|
960 |
+
rand_ids = torch.randperm(num_codes, device=device)[:self.orthogonal_reg_max_codes]
|
961 |
+
codebook = codebook[rand_ids]
|
962 |
+
|
963 |
+
orthogonal_reg_loss = orthogonal_loss_fn(codebook)
|
964 |
+
loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight
|
965 |
+
|
966 |
+
quantize = self.project_out(quantize)
|
967 |
+
quantize = self._postprocess(quantize)
|
968 |
+
|
969 |
+
return quantize, embed_ind, loss
|
970 |
+
|
971 |
+
|
972 |
+
class ResidualVectorQuantization(nn.Module):
|
973 |
+
"""Residual vector quantization implementation.
|
974 |
+
|
975 |
+
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
|
976 |
+
"""
|
977 |
+
def __init__(self, *, num_quantizers, **kwargs):
|
978 |
+
super().__init__()
|
979 |
+
codebook_size = kwargs.pop('codebook_size', None)
|
980 |
+
if codebook_size is None:
|
981 |
+
raise ValueError("codebook_size must be provided in kwargs")
|
982 |
+
if type(codebook_size) != list:
|
983 |
+
codebook_size = [codebook_size] * num_quantizers
|
984 |
+
self.layers = nn.ModuleList(
|
985 |
+
[VectorQuantization(codebook_size=cur_codebook_size, **kwargs) for _,cur_codebook_size in zip(range(num_quantizers), codebook_size)]
|
986 |
+
)
|
987 |
+
|
988 |
+
|
989 |
+
# self.layers = nn.ModuleList(
|
990 |
+
# [VectorQuantization(**kwargs) for _ in range(num_quantizers)]
|
991 |
+
# )
|
992 |
+
|
993 |
+
def forward(self, x, n_q: tp.Optional[int] = None):
|
994 |
+
quantized_out = 0.0
|
995 |
+
residual = x
|
996 |
+
|
997 |
+
all_losses = []
|
998 |
+
all_indices = []
|
999 |
+
|
1000 |
+
n_q = n_q or len(self.layers)
|
1001 |
+
|
1002 |
+
for i, layer in enumerate(self.layers[:n_q]):
|
1003 |
+
quantized, indices, loss = layer(residual)
|
1004 |
+
residual = residual - quantized
|
1005 |
+
quantized_out = quantized_out + quantized
|
1006 |
+
all_indices.append(indices)
|
1007 |
+
all_losses.append(loss)
|
1008 |
+
|
1009 |
+
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
|
1010 |
+
return quantized_out, out_indices, out_losses
|
1011 |
+
|
1012 |
+
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
|
1013 |
+
residual = x
|
1014 |
+
all_indices = []
|
1015 |
+
n_q = n_q or len(self.layers)
|
1016 |
+
for layer in self.layers[:n_q]:
|
1017 |
+
indices = layer.encode(residual)
|
1018 |
+
quantized = layer.decode(indices)
|
1019 |
+
# the original code is below
|
1020 |
+
# since quantize has the gradient of residual, according to line 321
|
1021 |
+
# quantize = x + (quantize - x).detach()
|
1022 |
+
# the code below will make commitment loss to be 0 for all codebooks except for codebook1
|
1023 |
+
# https://github.com/facebookresearch/encodec/issues/25
|
1024 |
+
# therefore we change it
|
1025 |
+
|
1026 |
+
residual = residual - quantized
|
1027 |
+
# residual = residual - quantized.detach()
|
1028 |
+
# since commitment loss is averaged, the scale of the loss won't get change (not as said in the issue above)
|
1029 |
+
all_indices.append(indices)
|
1030 |
+
out_indices = torch.stack(all_indices)
|
1031 |
+
return out_indices
|
1032 |
+
|
1033 |
+
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
|
1034 |
+
quantized_out = torch.tensor(0.0, device=q_indices.device)
|
1035 |
+
for i, indices in enumerate(q_indices):
|
1036 |
+
layer = self.layers[i]
|
1037 |
+
quantized = layer.decode(indices)
|
1038 |
+
quantized_out = quantized_out + quantized
|
1039 |
+
return quantized_out
|
1040 |
+
|
1041 |
+
|
1042 |
+
class ResidualVectorQuantizer(BaseQuantizer):
|
1043 |
+
"""Residual Vector Quantizer.
|
1044 |
+
|
1045 |
+
Args:
|
1046 |
+
dimension (int): Dimension of the codebooks.
|
1047 |
+
n_q (int): Number of residual vector quantizers used.
|
1048 |
+
q_dropout (bool): Random quantizer drop out at train time.
|
1049 |
+
bins (int): Codebook size.
|
1050 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
1051 |
+
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
1052 |
+
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
1053 |
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
1054 |
+
that have an exponential moving average cluster size less than the specified threshold with
|
1055 |
+
randomly selected vector from the current batch.
|
1056 |
+
orthogonal_reg_weight (float): Orthogonal regularization weights.
|
1057 |
+
orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes.
|
1058 |
+
orthogonal_reg_max_codes (optional int): Maximum number of codes to consider.
|
1059 |
+
for orthogonal regularization.
|
1060 |
+
"""
|
1061 |
+
def __init__(
|
1062 |
+
self,
|
1063 |
+
dimension: int = 256,
|
1064 |
+
n_q: int = 8,
|
1065 |
+
q_dropout: bool = False,
|
1066 |
+
bins: tp.Union[int, tp.List[int]] = 1024,
|
1067 |
+
decay: float = 0.99,
|
1068 |
+
kmeans_init: bool = True,
|
1069 |
+
kmeans_iters: int = 10,
|
1070 |
+
threshold_ema_dead_code: int = 2,
|
1071 |
+
orthogonal_reg_weight: float = 0.0,
|
1072 |
+
orthogonal_reg_active_codes_only: bool = False,
|
1073 |
+
orthogonal_reg_max_codes: tp.Optional[int] = None,
|
1074 |
+
):
|
1075 |
+
super().__init__()
|
1076 |
+
self.max_n_q = n_q
|
1077 |
+
self.n_q = n_q
|
1078 |
+
self.q_dropout = q_dropout
|
1079 |
+
self.dimension = dimension
|
1080 |
+
self.bins = bins
|
1081 |
+
self.decay = decay
|
1082 |
+
self.kmeans_init = kmeans_init
|
1083 |
+
self.kmeans_iters = kmeans_iters
|
1084 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
1085 |
+
self.orthogonal_reg_weight = orthogonal_reg_weight
|
1086 |
+
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
|
1087 |
+
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
|
1088 |
+
self.vq = ResidualVectorQuantization(
|
1089 |
+
dim=self.dimension,
|
1090 |
+
codebook_size=self.bins,
|
1091 |
+
num_quantizers=self.n_q,
|
1092 |
+
decay=self.decay,
|
1093 |
+
kmeans_init=self.kmeans_init,
|
1094 |
+
kmeans_iters=self.kmeans_iters,
|
1095 |
+
threshold_ema_dead_code=self.threshold_ema_dead_code,
|
1096 |
+
orthogonal_reg_weight=self.orthogonal_reg_weight,
|
1097 |
+
orthogonal_reg_active_codes_only=self.orthogonal_reg_active_codes_only,
|
1098 |
+
orthogonal_reg_max_codes=self.orthogonal_reg_max_codes,
|
1099 |
+
channels_last=False
|
1100 |
+
)
|
1101 |
+
|
1102 |
+
def forward(self, x: torch.Tensor, frame_rate: int):
|
1103 |
+
n_q = self.n_q
|
1104 |
+
if self.training and self.q_dropout:
|
1105 |
+
n_q = int(torch.randint(1, self.n_q + 1, (1,)).item())
|
1106 |
+
if type(self.bins) == list:
|
1107 |
+
bins = self.bins
|
1108 |
+
else:
|
1109 |
+
bins = [self.bins] * self.n_q
|
1110 |
+
bw_per_q = [math.log2(bin) * frame_rate / 1000 for bin in bins]
|
1111 |
+
bw = torch.tensor(sum(bw_per_q)).to(x)
|
1112 |
+
quantized, codes, commit_loss = self.vq(x, n_q=n_q)
|
1113 |
+
codes = codes.transpose(0, 1)
|
1114 |
+
# codes is [B, K, T], with T frames, K nb of codebooks.
|
1115 |
+
return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss))
|
1116 |
+
|
1117 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
1118 |
+
"""Encode a given input tensor with the specified frame rate at the given bandwidth.
|
1119 |
+
The RVQ encode method sets the appropriate number of quantizer to use
|
1120 |
+
and returns indices for each quantizer.
|
1121 |
+
"""
|
1122 |
+
n_q = self.n_q
|
1123 |
+
codes = self.vq.encode(x, n_q=n_q)
|
1124 |
+
codes = codes.transpose(0, 1)
|
1125 |
+
# codes is [B, K, T], with T frames, K nb of codebooks.
|
1126 |
+
return codes
|
1127 |
+
|
1128 |
+
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
1129 |
+
"""Decode the given codes to the quantized representation."""
|
1130 |
+
# codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T].
|
1131 |
+
codes = codes.transpose(0, 1)
|
1132 |
+
quantized = self.vq.decode(codes)
|
1133 |
+
return quantized
|
1134 |
+
|
1135 |
+
@property
|
1136 |
+
def total_codebooks(self):
|
1137 |
+
return self.max_n_q
|
1138 |
+
|
1139 |
+
@property
|
1140 |
+
def num_codebooks(self):
|
1141 |
+
return self.n_q
|
1142 |
+
|
1143 |
+
def set_num_codebooks(self, n: int):
|
1144 |
+
assert n > 0 and n <= self.max_n_q
|
1145 |
+
self.n_q = n
|
1146 |
+
|
1147 |
+
class DummyQuantizer(BaseQuantizer):
|
1148 |
+
"""Fake quantizer that actually does not perform any quantization.
|
1149 |
+
"""
|
1150 |
+
def __init__(self):
|
1151 |
+
super().__init__()
|
1152 |
+
|
1153 |
+
def forward(self, x: torch.Tensor, frame_rate: int):
|
1154 |
+
q = x.unsqueeze(1)
|
1155 |
+
return QuantizedResult(x, q, torch.tensor(q.numel() * 32 * frame_rate / 1000 / len(x)).to(x))
|
1156 |
+
|
1157 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
1158 |
+
"""Encode a given input tensor with the specified sample rate at the given bandwidth.
|
1159 |
+
In the case of the DummyQuantizer, the codes are actually identical
|
1160 |
+
to the input and resulting quantized representation as no quantization is done.
|
1161 |
+
"""
|
1162 |
+
return x.unsqueeze(1)
|
1163 |
+
|
1164 |
+
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
1165 |
+
"""Decode the given codes to the quantized representation.
|
1166 |
+
In the case of the DummyQuantizer, the codes are actually identical
|
1167 |
+
to the input and resulting quantized representation as no quantization is done.
|
1168 |
+
"""
|
1169 |
+
return codes.squeeze(1)
|
1170 |
+
|
1171 |
+
@property
|
1172 |
+
def total_codebooks(self):
|
1173 |
+
"""Total number of codebooks."""
|
1174 |
+
return 1
|
1175 |
+
|
1176 |
+
@property
|
1177 |
+
def num_codebooks(self):
|
1178 |
+
"""Total number of codebooks."""
|
1179 |
+
return self.total_codebooks
|
1180 |
+
|
1181 |
+
def set_num_codebooks(self, n: int):
|
1182 |
+
"""Set the number of active codebooks."""
|
1183 |
+
raise AttributeError("Cannot override the number of codebooks for the dummy quantizer")
|
1184 |
+
|
1185 |
+
|
1186 |
+
class EncodecModel(CompressionModel):
|
1187 |
+
"""Encodec model operating on the raw waveform.
|
1188 |
+
|
1189 |
+
Args:
|
1190 |
+
encoder (nn.Module): Encoder network.
|
1191 |
+
decoder (nn.Module): Decoder network.
|
1192 |
+
quantizer (BaseQuantizer): Quantizer network.
|
1193 |
+
frame_rate (int): Frame rate for the latent representation.
|
1194 |
+
sample_rate (int): Audio sample rate.
|
1195 |
+
channels (int): Number of audio channels.
|
1196 |
+
causal (bool): Whether to use a causal version of the model.
|
1197 |
+
renormalize (bool): Whether to renormalize the audio before running the model.
|
1198 |
+
"""
|
1199 |
+
# we need assignment to override the property in the abstract class,
|
1200 |
+
# I couldn't find a better way...
|
1201 |
+
frame_rate: float = 0
|
1202 |
+
sample_rate: int = 0
|
1203 |
+
channels: int = 0
|
1204 |
+
|
1205 |
+
def __init__(self,
|
1206 |
+
encoder: nn.Module,
|
1207 |
+
decoder: nn.Module,
|
1208 |
+
quantizer: BaseQuantizer,
|
1209 |
+
frame_rate: int,
|
1210 |
+
sample_rate: int,
|
1211 |
+
channels: int,
|
1212 |
+
causal: bool = False,
|
1213 |
+
renormalize: bool = False):
|
1214 |
+
super().__init__()
|
1215 |
+
self.encoder = encoder
|
1216 |
+
self.decoder = decoder
|
1217 |
+
self.quantizer = quantizer
|
1218 |
+
self.frame_rate = frame_rate
|
1219 |
+
self.sample_rate = sample_rate
|
1220 |
+
self.channels = channels
|
1221 |
+
self.renormalize = renormalize
|
1222 |
+
self.causal = causal
|
1223 |
+
if self.causal:
|
1224 |
+
# we force disabling here to avoid handling linear overlap of segments
|
1225 |
+
# as supported in original EnCodec codebase.
|
1226 |
+
assert not self.renormalize, 'Causal model does not support renormalize'
|
1227 |
+
|
1228 |
+
@property
|
1229 |
+
def total_codebooks(self):
|
1230 |
+
"""Total number of quantizer codebooks available."""
|
1231 |
+
return self.quantizer.total_codebooks
|
1232 |
+
|
1233 |
+
@property
|
1234 |
+
def num_codebooks(self):
|
1235 |
+
"""Active number of codebooks used by the quantizer."""
|
1236 |
+
return self.quantizer.num_codebooks
|
1237 |
+
|
1238 |
+
def set_num_codebooks(self, n: int):
|
1239 |
+
"""Set the active number of codebooks used by the quantizer."""
|
1240 |
+
self.quantizer.set_num_codebooks(n)
|
1241 |
+
|
1242 |
+
@property
|
1243 |
+
def cardinality(self):
|
1244 |
+
"""Cardinality of each codebook."""
|
1245 |
+
return self.quantizer.bins
|
1246 |
+
|
1247 |
+
def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
1248 |
+
scale: tp.Optional[torch.Tensor]
|
1249 |
+
if self.renormalize:
|
1250 |
+
mono = x.mean(dim=1, keepdim=True)
|
1251 |
+
volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt()
|
1252 |
+
scale = 1e-8 + volume
|
1253 |
+
x = x / scale
|
1254 |
+
scale = scale.view(-1, 1)
|
1255 |
+
else:
|
1256 |
+
scale = None
|
1257 |
+
return x, scale
|
1258 |
+
|
1259 |
+
def postprocess(self,
|
1260 |
+
x: torch.Tensor,
|
1261 |
+
scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
|
1262 |
+
if scale is not None:
|
1263 |
+
assert self.renormalize
|
1264 |
+
x = x * scale.view(-1, 1, 1)
|
1265 |
+
return x
|
1266 |
+
|
1267 |
+
def forward(self, x: torch.Tensor, encode=False) -> QuantizedResult:
|
1268 |
+
if encode:
|
1269 |
+
return self.encode(x)
|
1270 |
+
else:
|
1271 |
+
raise NotImplementedError("model forward and training is not supported.")
|
1272 |
+
assert x.dim() == 3
|
1273 |
+
length = x.shape[-1]
|
1274 |
+
x, scale = self.preprocess(x)
|
1275 |
+
|
1276 |
+
emb = self.encoder(x)
|
1277 |
+
q_res = self.quantizer(emb, self.frame_rate)
|
1278 |
+
out = self.decoder(q_res.x)
|
1279 |
+
|
1280 |
+
# remove extra padding added by the encoder and decoder
|
1281 |
+
assert out.shape[-1] >= length, (out.shape[-1], length)
|
1282 |
+
out = out[..., :length]
|
1283 |
+
|
1284 |
+
q_res.x = self.postprocess(out, scale)
|
1285 |
+
|
1286 |
+
return q_res
|
1287 |
+
|
1288 |
+
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
1289 |
+
"""Encode the given input tensor to quantized representation along with scale parameter.
|
1290 |
+
|
1291 |
+
Args:
|
1292 |
+
x (torch.Tensor): Float tensor of shape [B, C, T]
|
1293 |
+
|
1294 |
+
Returns:
|
1295 |
+
codes, scale (tuple of torch.Tensor, torch.Tensor): Tuple composed of:
|
1296 |
+
codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep.
|
1297 |
+
scale a float tensor containing the scale for audio renormalizealization.
|
1298 |
+
"""
|
1299 |
+
assert x.dim() == 3
|
1300 |
+
x, scale = self.preprocess(x)
|
1301 |
+
emb = self.encoder(x)
|
1302 |
+
codes = self.quantizer.encode(emb)
|
1303 |
+
return codes, scale
|
1304 |
+
|
1305 |
+
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
1306 |
+
"""Decode the given codes to a reconstructed representation, using the scale to perform
|
1307 |
+
audio denormalization if needed.
|
1308 |
+
|
1309 |
+
Args:
|
1310 |
+
codes (torch.Tensor): Int tensor of shape [B, K, T]
|
1311 |
+
scale (torch.Tensor, optional): Float tensor containing the scale value.
|
1312 |
+
|
1313 |
+
Returns:
|
1314 |
+
out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio.
|
1315 |
+
"""
|
1316 |
+
emb = self.decode_latent(codes)
|
1317 |
+
out = self.decoder(emb)
|
1318 |
+
out = self.postprocess(out, scale)
|
1319 |
+
# out contains extra padding added by the encoder and decoder
|
1320 |
+
return out
|
1321 |
+
|
1322 |
+
def decode_latent(self, codes: torch.Tensor):
|
1323 |
+
"""Decode from the discrete codes to continuous latent space."""
|
1324 |
+
return self.quantizer.decode(codes)
|
1325 |
+
|
1326 |
+
class EncodecModel_encode_only(CompressionModel):
|
1327 |
+
"""Encodec model operating on the raw waveform. Encode only, so no decoder
|
1328 |
+
|
1329 |
+
Args:
|
1330 |
+
encoder (nn.Module): Encoder network.
|
1331 |
+
quantizer (BaseQuantizer): Quantizer network.
|
1332 |
+
frame_rate (int): Frame rate for the latent representation.
|
1333 |
+
sample_rate (int): Audio sample rate.
|
1334 |
+
channels (int): Number of audio channels.
|
1335 |
+
causal (bool): Whether to use a causal version of the model.
|
1336 |
+
renormalize (bool): Whether to renormalize the audio before running the model.
|
1337 |
+
"""
|
1338 |
+
# we need assignment to override the property in the abstract class,
|
1339 |
+
# I couldn't find a better way...
|
1340 |
+
frame_rate: float = 0
|
1341 |
+
sample_rate: int = 0
|
1342 |
+
channels: int = 0
|
1343 |
+
|
1344 |
+
def __init__(self,
|
1345 |
+
encoder: nn.Module,
|
1346 |
+
quantizer: BaseQuantizer,
|
1347 |
+
frame_rate: int,
|
1348 |
+
sample_rate: int,
|
1349 |
+
channels: int,
|
1350 |
+
causal: bool = False,
|
1351 |
+
renormalize: bool = False):
|
1352 |
+
super().__init__()
|
1353 |
+
self.encoder = encoder
|
1354 |
+
self.quantizer = quantizer
|
1355 |
+
self.frame_rate = frame_rate
|
1356 |
+
self.sample_rate = sample_rate
|
1357 |
+
self.channels = channels
|
1358 |
+
self.renormalize = renormalize
|
1359 |
+
self.causal = causal
|
1360 |
+
if self.causal:
|
1361 |
+
# we force disabling here to avoid handling linear overlap of segments
|
1362 |
+
# as supported in original EnCodec codebase.
|
1363 |
+
assert not self.renormalize, 'Causal model does not support renormalize'
|
1364 |
+
|
1365 |
+
@property
|
1366 |
+
def total_codebooks(self):
|
1367 |
+
"""Total number of quantizer codebooks available."""
|
1368 |
+
return self.quantizer.total_codebooks
|
1369 |
+
|
1370 |
+
@property
|
1371 |
+
def num_codebooks(self):
|
1372 |
+
"""Active number of codebooks used by the quantizer."""
|
1373 |
+
return self.quantizer.num_codebooks
|
1374 |
+
|
1375 |
+
def set_num_codebooks(self, n: int):
|
1376 |
+
"""Set the active number of codebooks used by the quantizer."""
|
1377 |
+
self.quantizer.set_num_codebooks(n)
|
1378 |
+
|
1379 |
+
@property
|
1380 |
+
def cardinality(self):
|
1381 |
+
"""Cardinality of each codebook."""
|
1382 |
+
return self.quantizer.bins
|
1383 |
+
|
1384 |
+
def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
1385 |
+
scale: tp.Optional[torch.Tensor]
|
1386 |
+
if self.renormalize:
|
1387 |
+
mono = x.mean(dim=1, keepdim=True)
|
1388 |
+
volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt()
|
1389 |
+
scale = 1e-8 + volume
|
1390 |
+
x = x / scale
|
1391 |
+
scale = scale.view(-1, 1)
|
1392 |
+
else:
|
1393 |
+
scale = None
|
1394 |
+
return x, scale
|
1395 |
+
|
1396 |
+
def postprocess(self,
|
1397 |
+
x: torch.Tensor,
|
1398 |
+
scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
|
1399 |
+
if scale is not None:
|
1400 |
+
assert self.renormalize
|
1401 |
+
x = x * scale.view(-1, 1, 1)
|
1402 |
+
return x
|
1403 |
+
|
1404 |
+
def forward(self, x: torch.Tensor, encode=False) -> QuantizedResult:
|
1405 |
+
if encode:
|
1406 |
+
return self.encode(x)
|
1407 |
+
else:
|
1408 |
+
raise NotImplementedError("model forward and training is not supported.")
|
1409 |
+
assert x.dim() == 3
|
1410 |
+
length = x.shape[-1]
|
1411 |
+
x, scale = self.preprocess(x)
|
1412 |
+
|
1413 |
+
emb = self.encoder(x)
|
1414 |
+
q_res = self.quantizer(emb, self.frame_rate)
|
1415 |
+
out = self.decoder(q_res.x)
|
1416 |
+
|
1417 |
+
# remove extra padding added by the encoder and decoder
|
1418 |
+
assert out.shape[-1] >= length, (out.shape[-1], length)
|
1419 |
+
out = out[..., :length]
|
1420 |
+
|
1421 |
+
q_res.x = self.postprocess(out, scale)
|
1422 |
+
|
1423 |
+
return q_res
|
1424 |
+
|
1425 |
+
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
1426 |
+
"""Encode the given input tensor to quantized representation along with scale parameter.
|
1427 |
+
|
1428 |
+
Args:
|
1429 |
+
x (torch.Tensor): Float tensor of shape [B, C, T]
|
1430 |
+
|
1431 |
+
Returns:
|
1432 |
+
codes, scale (tuple of torch.Tensor, torch.Tensor): Tuple composed of:
|
1433 |
+
codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep.
|
1434 |
+
scale a float tensor containing the scale for audio renormalizealization.
|
1435 |
+
"""
|
1436 |
+
assert x.dim() == 3
|
1437 |
+
x, scale = self.preprocess(x)
|
1438 |
+
emb = self.encoder(x)
|
1439 |
+
codes = self.quantizer.encode(emb)
|
1440 |
+
return codes, scale
|
1441 |
+
|
1442 |
+
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
1443 |
+
"""Decode the given codes to a reconstructed representation, using the scale to perform
|
1444 |
+
audio denormalization if needed.
|
1445 |
+
|
1446 |
+
Args:
|
1447 |
+
codes (torch.Tensor): Int tensor of shape [B, K, T]
|
1448 |
+
scale (torch.Tensor, optional): Float tensor containing the scale value.
|
1449 |
+
|
1450 |
+
Returns:
|
1451 |
+
out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio.
|
1452 |
+
"""
|
1453 |
+
raise NotImplementedError("Decode is not supported for encode only model")
|
1454 |
+
emb = self.decode_latent(codes)
|
1455 |
+
out = self.decoder(emb)
|
1456 |
+
out = self.postprocess(out, scale)
|
1457 |
+
# out contains extra padding added by the encoder and decoder
|
1458 |
+
return out
|
1459 |
+
|
1460 |
+
def decode_latent(self, codes: torch.Tensor):
|
1461 |
+
"""Decode from the discrete codes to continuous latent space."""
|
1462 |
+
raise NotImplementedError("Decode is not supported for encode only model")
|
1463 |
+
return self.quantizer.decode(codes)
|
1464 |
+
|
1465 |
+
def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) -> BaseQuantizer:
|
1466 |
+
klass = {
|
1467 |
+
'no_quant': DummyQuantizer,
|
1468 |
+
'rvq': ResidualVectorQuantizer
|
1469 |
+
}[quantizer]
|
1470 |
+
kwargs = dict_from_config(getattr(cfg, quantizer))
|
1471 |
+
if quantizer != 'no_quant':
|
1472 |
+
kwargs['dimension'] = dimension
|
1473 |
+
return klass(**kwargs)
|
1474 |
+
|
1475 |
+
def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig):
|
1476 |
+
if encoder_name == 'seanet':
|
1477 |
+
kwargs = dict_from_config(getattr(cfg, 'seanet'))
|
1478 |
+
encoder_override_kwargs = kwargs.pop('encoder')
|
1479 |
+
decoder_override_kwargs = kwargs.pop('decoder')
|
1480 |
+
encoder_kwargs = {**kwargs, **encoder_override_kwargs}
|
1481 |
+
decoder_kwargs = {**kwargs, **decoder_override_kwargs}
|
1482 |
+
encoder = SEANetEncoder(**encoder_kwargs)
|
1483 |
+
decoder = SEANetDecoder(**decoder_kwargs)
|
1484 |
+
return encoder, decoder
|
1485 |
+
else:
|
1486 |
+
raise KeyError(f"Unexpected compression model {cfg.compression_model}")
|
1487 |
+
|
1488 |
+
|
1489 |
+
def get_compression_model(ckpt_fn, encode_only=False, device="cpu") -> CompressionModel:
|
1490 |
+
"""Instantiate a compression model."""
|
1491 |
+
if device == None:
|
1492 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
1493 |
+
state = torch.load(ckpt_fn, map_location='cpu')
|
1494 |
+
cfg = state['xp.cfg']
|
1495 |
+
cfg.device = str(device)
|
1496 |
+
weights = state['best_state']['model']
|
1497 |
+
assert cfg.compression_model == 'encodec', "Only Encodec model is supported for now."
|
1498 |
+
if encode_only:
|
1499 |
+
all_keys = list(weights.keys())
|
1500 |
+
for key in all_keys:
|
1501 |
+
if key.startswith('decoder'):
|
1502 |
+
del weights[key]
|
1503 |
+
kwargs = dict_from_config(getattr(cfg, 'encodec'))
|
1504 |
+
encoder_name = kwargs.pop('autoencoder')
|
1505 |
+
quantizer_name = kwargs.pop('quantizer')
|
1506 |
+
encoder, _ = get_encodec_autoencoder(encoder_name, cfg)
|
1507 |
+
quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension)
|
1508 |
+
frame_rate = kwargs['sample_rate'] // encoder.hop_length
|
1509 |
+
renormalize = kwargs.pop('renormalize', False)
|
1510 |
+
# deprecated params
|
1511 |
+
kwargs.pop('renorm', None)
|
1512 |
+
compression_model = EncodecModel_encode_only(encoder, quantizer,
|
1513 |
+
frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device)
|
1514 |
+
assert compression_model.sample_rate == cfg.sample_rate, "Compression model sample rate should match"
|
1515 |
+
compression_model.load_state_dict(weights)
|
1516 |
+
compression_model.eval()
|
1517 |
+
return compression_model
|
1518 |
+
|
1519 |
+
else:
|
1520 |
+
kwargs = dict_from_config(getattr(cfg, 'encodec'))
|
1521 |
+
encoder_name = kwargs.pop('autoencoder')
|
1522 |
+
quantizer_name = kwargs.pop('quantizer')
|
1523 |
+
encoder, decoder = get_encodec_autoencoder(encoder_name, cfg)
|
1524 |
+
quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension)
|
1525 |
+
frame_rate = kwargs['sample_rate'] // encoder.hop_length
|
1526 |
+
renormalize = kwargs.pop('renormalize', False)
|
1527 |
+
# deprecated params
|
1528 |
+
kwargs.pop('renorm', None)
|
1529 |
+
compression_model = EncodecModel(encoder, decoder, quantizer,
|
1530 |
+
frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device)
|
1531 |
+
assert compression_model.sample_rate == cfg.sample_rate, "Compression model sample rate should match"
|
1532 |
+
compression_model.load_state_dict(weights)
|
1533 |
+
compression_model.eval()
|
1534 |
+
return compression_model
|
1535 |
+
|
1536 |
+
if __name__ == "__main__":
|
1537 |
+
import torchaudio
|
1538 |
+
ckpt_fn = "/home/pyp/BoostedVoiceEditor/pretrained/encodec_6f79c6a8.th"
|
1539 |
+
audio_in_fns = ["/home/pyp/BoostedVoiceEditor/demo/pam.wav", "/home/pyp/BoostedVoiceEditor/demo/ray.wav", "/home/pyp/BoostedVoiceEditor/demo/84_121550_000074_000000.wav", "/home/pyp/BoostedVoiceEditor/demo/caribbean.wav", "/home/pyp/BoostedVoiceEditor/demo/bible.wav", "/home/pyp/BoostedVoiceEditor/demo/miley.wav"]
|
1540 |
+
audio_out_fns = ["/home/pyp/BoostedVoiceEditor/demo/pam_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/ray_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/84_121550_000074_000000_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/caribbean_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/bible_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/miley_encodecTest.wav"]
|
1541 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
1542 |
+
model = get_compression_model(ckpt_fn, device=device)
|
1543 |
+
|
1544 |
+
for audio_in_fn, audio_out_fn in zip(audio_in_fns, audio_out_fns):
|
1545 |
+
audio_in, sr = torchaudio.load(audio_in_fn)
|
1546 |
+
if sr != model.sample_rate:
|
1547 |
+
audio_in = torchaudio.transforms.Resample(sr, model.sample_rate)(audio_in)
|
1548 |
+
if audio_in.shape[0] == 2:
|
1549 |
+
audio_in = audio_in.mean(dim=0, keepdim=True)
|
1550 |
+
audio_in = audio_in.unsqueeze(0)
|
1551 |
+
audio_in = audio_in.to(torch.float32).to(device)
|
1552 |
+
codes = model.encode(audio_in)[0]
|
1553 |
+
audio_out = model.decode(codes)[0].cpu()
|
1554 |
+
torchaudio.save(audio_out_fn, audio_out, model.sample_rate)
|
data/ll60k_preprocessing/config.yaml
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# WARNING: This is the base configuration file shared across ALL solvers in AudioCraft
|
2 |
+
# Please don't update this file directly. Instead use distinct configuration files
|
3 |
+
# to override the below configuration.
|
4 |
+
defaults:
|
5 |
+
- _self_
|
6 |
+
- dset: default
|
7 |
+
- solver: default
|
8 |
+
|
9 |
+
device: cuda
|
10 |
+
dtype: float32
|
11 |
+
autocast: false
|
12 |
+
autocast_dtype: bfloat16
|
13 |
+
seed: 2036
|
14 |
+
show: false # just show the model and its size and exit
|
15 |
+
continue_from: # continue from a given sig or path
|
16 |
+
execute_only: # can be set to generate/evaluate/valid to run that stage
|
17 |
+
execute_inplace: false # don't enforce continue_from to be set
|
18 |
+
# to enable inplace execution of the stage. This assume
|
19 |
+
# that you know what you are doing and execute stage
|
20 |
+
# preserving the original xp sig.
|
21 |
+
benchmark_no_load: false # if set to true, will repeat the same batch instead of loading them
|
22 |
+
|
23 |
+
efficient_attention_backend: torch # can be torch or xformers.
|
24 |
+
num_threads: 1 # called with torch.set_num_thread.
|
25 |
+
mp_start_method: forkserver # multiprocessing method (spawn, fork or fork_server).
|
26 |
+
|
27 |
+
|
28 |
+
label: # use this if you want twice the same exp, with a name.
|
29 |
+
|
30 |
+
# logging parameters
|
31 |
+
logging:
|
32 |
+
level: INFO
|
33 |
+
log_updates: 10
|
34 |
+
log_tensorboard: false
|
35 |
+
log_wandb: false
|
36 |
+
tensorboard:
|
37 |
+
with_media_logging: false
|
38 |
+
name: # optional name for the experiment
|
39 |
+
sub_dir: # optional sub directory to store tensorboard data
|
40 |
+
wandb:
|
41 |
+
with_media_logging: true
|
42 |
+
project: # project name
|
43 |
+
name: # optional name for the experiment
|
44 |
+
group: # optional group
|
45 |
+
|
46 |
+
# SLURM launcher configuration.
|
47 |
+
slurm:
|
48 |
+
gpus: 4 # convenience parameter, number of GPUs to use.
|
49 |
+
mem_per_gpu: 40 # in GB, total mem is automatically scaled with `gpus`.
|
50 |
+
time: 3600
|
51 |
+
constraint:
|
52 |
+
partition:
|
53 |
+
comment:
|
54 |
+
setup: []
|
55 |
+
exclude: ''
|
56 |
+
|
57 |
+
# dora parameters
|
58 |
+
dora:
|
59 |
+
# Output folder for all artifacts of an experiment.
|
60 |
+
dir: /checkpoint/${oc.env:USER}/experiments/audiocraft/outputs
|
61 |
+
# The following entries will be ignored by dora when computing the unique XP signature.
|
62 |
+
# Note that slurm.* and dora.* are automatically ignored.
|
63 |
+
exclude: [
|
64 |
+
'device', 'wandb.*', 'tensorboard.*', 'logging.*',
|
65 |
+
'dataset.num_workers', 'eval.num_workers', 'special.*',
|
66 |
+
'metrics.visqol.bin', 'metrics.fad.bin',
|
67 |
+
'execute_only', 'execute_best', 'generate.every',
|
68 |
+
'optim.eager_sync', 'profiler.*', 'deadlock.*',
|
69 |
+
'efficient_attention_backend', 'num_threads', 'mp_start_method',
|
70 |
+
]
|
71 |
+
use_rendezvous: false
|
72 |
+
# for grids, always run from a clean repo, allowing reliable runs and storing
|
73 |
+
# the exact commit. Your repo must be absolutely pristine clean.
|
74 |
+
# Local `dora run` are not impacted for easier debugging.
|
75 |
+
git_save: true
|
data/ll60k_preprocessing/encodec.py
ADDED
@@ -0,0 +1,1554 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Compression models or wrapper around existing models.
|
7 |
+
Also defines the main interface that a model must follow to be usable as an audio tokenizer.
|
8 |
+
"""
|
9 |
+
|
10 |
+
from abc import ABC, abstractmethod
|
11 |
+
from dataclasses import dataclass, field
|
12 |
+
import logging
|
13 |
+
import math
|
14 |
+
from pathlib import Path
|
15 |
+
import typing as tp
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
import torch
|
19 |
+
from torch import nn
|
20 |
+
from torch import einsum
|
21 |
+
import torch.nn.functional as F
|
22 |
+
from torch.nn.utils import spectral_norm, weight_norm
|
23 |
+
|
24 |
+
import logging
|
25 |
+
import warnings
|
26 |
+
from einops import rearrange, repeat
|
27 |
+
import omegaconf
|
28 |
+
# import flashy
|
29 |
+
|
30 |
+
CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
|
31 |
+
'time_group_norm'])
|
32 |
+
|
33 |
+
def dict_from_config(cfg: omegaconf.DictConfig) -> dict:
|
34 |
+
"""Convenience function to map an omegaconf configuration to a dictionary.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
cfg (omegaconf.DictConfig): Original configuration to map to dict.
|
38 |
+
Returns:
|
39 |
+
dict: Config as dictionary object.
|
40 |
+
"""
|
41 |
+
dct = omegaconf.OmegaConf.to_container(cfg, resolve=True)
|
42 |
+
assert isinstance(dct, dict)
|
43 |
+
return dct
|
44 |
+
|
45 |
+
@dataclass
|
46 |
+
class QuantizedResult:
|
47 |
+
x: torch.Tensor
|
48 |
+
codes: torch.Tensor
|
49 |
+
bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
|
50 |
+
penalty: tp.Optional[torch.Tensor] = None
|
51 |
+
metrics: dict = field(default_factory=dict)
|
52 |
+
|
53 |
+
class BaseQuantizer(nn.Module):
|
54 |
+
"""Base class for quantizers.
|
55 |
+
"""
|
56 |
+
|
57 |
+
def forward(self, x: torch.Tensor, frame_rate: int) -> QuantizedResult:
|
58 |
+
"""
|
59 |
+
Given input tensor x, returns first the quantized (or approximately quantized)
|
60 |
+
representation along with quantized codes, bandwidth, and any penalty term for the loss.
|
61 |
+
Finally, this returns a dict of metrics to update logging etc.
|
62 |
+
Frame rate must be passed so that the bandwidth is properly computed.
|
63 |
+
"""
|
64 |
+
raise NotImplementedError()
|
65 |
+
|
66 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
67 |
+
"""Encode a given input tensor with the specified sample rate at the given bandwidth."""
|
68 |
+
raise NotImplementedError()
|
69 |
+
|
70 |
+
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
71 |
+
"""Decode the given codes to the quantized representation."""
|
72 |
+
raise NotImplementedError()
|
73 |
+
|
74 |
+
@property
|
75 |
+
def total_codebooks(self):
|
76 |
+
"""Total number of codebooks."""
|
77 |
+
raise NotImplementedError()
|
78 |
+
|
79 |
+
@property
|
80 |
+
def num_codebooks(self):
|
81 |
+
"""Number of active codebooks."""
|
82 |
+
raise NotImplementedError()
|
83 |
+
|
84 |
+
def set_num_codebooks(self, n: int):
|
85 |
+
"""Set the number of active codebooks."""
|
86 |
+
raise NotImplementedError()
|
87 |
+
|
88 |
+
class CompressionModel(ABC, nn.Module):
|
89 |
+
"""Base API for all compression model that aim at being used as audio tokenizers
|
90 |
+
with a language model.
|
91 |
+
"""
|
92 |
+
|
93 |
+
@abstractmethod
|
94 |
+
def forward(self, x: torch.Tensor) -> QuantizedResult:
|
95 |
+
...
|
96 |
+
|
97 |
+
@abstractmethod
|
98 |
+
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
99 |
+
"""See `EncodecModel.encode`."""
|
100 |
+
...
|
101 |
+
|
102 |
+
@abstractmethod
|
103 |
+
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
104 |
+
"""See `EncodecModel.decode`."""
|
105 |
+
...
|
106 |
+
|
107 |
+
@abstractmethod
|
108 |
+
def decode_latent(self, codes: torch.Tensor):
|
109 |
+
"""Decode from the discrete codes to continuous latent space."""
|
110 |
+
...
|
111 |
+
|
112 |
+
@property
|
113 |
+
@abstractmethod
|
114 |
+
def channels(self) -> int:
|
115 |
+
...
|
116 |
+
|
117 |
+
@property
|
118 |
+
@abstractmethod
|
119 |
+
def frame_rate(self) -> float:
|
120 |
+
...
|
121 |
+
|
122 |
+
@property
|
123 |
+
@abstractmethod
|
124 |
+
def sample_rate(self) -> int:
|
125 |
+
...
|
126 |
+
|
127 |
+
@property
|
128 |
+
@abstractmethod
|
129 |
+
def cardinality(self) -> int:
|
130 |
+
...
|
131 |
+
|
132 |
+
@property
|
133 |
+
@abstractmethod
|
134 |
+
def num_codebooks(self) -> int:
|
135 |
+
...
|
136 |
+
|
137 |
+
@property
|
138 |
+
@abstractmethod
|
139 |
+
def total_codebooks(self) -> int:
|
140 |
+
...
|
141 |
+
|
142 |
+
@abstractmethod
|
143 |
+
def set_num_codebooks(self, n: int):
|
144 |
+
"""Set the active number of codebooks used by the quantizer."""
|
145 |
+
...
|
146 |
+
|
147 |
+
def apply_parametrization_norm(module: nn.Module, norm: str = 'none'):
|
148 |
+
assert norm in CONV_NORMALIZATIONS
|
149 |
+
if norm == 'weight_norm':
|
150 |
+
return weight_norm(module)
|
151 |
+
elif norm == 'spectral_norm':
|
152 |
+
return spectral_norm(module)
|
153 |
+
else:
|
154 |
+
# We already check was in CONV_NORMALIZATION, so any other choice
|
155 |
+
# doesn't need reparametrization.
|
156 |
+
return module
|
157 |
+
|
158 |
+
|
159 |
+
def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs):
|
160 |
+
"""Return the proper normalization module. If causal is True, this will ensure the returned
|
161 |
+
module is causal, or return an error if the normalization doesn't support causal evaluation.
|
162 |
+
"""
|
163 |
+
assert norm in CONV_NORMALIZATIONS
|
164 |
+
if norm == 'time_group_norm':
|
165 |
+
if causal:
|
166 |
+
raise ValueError("GroupNorm doesn't support causal evaluation.")
|
167 |
+
assert isinstance(module, nn.modules.conv._ConvNd)
|
168 |
+
return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
|
169 |
+
else:
|
170 |
+
return nn.Identity()
|
171 |
+
|
172 |
+
|
173 |
+
def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
|
174 |
+
padding_total: int = 0) -> int:
|
175 |
+
"""See `pad_for_conv1d`."""
|
176 |
+
length = x.shape[-1]
|
177 |
+
n_frames = (length - kernel_size + padding_total) / stride + 1
|
178 |
+
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
179 |
+
return ideal_length - length
|
180 |
+
|
181 |
+
|
182 |
+
def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
|
183 |
+
"""Pad for a convolution to make sure that the last window is full.
|
184 |
+
Extra padding is added at the end. This is required to ensure that we can rebuild
|
185 |
+
an output of the same length, as otherwise, even with padding, some time steps
|
186 |
+
might get removed.
|
187 |
+
For instance, with total padding = 4, kernel size = 4, stride = 2:
|
188 |
+
0 0 1 2 3 4 5 0 0 # (0s are padding)
|
189 |
+
1 2 3 # (output frames of a convolution, last 0 is never used)
|
190 |
+
0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
|
191 |
+
1 2 3 4 # once you removed padding, we are missing one time step !
|
192 |
+
"""
|
193 |
+
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
194 |
+
return F.pad(x, (0, extra_padding))
|
195 |
+
|
196 |
+
|
197 |
+
def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
|
198 |
+
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
|
199 |
+
If this is the case, we insert extra 0 padding to the right before the reflection happen.
|
200 |
+
"""
|
201 |
+
length = x.shape[-1]
|
202 |
+
padding_left, padding_right = paddings
|
203 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
204 |
+
if mode == 'reflect':
|
205 |
+
max_pad = max(padding_left, padding_right)
|
206 |
+
extra_pad = 0
|
207 |
+
if length <= max_pad:
|
208 |
+
extra_pad = max_pad - length + 1
|
209 |
+
x = F.pad(x, (0, extra_pad))
|
210 |
+
padded = F.pad(x, paddings, mode, value)
|
211 |
+
end = padded.shape[-1] - extra_pad
|
212 |
+
return padded[..., :end]
|
213 |
+
else:
|
214 |
+
return F.pad(x, paddings, mode, value)
|
215 |
+
|
216 |
+
|
217 |
+
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
|
218 |
+
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
|
219 |
+
padding_left, padding_right = paddings
|
220 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
221 |
+
assert (padding_left + padding_right) <= x.shape[-1]
|
222 |
+
end = x.shape[-1] - padding_right
|
223 |
+
return x[..., padding_left: end]
|
224 |
+
|
225 |
+
|
226 |
+
class NormConv1d(nn.Module):
|
227 |
+
"""Wrapper around Conv1d and normalization applied to this conv
|
228 |
+
to provide a uniform interface across normalization approaches.
|
229 |
+
"""
|
230 |
+
def __init__(self, *args, causal: bool = False, norm: str = 'none',
|
231 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
232 |
+
super().__init__()
|
233 |
+
self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
|
234 |
+
self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
|
235 |
+
self.norm_type = norm
|
236 |
+
|
237 |
+
def forward(self, x):
|
238 |
+
x = self.conv(x)
|
239 |
+
x = self.norm(x)
|
240 |
+
return x
|
241 |
+
|
242 |
+
|
243 |
+
class NormConv2d(nn.Module):
|
244 |
+
"""Wrapper around Conv2d and normalization applied to this conv
|
245 |
+
to provide a uniform interface across normalization approaches.
|
246 |
+
"""
|
247 |
+
def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
248 |
+
super().__init__()
|
249 |
+
self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
|
250 |
+
self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
|
251 |
+
self.norm_type = norm
|
252 |
+
|
253 |
+
def forward(self, x):
|
254 |
+
x = self.conv(x)
|
255 |
+
x = self.norm(x)
|
256 |
+
return x
|
257 |
+
|
258 |
+
|
259 |
+
class NormConvTranspose1d(nn.Module):
|
260 |
+
"""Wrapper around ConvTranspose1d and normalization applied to this conv
|
261 |
+
to provide a uniform interface across normalization approaches.
|
262 |
+
"""
|
263 |
+
def __init__(self, *args, causal: bool = False, norm: str = 'none',
|
264 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
265 |
+
super().__init__()
|
266 |
+
self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
|
267 |
+
self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
|
268 |
+
self.norm_type = norm
|
269 |
+
|
270 |
+
def forward(self, x):
|
271 |
+
x = self.convtr(x)
|
272 |
+
x = self.norm(x)
|
273 |
+
return x
|
274 |
+
|
275 |
+
|
276 |
+
class NormConvTranspose2d(nn.Module):
|
277 |
+
"""Wrapper around ConvTranspose2d and normalization applied to this conv
|
278 |
+
to provide a uniform interface across normalization approaches.
|
279 |
+
"""
|
280 |
+
def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
281 |
+
super().__init__()
|
282 |
+
self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
|
283 |
+
self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
|
284 |
+
|
285 |
+
def forward(self, x):
|
286 |
+
x = self.convtr(x)
|
287 |
+
x = self.norm(x)
|
288 |
+
return x
|
289 |
+
|
290 |
+
|
291 |
+
class StreamableConv1d(nn.Module):
|
292 |
+
"""Conv1d with some builtin handling of asymmetric or causal padding
|
293 |
+
and normalization.
|
294 |
+
"""
|
295 |
+
def __init__(self, in_channels: int, out_channels: int,
|
296 |
+
kernel_size: int, stride: int = 1, dilation: int = 1,
|
297 |
+
groups: int = 1, bias: bool = True, causal: bool = False,
|
298 |
+
norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
|
299 |
+
pad_mode: str = 'reflect'):
|
300 |
+
super().__init__()
|
301 |
+
# warn user on unusual setup between dilation and stride
|
302 |
+
if stride > 1 and dilation > 1:
|
303 |
+
warnings.warn("StreamableConv1d has been initialized with stride > 1 and dilation > 1"
|
304 |
+
f" (kernel_size={kernel_size} stride={stride}, dilation={dilation}).")
|
305 |
+
self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
|
306 |
+
dilation=dilation, groups=groups, bias=bias, causal=causal,
|
307 |
+
norm=norm, norm_kwargs=norm_kwargs)
|
308 |
+
self.causal = causal
|
309 |
+
self.pad_mode = pad_mode
|
310 |
+
|
311 |
+
def forward(self, x):
|
312 |
+
B, C, T = x.shape
|
313 |
+
kernel_size = self.conv.conv.kernel_size[0]
|
314 |
+
stride = self.conv.conv.stride[0]
|
315 |
+
dilation = self.conv.conv.dilation[0]
|
316 |
+
kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
|
317 |
+
padding_total = kernel_size - stride
|
318 |
+
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
319 |
+
if self.causal:
|
320 |
+
# Left padding for causal
|
321 |
+
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
|
322 |
+
else:
|
323 |
+
# Asymmetric padding required for odd strides
|
324 |
+
padding_right = padding_total // 2
|
325 |
+
padding_left = padding_total - padding_right
|
326 |
+
x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
|
327 |
+
return self.conv(x)
|
328 |
+
|
329 |
+
|
330 |
+
class StreamableConvTranspose1d(nn.Module):
|
331 |
+
"""ConvTranspose1d with some builtin handling of asymmetric or causal padding
|
332 |
+
and normalization.
|
333 |
+
"""
|
334 |
+
def __init__(self, in_channels: int, out_channels: int,
|
335 |
+
kernel_size: int, stride: int = 1, causal: bool = False,
|
336 |
+
norm: str = 'none', trim_right_ratio: float = 1.,
|
337 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}):
|
338 |
+
super().__init__()
|
339 |
+
self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
|
340 |
+
causal=causal, norm=norm, norm_kwargs=norm_kwargs)
|
341 |
+
self.causal = causal
|
342 |
+
self.trim_right_ratio = trim_right_ratio
|
343 |
+
assert self.causal or self.trim_right_ratio == 1., \
|
344 |
+
"`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
|
345 |
+
assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
|
346 |
+
|
347 |
+
def forward(self, x):
|
348 |
+
kernel_size = self.convtr.convtr.kernel_size[0]
|
349 |
+
stride = self.convtr.convtr.stride[0]
|
350 |
+
padding_total = kernel_size - stride
|
351 |
+
|
352 |
+
y = self.convtr(x)
|
353 |
+
|
354 |
+
# We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
|
355 |
+
# removed at the very end, when keeping only the right length for the output,
|
356 |
+
# as removing it here would require also passing the length at the matching layer
|
357 |
+
# in the encoder.
|
358 |
+
if self.causal:
|
359 |
+
# Trim the padding on the right according to the specified ratio
|
360 |
+
# if trim_right_ratio = 1.0, trim everything from right
|
361 |
+
padding_right = math.ceil(padding_total * self.trim_right_ratio)
|
362 |
+
padding_left = padding_total - padding_right
|
363 |
+
y = unpad1d(y, (padding_left, padding_right))
|
364 |
+
else:
|
365 |
+
# Asymmetric padding required for odd strides
|
366 |
+
padding_right = padding_total // 2
|
367 |
+
padding_left = padding_total - padding_right
|
368 |
+
y = unpad1d(y, (padding_left, padding_right))
|
369 |
+
return y
|
370 |
+
|
371 |
+
|
372 |
+
class StreamableLSTM(nn.Module):
|
373 |
+
"""LSTM without worrying about the hidden state, nor the layout of the data.
|
374 |
+
Expects input as convolutional layout.
|
375 |
+
"""
|
376 |
+
def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
|
377 |
+
super().__init__()
|
378 |
+
self.skip = skip
|
379 |
+
self.lstm = nn.LSTM(dimension, dimension, num_layers)
|
380 |
+
|
381 |
+
def forward(self, x):
|
382 |
+
x = x.permute(2, 0, 1)
|
383 |
+
y, _ = self.lstm(x)
|
384 |
+
if self.skip:
|
385 |
+
y = y + x
|
386 |
+
y = y.permute(1, 2, 0)
|
387 |
+
return y
|
388 |
+
|
389 |
+
|
390 |
+
class SEANetResnetBlock(nn.Module):
|
391 |
+
"""Residual block from SEANet model.
|
392 |
+
|
393 |
+
Args:
|
394 |
+
dim (int): Dimension of the input/output.
|
395 |
+
kernel_sizes (list): List of kernel sizes for the convolutions.
|
396 |
+
dilations (list): List of dilations for the convolutions.
|
397 |
+
activation (str): Activation function.
|
398 |
+
activation_params (dict): Parameters to provide to the activation function.
|
399 |
+
norm (str): Normalization method.
|
400 |
+
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
401 |
+
causal (bool): Whether to use fully causal convolution.
|
402 |
+
pad_mode (str): Padding mode for the convolutions.
|
403 |
+
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
404 |
+
true_skip (bool): Whether to use true skip connection or a simple
|
405 |
+
(streamable) convolution as the skip connection.
|
406 |
+
"""
|
407 |
+
def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1],
|
408 |
+
activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
|
409 |
+
norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False,
|
410 |
+
pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True):
|
411 |
+
super().__init__()
|
412 |
+
assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations'
|
413 |
+
act = getattr(nn, activation)
|
414 |
+
hidden = dim // compress
|
415 |
+
block = []
|
416 |
+
for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
|
417 |
+
in_chs = dim if i == 0 else hidden
|
418 |
+
out_chs = dim if i == len(kernel_sizes) - 1 else hidden
|
419 |
+
block += [
|
420 |
+
act(**activation_params),
|
421 |
+
StreamableConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation,
|
422 |
+
norm=norm, norm_kwargs=norm_params,
|
423 |
+
causal=causal, pad_mode=pad_mode),
|
424 |
+
]
|
425 |
+
self.block = nn.Sequential(*block)
|
426 |
+
self.shortcut: nn.Module
|
427 |
+
if true_skip:
|
428 |
+
self.shortcut = nn.Identity()
|
429 |
+
else:
|
430 |
+
self.shortcut = StreamableConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params,
|
431 |
+
causal=causal, pad_mode=pad_mode)
|
432 |
+
|
433 |
+
def forward(self, x):
|
434 |
+
return self.shortcut(x) + self.block(x)
|
435 |
+
|
436 |
+
|
437 |
+
class SEANetEncoder(nn.Module):
|
438 |
+
"""SEANet encoder.
|
439 |
+
|
440 |
+
Args:
|
441 |
+
channels (int): Audio channels.
|
442 |
+
dimension (int): Intermediate representation dimension.
|
443 |
+
n_filters (int): Base width for the model.
|
444 |
+
n_residual_layers (int): nb of residual layers.
|
445 |
+
ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
|
446 |
+
upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
|
447 |
+
that must match the decoder order. We use the decoder order as some models may only employ the decoder.
|
448 |
+
activation (str): Activation function.
|
449 |
+
activation_params (dict): Parameters to provide to the activation function.
|
450 |
+
norm (str): Normalization method.
|
451 |
+
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
452 |
+
kernel_size (int): Kernel size for the initial convolution.
|
453 |
+
last_kernel_size (int): Kernel size for the initial convolution.
|
454 |
+
residual_kernel_size (int): Kernel size for the residual layers.
|
455 |
+
dilation_base (int): How much to increase the dilation with each layer.
|
456 |
+
causal (bool): Whether to use fully causal convolution.
|
457 |
+
pad_mode (str): Padding mode for the convolutions.
|
458 |
+
true_skip (bool): Whether to use true skip connection or a simple
|
459 |
+
(streamable) convolution as the skip connection in the residual network blocks.
|
460 |
+
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
461 |
+
lstm (int): Number of LSTM layers at the end of the encoder.
|
462 |
+
disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
|
463 |
+
For the encoder, it corresponds to the N first blocks.
|
464 |
+
"""
|
465 |
+
def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
|
466 |
+
ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
|
467 |
+
norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
|
468 |
+
last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
|
469 |
+
pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0,
|
470 |
+
disable_norm_outer_blocks: int = 0):
|
471 |
+
super().__init__()
|
472 |
+
self.channels = channels
|
473 |
+
self.dimension = dimension
|
474 |
+
self.n_filters = n_filters
|
475 |
+
self.ratios = list(reversed(ratios))
|
476 |
+
del ratios
|
477 |
+
self.n_residual_layers = n_residual_layers
|
478 |
+
self.hop_length = np.prod(self.ratios)
|
479 |
+
self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
|
480 |
+
self.disable_norm_outer_blocks = disable_norm_outer_blocks
|
481 |
+
assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \
|
482 |
+
"Number of blocks for which to disable norm is invalid." \
|
483 |
+
"It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
|
484 |
+
|
485 |
+
act = getattr(nn, activation)
|
486 |
+
mult = 1
|
487 |
+
model: tp.List[nn.Module] = [
|
488 |
+
StreamableConv1d(channels, mult * n_filters, kernel_size,
|
489 |
+
norm='none' if self.disable_norm_outer_blocks >= 1 else norm,
|
490 |
+
norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
|
491 |
+
]
|
492 |
+
# Downsample to raw audio scale
|
493 |
+
for i, ratio in enumerate(self.ratios):
|
494 |
+
block_norm = 'none' if self.disable_norm_outer_blocks >= i + 2 else norm
|
495 |
+
# Add residual layers
|
496 |
+
for j in range(n_residual_layers):
|
497 |
+
model += [
|
498 |
+
SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1],
|
499 |
+
dilations=[dilation_base ** j, 1],
|
500 |
+
norm=block_norm, norm_params=norm_params,
|
501 |
+
activation=activation, activation_params=activation_params,
|
502 |
+
causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
|
503 |
+
|
504 |
+
# Add downsampling layers
|
505 |
+
model += [
|
506 |
+
act(**activation_params),
|
507 |
+
StreamableConv1d(mult * n_filters, mult * n_filters * 2,
|
508 |
+
kernel_size=ratio * 2, stride=ratio,
|
509 |
+
norm=block_norm, norm_kwargs=norm_params,
|
510 |
+
causal=causal, pad_mode=pad_mode),
|
511 |
+
]
|
512 |
+
mult *= 2
|
513 |
+
|
514 |
+
if lstm:
|
515 |
+
model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
|
516 |
+
|
517 |
+
model += [
|
518 |
+
act(**activation_params),
|
519 |
+
StreamableConv1d(mult * n_filters, dimension, last_kernel_size,
|
520 |
+
norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm,
|
521 |
+
norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
|
522 |
+
]
|
523 |
+
|
524 |
+
self.model = nn.Sequential(*model)
|
525 |
+
|
526 |
+
def forward(self, x):
|
527 |
+
return self.model(x)
|
528 |
+
|
529 |
+
|
530 |
+
class SEANetDecoder(nn.Module):
|
531 |
+
"""SEANet decoder.
|
532 |
+
|
533 |
+
Args:
|
534 |
+
channels (int): Audio channels.
|
535 |
+
dimension (int): Intermediate representation dimension.
|
536 |
+
n_filters (int): Base width for the model.
|
537 |
+
n_residual_layers (int): nb of residual layers.
|
538 |
+
ratios (Sequence[int]): kernel size and stride ratios.
|
539 |
+
activation (str): Activation function.
|
540 |
+
activation_params (dict): Parameters to provide to the activation function.
|
541 |
+
final_activation (str): Final activation function after all convolutions.
|
542 |
+
final_activation_params (dict): Parameters to provide to the activation function.
|
543 |
+
norm (str): Normalization method.
|
544 |
+
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
545 |
+
kernel_size (int): Kernel size for the initial convolution.
|
546 |
+
last_kernel_size (int): Kernel size for the initial convolution.
|
547 |
+
residual_kernel_size (int): Kernel size for the residual layers.
|
548 |
+
dilation_base (int): How much to increase the dilation with each layer.
|
549 |
+
causal (bool): Whether to use fully causal convolution.
|
550 |
+
pad_mode (str): Padding mode for the convolutions.
|
551 |
+
true_skip (bool): Whether to use true skip connection or a simple.
|
552 |
+
(streamable) convolution as the skip connection in the residual network blocks.
|
553 |
+
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
554 |
+
lstm (int): Number of LSTM layers at the end of the encoder.
|
555 |
+
disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
|
556 |
+
For the decoder, it corresponds to the N last blocks.
|
557 |
+
trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
|
558 |
+
If equal to 1.0, it means that all the trimming is done at the right.
|
559 |
+
"""
|
560 |
+
def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
|
561 |
+
ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
|
562 |
+
final_activation: tp.Optional[str] = None, final_activation_params: tp.Optional[dict] = None,
|
563 |
+
norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
|
564 |
+
last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
|
565 |
+
pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0,
|
566 |
+
disable_norm_outer_blocks: int = 0, trim_right_ratio: float = 1.0):
|
567 |
+
super().__init__()
|
568 |
+
self.dimension = dimension
|
569 |
+
self.channels = channels
|
570 |
+
self.n_filters = n_filters
|
571 |
+
self.ratios = ratios
|
572 |
+
del ratios
|
573 |
+
self.n_residual_layers = n_residual_layers
|
574 |
+
self.hop_length = np.prod(self.ratios)
|
575 |
+
self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
|
576 |
+
self.disable_norm_outer_blocks = disable_norm_outer_blocks
|
577 |
+
assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \
|
578 |
+
"Number of blocks for which to disable norm is invalid." \
|
579 |
+
"It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
|
580 |
+
|
581 |
+
act = getattr(nn, activation)
|
582 |
+
mult = int(2 ** len(self.ratios))
|
583 |
+
model: tp.List[nn.Module] = [
|
584 |
+
StreamableConv1d(dimension, mult * n_filters, kernel_size,
|
585 |
+
norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm,
|
586 |
+
norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
|
587 |
+
]
|
588 |
+
|
589 |
+
if lstm:
|
590 |
+
model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
|
591 |
+
|
592 |
+
# Upsample to raw audio scale
|
593 |
+
for i, ratio in enumerate(self.ratios):
|
594 |
+
block_norm = 'none' if self.disable_norm_outer_blocks >= self.n_blocks - (i + 1) else norm
|
595 |
+
# Add upsampling layers
|
596 |
+
model += [
|
597 |
+
act(**activation_params),
|
598 |
+
StreamableConvTranspose1d(mult * n_filters, mult * n_filters // 2,
|
599 |
+
kernel_size=ratio * 2, stride=ratio,
|
600 |
+
norm=block_norm, norm_kwargs=norm_params,
|
601 |
+
causal=causal, trim_right_ratio=trim_right_ratio),
|
602 |
+
]
|
603 |
+
# Add residual layers
|
604 |
+
for j in range(n_residual_layers):
|
605 |
+
model += [
|
606 |
+
SEANetResnetBlock(mult * n_filters // 2, kernel_sizes=[residual_kernel_size, 1],
|
607 |
+
dilations=[dilation_base ** j, 1],
|
608 |
+
activation=activation, activation_params=activation_params,
|
609 |
+
norm=block_norm, norm_params=norm_params, causal=causal,
|
610 |
+
pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
|
611 |
+
|
612 |
+
mult //= 2
|
613 |
+
|
614 |
+
# Add final layers
|
615 |
+
model += [
|
616 |
+
act(**activation_params),
|
617 |
+
StreamableConv1d(n_filters, channels, last_kernel_size,
|
618 |
+
norm='none' if self.disable_norm_outer_blocks >= 1 else norm,
|
619 |
+
norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
|
620 |
+
]
|
621 |
+
# Add optional final activation to decoder (eg. tanh)
|
622 |
+
if final_activation is not None:
|
623 |
+
final_act = getattr(nn, final_activation)
|
624 |
+
final_activation_params = final_activation_params or {}
|
625 |
+
model += [
|
626 |
+
final_act(**final_activation_params)
|
627 |
+
]
|
628 |
+
self.model = nn.Sequential(*model)
|
629 |
+
|
630 |
+
def forward(self, z):
|
631 |
+
y = self.model(z)
|
632 |
+
return y
|
633 |
+
|
634 |
+
|
635 |
+
def exists(val: tp.Optional[tp.Any]) -> bool:
|
636 |
+
return val is not None
|
637 |
+
|
638 |
+
|
639 |
+
def default(val: tp.Any, d: tp.Any) -> tp.Any:
|
640 |
+
return val if exists(val) else d
|
641 |
+
|
642 |
+
|
643 |
+
def l2norm(t):
|
644 |
+
return F.normalize(t, p=2, dim=-1)
|
645 |
+
|
646 |
+
|
647 |
+
def ema_inplace(moving_avg, new, decay: float):
|
648 |
+
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
649 |
+
|
650 |
+
|
651 |
+
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
|
652 |
+
return (x + epsilon) / (x.sum() + n_categories * epsilon)
|
653 |
+
|
654 |
+
|
655 |
+
def uniform_init(*shape: int):
|
656 |
+
t = torch.empty(shape)
|
657 |
+
nn.init.kaiming_uniform_(t)
|
658 |
+
return t
|
659 |
+
|
660 |
+
|
661 |
+
def sample_vectors(samples, num: int):
|
662 |
+
num_samples, device = samples.shape[0], samples.device
|
663 |
+
|
664 |
+
if num_samples >= num:
|
665 |
+
indices = torch.randperm(num_samples, device=device)[:num]
|
666 |
+
else:
|
667 |
+
indices = torch.randint(0, num_samples, (num,), device=device)
|
668 |
+
|
669 |
+
return samples[indices]
|
670 |
+
|
671 |
+
|
672 |
+
def kmeans(samples, num_clusters: int, num_iters: int = 10):
|
673 |
+
dim, dtype = samples.shape[-1], samples.dtype
|
674 |
+
|
675 |
+
means = sample_vectors(samples, num_clusters)
|
676 |
+
|
677 |
+
for _ in range(num_iters):
|
678 |
+
diffs = rearrange(samples, "n d -> n () d") - rearrange(
|
679 |
+
means, "c d -> () c d"
|
680 |
+
)
|
681 |
+
dists = -(diffs ** 2).sum(dim=-1)
|
682 |
+
|
683 |
+
buckets = dists.max(dim=-1).indices
|
684 |
+
bins = torch.bincount(buckets, minlength=num_clusters)
|
685 |
+
zero_mask = bins == 0
|
686 |
+
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
687 |
+
|
688 |
+
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
|
689 |
+
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
|
690 |
+
new_means = new_means / bins_min_clamped[..., None]
|
691 |
+
|
692 |
+
means = torch.where(zero_mask[..., None], means, new_means)
|
693 |
+
|
694 |
+
return means, bins
|
695 |
+
|
696 |
+
|
697 |
+
def orthogonal_loss_fn(t):
|
698 |
+
# eq (2) from https://arxiv.org/abs/2112.00384
|
699 |
+
n = t.shape[0]
|
700 |
+
normed_codes = l2norm(t)
|
701 |
+
identity = torch.eye(n, device=t.device)
|
702 |
+
cosine_sim = einsum("i d, j d -> i j", normed_codes, normed_codes)
|
703 |
+
return ((cosine_sim - identity) ** 2).sum() / (n ** 2)
|
704 |
+
|
705 |
+
|
706 |
+
class EuclideanCodebook(nn.Module):
|
707 |
+
"""Codebook with Euclidean distance.
|
708 |
+
|
709 |
+
Args:
|
710 |
+
dim (int): Dimension.
|
711 |
+
codebook_size (int): Codebook size.
|
712 |
+
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
|
713 |
+
If set to true, run the k-means algorithm on the first training batch and use
|
714 |
+
the learned centroids as initialization.
|
715 |
+
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
|
716 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
717 |
+
epsilon (float): Epsilon value for numerical stability.
|
718 |
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
719 |
+
that have an exponential moving average cluster size less than the specified threshold with
|
720 |
+
randomly selected vector from the current batch.
|
721 |
+
"""
|
722 |
+
def __init__(
|
723 |
+
self,
|
724 |
+
dim: int,
|
725 |
+
codebook_size: int,
|
726 |
+
kmeans_init: int = False,
|
727 |
+
kmeans_iters: int = 10,
|
728 |
+
decay: float = 0.8,
|
729 |
+
epsilon: float = 1e-5,
|
730 |
+
threshold_ema_dead_code: int = 2,
|
731 |
+
):
|
732 |
+
super().__init__()
|
733 |
+
self.decay = decay
|
734 |
+
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
|
735 |
+
embed = init_fn(codebook_size, dim)
|
736 |
+
|
737 |
+
self.codebook_size = codebook_size
|
738 |
+
|
739 |
+
self.kmeans_iters = kmeans_iters
|
740 |
+
self.epsilon = epsilon
|
741 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
742 |
+
|
743 |
+
self.register_buffer("inited", torch.Tensor([not kmeans_init]))
|
744 |
+
self.register_buffer("cluster_size", torch.zeros(codebook_size))
|
745 |
+
self.register_buffer("embed", embed)
|
746 |
+
self.register_buffer("embed_avg", embed.clone())
|
747 |
+
|
748 |
+
@torch.jit.ignore
|
749 |
+
def init_embed_(self, data):
|
750 |
+
if self.inited:
|
751 |
+
return
|
752 |
+
|
753 |
+
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
|
754 |
+
self.embed.data.copy_(embed)
|
755 |
+
self.embed_avg.data.copy_(embed.clone())
|
756 |
+
self.cluster_size.data.copy_(cluster_size)
|
757 |
+
self.inited.data.copy_(torch.Tensor([True]))
|
758 |
+
# Make sure all buffers across workers are in sync after initialization
|
759 |
+
flashy.distrib.broadcast_tensors(self.buffers())
|
760 |
+
|
761 |
+
def replace_(self, samples, mask):
|
762 |
+
modified_codebook = torch.where(
|
763 |
+
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
|
764 |
+
)
|
765 |
+
self.embed.data.copy_(modified_codebook)
|
766 |
+
|
767 |
+
def expire_codes_(self, batch_samples):
|
768 |
+
if self.threshold_ema_dead_code == 0:
|
769 |
+
return
|
770 |
+
|
771 |
+
expired_codes = self.cluster_size < self.threshold_ema_dead_code
|
772 |
+
if not torch.any(expired_codes):
|
773 |
+
return
|
774 |
+
|
775 |
+
batch_samples = rearrange(batch_samples, "... d -> (...) d")
|
776 |
+
self.replace_(batch_samples, mask=expired_codes)
|
777 |
+
flashy.distrib.broadcast_tensors(self.buffers())
|
778 |
+
|
779 |
+
def preprocess(self, x):
|
780 |
+
x = rearrange(x, "... d -> (...) d")
|
781 |
+
return x
|
782 |
+
|
783 |
+
def quantize(self, x):
|
784 |
+
embed = self.embed.t()
|
785 |
+
dist = -(
|
786 |
+
x.pow(2).sum(1, keepdim=True)
|
787 |
+
- 2 * x @ embed
|
788 |
+
+ embed.pow(2).sum(0, keepdim=True)
|
789 |
+
)
|
790 |
+
embed_ind = dist.max(dim=-1).indices
|
791 |
+
return embed_ind
|
792 |
+
|
793 |
+
def postprocess_emb(self, embed_ind, shape):
|
794 |
+
return embed_ind.view(*shape[:-1])
|
795 |
+
|
796 |
+
def dequantize(self, embed_ind):
|
797 |
+
quantize = F.embedding(embed_ind, self.embed)
|
798 |
+
return quantize
|
799 |
+
|
800 |
+
def encode(self, x):
|
801 |
+
shape = x.shape
|
802 |
+
# pre-process
|
803 |
+
x = self.preprocess(x)
|
804 |
+
# quantize
|
805 |
+
embed_ind = self.quantize(x)
|
806 |
+
# post-process
|
807 |
+
embed_ind = self.postprocess_emb(embed_ind, shape)
|
808 |
+
return embed_ind
|
809 |
+
|
810 |
+
def decode(self, embed_ind):
|
811 |
+
quantize = self.dequantize(embed_ind)
|
812 |
+
return quantize
|
813 |
+
|
814 |
+
def forward(self, x):
|
815 |
+
raise NotImplementedError()
|
816 |
+
shape, dtype = x.shape, x.dtype
|
817 |
+
x = self.preprocess(x)
|
818 |
+
self.init_embed_(x)
|
819 |
+
|
820 |
+
embed_ind = self.quantize(x)
|
821 |
+
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
|
822 |
+
embed_ind = self.postprocess_emb(embed_ind, shape)
|
823 |
+
quantize = self.dequantize(embed_ind)
|
824 |
+
|
825 |
+
if self.training:
|
826 |
+
# We do the expiry of code at that point as buffers are in sync
|
827 |
+
# and all the workers will take the same decision.
|
828 |
+
self.expire_codes_(x)
|
829 |
+
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
|
830 |
+
embed_sum = x.t() @ embed_onehot
|
831 |
+
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
|
832 |
+
cluster_size = (
|
833 |
+
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
|
834 |
+
* self.cluster_size.sum()
|
835 |
+
)
|
836 |
+
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
837 |
+
self.embed.data.copy_(embed_normalized)
|
838 |
+
|
839 |
+
return quantize, embed_ind
|
840 |
+
|
841 |
+
|
842 |
+
class VectorQuantization(nn.Module):
|
843 |
+
"""Vector quantization implementation.
|
844 |
+
Currently supports only euclidean distance.
|
845 |
+
|
846 |
+
Args:
|
847 |
+
dim (int): Dimension
|
848 |
+
codebook_size (int): Codebook size
|
849 |
+
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
|
850 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
851 |
+
epsilon (float): Epsilon value for numerical stability.
|
852 |
+
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
853 |
+
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
854 |
+
threshold_ema_dead_code (int):
|
855 |
+
channels_last (bool): Channels are the last dimension in the input tensors.
|
856 |
+
commitment_weight (float): Weight for commitment loss.
|
857 |
+
orthogonal_reg_weight (float): Orthogonal regularization weights.
|
858 |
+
orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes.
|
859 |
+
orthogonal_reg_max_codes (optional int): Maximum number of codes to consider
|
860 |
+
for orthogonal regularization.
|
861 |
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
862 |
+
that have an exponential moving average cluster size less than the specified threshold with
|
863 |
+
randomly selected vector from the current batch.
|
864 |
+
"""
|
865 |
+
def __init__(
|
866 |
+
self,
|
867 |
+
dim: int,
|
868 |
+
codebook_size: int,
|
869 |
+
codebook_dim: tp.Optional[int] = None,
|
870 |
+
decay: float = 0.8,
|
871 |
+
epsilon: float = 1e-5,
|
872 |
+
kmeans_init: bool = False,
|
873 |
+
kmeans_iters: int = 10,
|
874 |
+
threshold_ema_dead_code: int = 2,
|
875 |
+
channels_last: bool = False,
|
876 |
+
commitment_weight: float = 1.,
|
877 |
+
orthogonal_reg_weight: float = 0.0,
|
878 |
+
orthogonal_reg_active_codes_only: bool = False,
|
879 |
+
orthogonal_reg_max_codes: tp.Optional[int] = None,
|
880 |
+
):
|
881 |
+
super().__init__()
|
882 |
+
_codebook_dim: int = default(codebook_dim, dim)
|
883 |
+
|
884 |
+
requires_projection = _codebook_dim != dim
|
885 |
+
self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity())
|
886 |
+
self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity())
|
887 |
+
|
888 |
+
self.epsilon = epsilon
|
889 |
+
self.commitment_weight = commitment_weight
|
890 |
+
|
891 |
+
self.orthogonal_reg_weight = orthogonal_reg_weight
|
892 |
+
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
|
893 |
+
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
|
894 |
+
|
895 |
+
self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
|
896 |
+
kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
|
897 |
+
decay=decay, epsilon=epsilon,
|
898 |
+
threshold_ema_dead_code=threshold_ema_dead_code)
|
899 |
+
self.codebook_size = codebook_size
|
900 |
+
|
901 |
+
self.channels_last = channels_last
|
902 |
+
|
903 |
+
@property
|
904 |
+
def codebook(self):
|
905 |
+
return self._codebook.embed
|
906 |
+
|
907 |
+
@property
|
908 |
+
def inited(self):
|
909 |
+
return self._codebook.inited
|
910 |
+
|
911 |
+
def _preprocess(self, x):
|
912 |
+
if not self.channels_last:
|
913 |
+
x = rearrange(x, "b d n -> b n d")
|
914 |
+
return x
|
915 |
+
|
916 |
+
def _postprocess(self, quantize):
|
917 |
+
if not self.channels_last:
|
918 |
+
quantize = rearrange(quantize, "b n d -> b d n")
|
919 |
+
return quantize
|
920 |
+
|
921 |
+
def encode(self, x):
|
922 |
+
x = self._preprocess(x)
|
923 |
+
x = self.project_in(x)
|
924 |
+
embed_in = self._codebook.encode(x)
|
925 |
+
return embed_in
|
926 |
+
|
927 |
+
def decode(self, embed_ind):
|
928 |
+
quantize = self._codebook.decode(embed_ind)
|
929 |
+
quantize = self.project_out(quantize)
|
930 |
+
quantize = self._postprocess(quantize)
|
931 |
+
return quantize
|
932 |
+
|
933 |
+
def forward(self, x):
|
934 |
+
device = x.device
|
935 |
+
x = self._preprocess(x)
|
936 |
+
|
937 |
+
x = self.project_in(x)
|
938 |
+
quantize, embed_ind = self._codebook(x)
|
939 |
+
|
940 |
+
if self.training:
|
941 |
+
quantize = x + (quantize - x).detach()
|
942 |
+
|
943 |
+
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
|
944 |
+
|
945 |
+
if self.training:
|
946 |
+
if self.commitment_weight > 0:
|
947 |
+
commit_loss = F.mse_loss(quantize.detach(), x)
|
948 |
+
loss = loss + commit_loss * self.commitment_weight
|
949 |
+
|
950 |
+
if self.orthogonal_reg_weight > 0:
|
951 |
+
codebook = self.codebook
|
952 |
+
|
953 |
+
if self.orthogonal_reg_active_codes_only:
|
954 |
+
# only calculate orthogonal loss for the activated codes for this batch
|
955 |
+
unique_code_ids = torch.unique(embed_ind)
|
956 |
+
codebook = codebook[unique_code_ids]
|
957 |
+
|
958 |
+
num_codes = codebook.shape[0]
|
959 |
+
if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes:
|
960 |
+
rand_ids = torch.randperm(num_codes, device=device)[:self.orthogonal_reg_max_codes]
|
961 |
+
codebook = codebook[rand_ids]
|
962 |
+
|
963 |
+
orthogonal_reg_loss = orthogonal_loss_fn(codebook)
|
964 |
+
loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight
|
965 |
+
|
966 |
+
quantize = self.project_out(quantize)
|
967 |
+
quantize = self._postprocess(quantize)
|
968 |
+
|
969 |
+
return quantize, embed_ind, loss
|
970 |
+
|
971 |
+
|
972 |
+
class ResidualVectorQuantization(nn.Module):
|
973 |
+
"""Residual vector quantization implementation.
|
974 |
+
|
975 |
+
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
|
976 |
+
"""
|
977 |
+
def __init__(self, *, num_quantizers, **kwargs):
|
978 |
+
super().__init__()
|
979 |
+
codebook_size = kwargs.pop('codebook_size', None)
|
980 |
+
if codebook_size is None:
|
981 |
+
raise ValueError("codebook_size must be provided in kwargs")
|
982 |
+
if type(codebook_size) != list:
|
983 |
+
codebook_size = [codebook_size] * num_quantizers
|
984 |
+
self.layers = nn.ModuleList(
|
985 |
+
[VectorQuantization(codebook_size=cur_codebook_size, **kwargs) for _,cur_codebook_size in zip(range(num_quantizers), codebook_size)]
|
986 |
+
)
|
987 |
+
|
988 |
+
|
989 |
+
# self.layers = nn.ModuleList(
|
990 |
+
# [VectorQuantization(**kwargs) for _ in range(num_quantizers)]
|
991 |
+
# )
|
992 |
+
|
993 |
+
def forward(self, x, n_q: tp.Optional[int] = None):
|
994 |
+
quantized_out = 0.0
|
995 |
+
residual = x
|
996 |
+
|
997 |
+
all_losses = []
|
998 |
+
all_indices = []
|
999 |
+
|
1000 |
+
n_q = n_q or len(self.layers)
|
1001 |
+
|
1002 |
+
for i, layer in enumerate(self.layers[:n_q]):
|
1003 |
+
quantized, indices, loss = layer(residual)
|
1004 |
+
residual = residual - quantized
|
1005 |
+
quantized_out = quantized_out + quantized
|
1006 |
+
all_indices.append(indices)
|
1007 |
+
all_losses.append(loss)
|
1008 |
+
|
1009 |
+
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
|
1010 |
+
return quantized_out, out_indices, out_losses
|
1011 |
+
|
1012 |
+
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
|
1013 |
+
residual = x
|
1014 |
+
all_indices = []
|
1015 |
+
n_q = n_q or len(self.layers)
|
1016 |
+
for layer in self.layers[:n_q]:
|
1017 |
+
indices = layer.encode(residual)
|
1018 |
+
quantized = layer.decode(indices)
|
1019 |
+
# the original code is below
|
1020 |
+
# since quantize has the gradient of residual, according to line 321
|
1021 |
+
# quantize = x + (quantize - x).detach()
|
1022 |
+
# the code below will make commitment loss to be 0 for all codebooks except for codebook1
|
1023 |
+
# https://github.com/facebookresearch/encodec/issues/25
|
1024 |
+
# therefore we change it
|
1025 |
+
|
1026 |
+
residual = residual - quantized
|
1027 |
+
# residual = residual - quantized.detach()
|
1028 |
+
# since commitment loss is averaged, the scale of the loss won't get change (not as said in the issue above)
|
1029 |
+
all_indices.append(indices)
|
1030 |
+
out_indices = torch.stack(all_indices)
|
1031 |
+
return out_indices
|
1032 |
+
|
1033 |
+
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
|
1034 |
+
quantized_out = torch.tensor(0.0, device=q_indices.device)
|
1035 |
+
for i, indices in enumerate(q_indices):
|
1036 |
+
layer = self.layers[i]
|
1037 |
+
quantized = layer.decode(indices)
|
1038 |
+
quantized_out = quantized_out + quantized
|
1039 |
+
return quantized_out
|
1040 |
+
|
1041 |
+
|
1042 |
+
class ResidualVectorQuantizer(BaseQuantizer):
|
1043 |
+
"""Residual Vector Quantizer.
|
1044 |
+
|
1045 |
+
Args:
|
1046 |
+
dimension (int): Dimension of the codebooks.
|
1047 |
+
n_q (int): Number of residual vector quantizers used.
|
1048 |
+
q_dropout (bool): Random quantizer drop out at train time.
|
1049 |
+
bins (int): Codebook size.
|
1050 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
1051 |
+
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
1052 |
+
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
1053 |
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
1054 |
+
that have an exponential moving average cluster size less than the specified threshold with
|
1055 |
+
randomly selected vector from the current batch.
|
1056 |
+
orthogonal_reg_weight (float): Orthogonal regularization weights.
|
1057 |
+
orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes.
|
1058 |
+
orthogonal_reg_max_codes (optional int): Maximum number of codes to consider.
|
1059 |
+
for orthogonal regularization.
|
1060 |
+
"""
|
1061 |
+
def __init__(
|
1062 |
+
self,
|
1063 |
+
dimension: int = 256,
|
1064 |
+
n_q: int = 8,
|
1065 |
+
q_dropout: bool = False,
|
1066 |
+
bins: tp.Union[int, tp.List[int]] = 1024,
|
1067 |
+
decay: float = 0.99,
|
1068 |
+
kmeans_init: bool = True,
|
1069 |
+
kmeans_iters: int = 10,
|
1070 |
+
threshold_ema_dead_code: int = 2,
|
1071 |
+
orthogonal_reg_weight: float = 0.0,
|
1072 |
+
orthogonal_reg_active_codes_only: bool = False,
|
1073 |
+
orthogonal_reg_max_codes: tp.Optional[int] = None,
|
1074 |
+
):
|
1075 |
+
super().__init__()
|
1076 |
+
self.max_n_q = n_q
|
1077 |
+
self.n_q = n_q
|
1078 |
+
self.q_dropout = q_dropout
|
1079 |
+
self.dimension = dimension
|
1080 |
+
self.bins = bins
|
1081 |
+
self.decay = decay
|
1082 |
+
self.kmeans_init = kmeans_init
|
1083 |
+
self.kmeans_iters = kmeans_iters
|
1084 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
1085 |
+
self.orthogonal_reg_weight = orthogonal_reg_weight
|
1086 |
+
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
|
1087 |
+
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
|
1088 |
+
self.vq = ResidualVectorQuantization(
|
1089 |
+
dim=self.dimension,
|
1090 |
+
codebook_size=self.bins,
|
1091 |
+
num_quantizers=self.n_q,
|
1092 |
+
decay=self.decay,
|
1093 |
+
kmeans_init=self.kmeans_init,
|
1094 |
+
kmeans_iters=self.kmeans_iters,
|
1095 |
+
threshold_ema_dead_code=self.threshold_ema_dead_code,
|
1096 |
+
orthogonal_reg_weight=self.orthogonal_reg_weight,
|
1097 |
+
orthogonal_reg_active_codes_only=self.orthogonal_reg_active_codes_only,
|
1098 |
+
orthogonal_reg_max_codes=self.orthogonal_reg_max_codes,
|
1099 |
+
channels_last=False
|
1100 |
+
)
|
1101 |
+
|
1102 |
+
def forward(self, x: torch.Tensor, frame_rate: int):
|
1103 |
+
n_q = self.n_q
|
1104 |
+
if self.training and self.q_dropout:
|
1105 |
+
n_q = int(torch.randint(1, self.n_q + 1, (1,)).item())
|
1106 |
+
if type(self.bins) == list:
|
1107 |
+
bins = self.bins
|
1108 |
+
else:
|
1109 |
+
bins = [self.bins] * self.n_q
|
1110 |
+
bw_per_q = [math.log2(bin) * frame_rate / 1000 for bin in bins]
|
1111 |
+
bw = torch.tensor(sum(bw_per_q)).to(x)
|
1112 |
+
quantized, codes, commit_loss = self.vq(x, n_q=n_q)
|
1113 |
+
codes = codes.transpose(0, 1)
|
1114 |
+
# codes is [B, K, T], with T frames, K nb of codebooks.
|
1115 |
+
return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss))
|
1116 |
+
|
1117 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
1118 |
+
"""Encode a given input tensor with the specified frame rate at the given bandwidth.
|
1119 |
+
The RVQ encode method sets the appropriate number of quantizer to use
|
1120 |
+
and returns indices for each quantizer.
|
1121 |
+
"""
|
1122 |
+
n_q = self.n_q
|
1123 |
+
codes = self.vq.encode(x, n_q=n_q)
|
1124 |
+
codes = codes.transpose(0, 1)
|
1125 |
+
# codes is [B, K, T], with T frames, K nb of codebooks.
|
1126 |
+
return codes
|
1127 |
+
|
1128 |
+
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
1129 |
+
"""Decode the given codes to the quantized representation."""
|
1130 |
+
# codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T].
|
1131 |
+
codes = codes.transpose(0, 1)
|
1132 |
+
quantized = self.vq.decode(codes)
|
1133 |
+
return quantized
|
1134 |
+
|
1135 |
+
@property
|
1136 |
+
def total_codebooks(self):
|
1137 |
+
return self.max_n_q
|
1138 |
+
|
1139 |
+
@property
|
1140 |
+
def num_codebooks(self):
|
1141 |
+
return self.n_q
|
1142 |
+
|
1143 |
+
def set_num_codebooks(self, n: int):
|
1144 |
+
assert n > 0 and n <= self.max_n_q
|
1145 |
+
self.n_q = n
|
1146 |
+
|
1147 |
+
class DummyQuantizer(BaseQuantizer):
|
1148 |
+
"""Fake quantizer that actually does not perform any quantization.
|
1149 |
+
"""
|
1150 |
+
def __init__(self):
|
1151 |
+
super().__init__()
|
1152 |
+
|
1153 |
+
def forward(self, x: torch.Tensor, frame_rate: int):
|
1154 |
+
q = x.unsqueeze(1)
|
1155 |
+
return QuantizedResult(x, q, torch.tensor(q.numel() * 32 * frame_rate / 1000 / len(x)).to(x))
|
1156 |
+
|
1157 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
1158 |
+
"""Encode a given input tensor with the specified sample rate at the given bandwidth.
|
1159 |
+
In the case of the DummyQuantizer, the codes are actually identical
|
1160 |
+
to the input and resulting quantized representation as no quantization is done.
|
1161 |
+
"""
|
1162 |
+
return x.unsqueeze(1)
|
1163 |
+
|
1164 |
+
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
1165 |
+
"""Decode the given codes to the quantized representation.
|
1166 |
+
In the case of the DummyQuantizer, the codes are actually identical
|
1167 |
+
to the input and resulting quantized representation as no quantization is done.
|
1168 |
+
"""
|
1169 |
+
return codes.squeeze(1)
|
1170 |
+
|
1171 |
+
@property
|
1172 |
+
def total_codebooks(self):
|
1173 |
+
"""Total number of codebooks."""
|
1174 |
+
return 1
|
1175 |
+
|
1176 |
+
@property
|
1177 |
+
def num_codebooks(self):
|
1178 |
+
"""Total number of codebooks."""
|
1179 |
+
return self.total_codebooks
|
1180 |
+
|
1181 |
+
def set_num_codebooks(self, n: int):
|
1182 |
+
"""Set the number of active codebooks."""
|
1183 |
+
raise AttributeError("Cannot override the number of codebooks for the dummy quantizer")
|
1184 |
+
|
1185 |
+
|
1186 |
+
class EncodecModel(CompressionModel):
|
1187 |
+
"""Encodec model operating on the raw waveform.
|
1188 |
+
|
1189 |
+
Args:
|
1190 |
+
encoder (nn.Module): Encoder network.
|
1191 |
+
decoder (nn.Module): Decoder network.
|
1192 |
+
quantizer (BaseQuantizer): Quantizer network.
|
1193 |
+
frame_rate (int): Frame rate for the latent representation.
|
1194 |
+
sample_rate (int): Audio sample rate.
|
1195 |
+
channels (int): Number of audio channels.
|
1196 |
+
causal (bool): Whether to use a causal version of the model.
|
1197 |
+
renormalize (bool): Whether to renormalize the audio before running the model.
|
1198 |
+
"""
|
1199 |
+
# we need assignment to override the property in the abstract class,
|
1200 |
+
# I couldn't find a better way...
|
1201 |
+
frame_rate: float = 0
|
1202 |
+
sample_rate: int = 0
|
1203 |
+
channels: int = 0
|
1204 |
+
|
1205 |
+
def __init__(self,
|
1206 |
+
encoder: nn.Module,
|
1207 |
+
decoder: nn.Module,
|
1208 |
+
quantizer: BaseQuantizer,
|
1209 |
+
frame_rate: int,
|
1210 |
+
sample_rate: int,
|
1211 |
+
channels: int,
|
1212 |
+
causal: bool = False,
|
1213 |
+
renormalize: bool = False):
|
1214 |
+
super().__init__()
|
1215 |
+
self.encoder = encoder
|
1216 |
+
self.decoder = decoder
|
1217 |
+
self.quantizer = quantizer
|
1218 |
+
self.frame_rate = frame_rate
|
1219 |
+
self.sample_rate = sample_rate
|
1220 |
+
self.channels = channels
|
1221 |
+
self.renormalize = renormalize
|
1222 |
+
self.causal = causal
|
1223 |
+
if self.causal:
|
1224 |
+
# we force disabling here to avoid handling linear overlap of segments
|
1225 |
+
# as supported in original EnCodec codebase.
|
1226 |
+
assert not self.renormalize, 'Causal model does not support renormalize'
|
1227 |
+
|
1228 |
+
@property
|
1229 |
+
def total_codebooks(self):
|
1230 |
+
"""Total number of quantizer codebooks available."""
|
1231 |
+
return self.quantizer.total_codebooks
|
1232 |
+
|
1233 |
+
@property
|
1234 |
+
def num_codebooks(self):
|
1235 |
+
"""Active number of codebooks used by the quantizer."""
|
1236 |
+
return self.quantizer.num_codebooks
|
1237 |
+
|
1238 |
+
def set_num_codebooks(self, n: int):
|
1239 |
+
"""Set the active number of codebooks used by the quantizer."""
|
1240 |
+
self.quantizer.set_num_codebooks(n)
|
1241 |
+
|
1242 |
+
@property
|
1243 |
+
def cardinality(self):
|
1244 |
+
"""Cardinality of each codebook."""
|
1245 |
+
return self.quantizer.bins
|
1246 |
+
|
1247 |
+
def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
1248 |
+
scale: tp.Optional[torch.Tensor]
|
1249 |
+
if self.renormalize:
|
1250 |
+
mono = x.mean(dim=1, keepdim=True)
|
1251 |
+
volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt()
|
1252 |
+
scale = 1e-8 + volume
|
1253 |
+
x = x / scale
|
1254 |
+
scale = scale.view(-1, 1)
|
1255 |
+
else:
|
1256 |
+
scale = None
|
1257 |
+
return x, scale
|
1258 |
+
|
1259 |
+
def postprocess(self,
|
1260 |
+
x: torch.Tensor,
|
1261 |
+
scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
|
1262 |
+
if scale is not None:
|
1263 |
+
assert self.renormalize
|
1264 |
+
x = x * scale.view(-1, 1, 1)
|
1265 |
+
return x
|
1266 |
+
|
1267 |
+
def forward(self, x: torch.Tensor, encode=False) -> QuantizedResult:
|
1268 |
+
if encode:
|
1269 |
+
return self.encode(x)
|
1270 |
+
else:
|
1271 |
+
raise NotImplementedError("model forward and training is not supported.")
|
1272 |
+
assert x.dim() == 3
|
1273 |
+
length = x.shape[-1]
|
1274 |
+
x, scale = self.preprocess(x)
|
1275 |
+
|
1276 |
+
emb = self.encoder(x)
|
1277 |
+
q_res = self.quantizer(emb, self.frame_rate)
|
1278 |
+
out = self.decoder(q_res.x)
|
1279 |
+
|
1280 |
+
# remove extra padding added by the encoder and decoder
|
1281 |
+
assert out.shape[-1] >= length, (out.shape[-1], length)
|
1282 |
+
out = out[..., :length]
|
1283 |
+
|
1284 |
+
q_res.x = self.postprocess(out, scale)
|
1285 |
+
|
1286 |
+
return q_res
|
1287 |
+
|
1288 |
+
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
1289 |
+
"""Encode the given input tensor to quantized representation along with scale parameter.
|
1290 |
+
|
1291 |
+
Args:
|
1292 |
+
x (torch.Tensor): Float tensor of shape [B, C, T]
|
1293 |
+
|
1294 |
+
Returns:
|
1295 |
+
codes, scale (tuple of torch.Tensor, torch.Tensor): Tuple composed of:
|
1296 |
+
codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep.
|
1297 |
+
scale a float tensor containing the scale for audio renormalizealization.
|
1298 |
+
"""
|
1299 |
+
assert x.dim() == 3
|
1300 |
+
x, scale = self.preprocess(x)
|
1301 |
+
emb = self.encoder(x)
|
1302 |
+
codes = self.quantizer.encode(emb)
|
1303 |
+
return codes, scale
|
1304 |
+
|
1305 |
+
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
1306 |
+
"""Decode the given codes to a reconstructed representation, using the scale to perform
|
1307 |
+
audio denormalization if needed.
|
1308 |
+
|
1309 |
+
Args:
|
1310 |
+
codes (torch.Tensor): Int tensor of shape [B, K, T]
|
1311 |
+
scale (torch.Tensor, optional): Float tensor containing the scale value.
|
1312 |
+
|
1313 |
+
Returns:
|
1314 |
+
out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio.
|
1315 |
+
"""
|
1316 |
+
emb = self.decode_latent(codes)
|
1317 |
+
out = self.decoder(emb)
|
1318 |
+
out = self.postprocess(out, scale)
|
1319 |
+
# out contains extra padding added by the encoder and decoder
|
1320 |
+
return out
|
1321 |
+
|
1322 |
+
def decode_latent(self, codes: torch.Tensor):
|
1323 |
+
"""Decode from the discrete codes to continuous latent space."""
|
1324 |
+
return self.quantizer.decode(codes)
|
1325 |
+
|
1326 |
+
class EncodecModel_encode_only(CompressionModel):
|
1327 |
+
"""Encodec model operating on the raw waveform. Encode only, so no decoder
|
1328 |
+
|
1329 |
+
Args:
|
1330 |
+
encoder (nn.Module): Encoder network.
|
1331 |
+
quantizer (BaseQuantizer): Quantizer network.
|
1332 |
+
frame_rate (int): Frame rate for the latent representation.
|
1333 |
+
sample_rate (int): Audio sample rate.
|
1334 |
+
channels (int): Number of audio channels.
|
1335 |
+
causal (bool): Whether to use a causal version of the model.
|
1336 |
+
renormalize (bool): Whether to renormalize the audio before running the model.
|
1337 |
+
"""
|
1338 |
+
# we need assignment to override the property in the abstract class,
|
1339 |
+
# I couldn't find a better way...
|
1340 |
+
frame_rate: float = 0
|
1341 |
+
sample_rate: int = 0
|
1342 |
+
channels: int = 0
|
1343 |
+
|
1344 |
+
def __init__(self,
|
1345 |
+
encoder: nn.Module,
|
1346 |
+
quantizer: BaseQuantizer,
|
1347 |
+
frame_rate: int,
|
1348 |
+
sample_rate: int,
|
1349 |
+
channels: int,
|
1350 |
+
causal: bool = False,
|
1351 |
+
renormalize: bool = False):
|
1352 |
+
super().__init__()
|
1353 |
+
self.encoder = encoder
|
1354 |
+
self.quantizer = quantizer
|
1355 |
+
self.frame_rate = frame_rate
|
1356 |
+
self.sample_rate = sample_rate
|
1357 |
+
self.channels = channels
|
1358 |
+
self.renormalize = renormalize
|
1359 |
+
self.causal = causal
|
1360 |
+
if self.causal:
|
1361 |
+
# we force disabling here to avoid handling linear overlap of segments
|
1362 |
+
# as supported in original EnCodec codebase.
|
1363 |
+
assert not self.renormalize, 'Causal model does not support renormalize'
|
1364 |
+
|
1365 |
+
@property
|
1366 |
+
def total_codebooks(self):
|
1367 |
+
"""Total number of quantizer codebooks available."""
|
1368 |
+
return self.quantizer.total_codebooks
|
1369 |
+
|
1370 |
+
@property
|
1371 |
+
def num_codebooks(self):
|
1372 |
+
"""Active number of codebooks used by the quantizer."""
|
1373 |
+
return self.quantizer.num_codebooks
|
1374 |
+
|
1375 |
+
def set_num_codebooks(self, n: int):
|
1376 |
+
"""Set the active number of codebooks used by the quantizer."""
|
1377 |
+
self.quantizer.set_num_codebooks(n)
|
1378 |
+
|
1379 |
+
@property
|
1380 |
+
def cardinality(self):
|
1381 |
+
"""Cardinality of each codebook."""
|
1382 |
+
return self.quantizer.bins
|
1383 |
+
|
1384 |
+
def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
1385 |
+
scale: tp.Optional[torch.Tensor]
|
1386 |
+
if self.renormalize:
|
1387 |
+
mono = x.mean(dim=1, keepdim=True)
|
1388 |
+
volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt()
|
1389 |
+
scale = 1e-8 + volume
|
1390 |
+
x = x / scale
|
1391 |
+
scale = scale.view(-1, 1)
|
1392 |
+
else:
|
1393 |
+
scale = None
|
1394 |
+
return x, scale
|
1395 |
+
|
1396 |
+
def postprocess(self,
|
1397 |
+
x: torch.Tensor,
|
1398 |
+
scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
|
1399 |
+
if scale is not None:
|
1400 |
+
assert self.renormalize
|
1401 |
+
x = x * scale.view(-1, 1, 1)
|
1402 |
+
return x
|
1403 |
+
|
1404 |
+
def forward(self, x: torch.Tensor, encode=False) -> QuantizedResult:
|
1405 |
+
if encode:
|
1406 |
+
return self.encode(x)
|
1407 |
+
else:
|
1408 |
+
raise NotImplementedError("model forward and training is not supported.")
|
1409 |
+
assert x.dim() == 3
|
1410 |
+
length = x.shape[-1]
|
1411 |
+
x, scale = self.preprocess(x)
|
1412 |
+
|
1413 |
+
emb = self.encoder(x)
|
1414 |
+
q_res = self.quantizer(emb, self.frame_rate)
|
1415 |
+
out = self.decoder(q_res.x)
|
1416 |
+
|
1417 |
+
# remove extra padding added by the encoder and decoder
|
1418 |
+
assert out.shape[-1] >= length, (out.shape[-1], length)
|
1419 |
+
out = out[..., :length]
|
1420 |
+
|
1421 |
+
q_res.x = self.postprocess(out, scale)
|
1422 |
+
|
1423 |
+
return q_res
|
1424 |
+
|
1425 |
+
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
1426 |
+
"""Encode the given input tensor to quantized representation along with scale parameter.
|
1427 |
+
|
1428 |
+
Args:
|
1429 |
+
x (torch.Tensor): Float tensor of shape [B, C, T]
|
1430 |
+
|
1431 |
+
Returns:
|
1432 |
+
codes, scale (tuple of torch.Tensor, torch.Tensor): Tuple composed of:
|
1433 |
+
codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep.
|
1434 |
+
scale a float tensor containing the scale for audio renormalizealization.
|
1435 |
+
"""
|
1436 |
+
assert x.dim() == 3
|
1437 |
+
x, scale = self.preprocess(x)
|
1438 |
+
emb = self.encoder(x)
|
1439 |
+
codes = self.quantizer.encode(emb)
|
1440 |
+
return codes, scale
|
1441 |
+
|
1442 |
+
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
1443 |
+
"""Decode the given codes to a reconstructed representation, using the scale to perform
|
1444 |
+
audio denormalization if needed.
|
1445 |
+
|
1446 |
+
Args:
|
1447 |
+
codes (torch.Tensor): Int tensor of shape [B, K, T]
|
1448 |
+
scale (torch.Tensor, optional): Float tensor containing the scale value.
|
1449 |
+
|
1450 |
+
Returns:
|
1451 |
+
out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio.
|
1452 |
+
"""
|
1453 |
+
raise NotImplementedError("Decode is not supported for encode only model")
|
1454 |
+
emb = self.decode_latent(codes)
|
1455 |
+
out = self.decoder(emb)
|
1456 |
+
out = self.postprocess(out, scale)
|
1457 |
+
# out contains extra padding added by the encoder and decoder
|
1458 |
+
return out
|
1459 |
+
|
1460 |
+
def decode_latent(self, codes: torch.Tensor):
|
1461 |
+
"""Decode from the discrete codes to continuous latent space."""
|
1462 |
+
raise NotImplementedError("Decode is not supported for encode only model")
|
1463 |
+
return self.quantizer.decode(codes)
|
1464 |
+
|
1465 |
+
def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) -> BaseQuantizer:
|
1466 |
+
klass = {
|
1467 |
+
'no_quant': DummyQuantizer,
|
1468 |
+
'rvq': ResidualVectorQuantizer
|
1469 |
+
}[quantizer]
|
1470 |
+
kwargs = dict_from_config(getattr(cfg, quantizer))
|
1471 |
+
if quantizer != 'no_quant':
|
1472 |
+
kwargs['dimension'] = dimension
|
1473 |
+
return klass(**kwargs)
|
1474 |
+
|
1475 |
+
def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig):
|
1476 |
+
if encoder_name == 'seanet':
|
1477 |
+
kwargs = dict_from_config(getattr(cfg, 'seanet'))
|
1478 |
+
encoder_override_kwargs = kwargs.pop('encoder')
|
1479 |
+
decoder_override_kwargs = kwargs.pop('decoder')
|
1480 |
+
encoder_kwargs = {**kwargs, **encoder_override_kwargs}
|
1481 |
+
decoder_kwargs = {**kwargs, **decoder_override_kwargs}
|
1482 |
+
encoder = SEANetEncoder(**encoder_kwargs)
|
1483 |
+
decoder = SEANetDecoder(**decoder_kwargs)
|
1484 |
+
return encoder, decoder
|
1485 |
+
else:
|
1486 |
+
raise KeyError(f"Unexpected compression model {cfg.compression_model}")
|
1487 |
+
|
1488 |
+
|
1489 |
+
def get_compression_model(ckpt_fn, encode_only=False, device="cpu") -> CompressionModel:
|
1490 |
+
"""Instantiate a compression model."""
|
1491 |
+
if device == None:
|
1492 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
1493 |
+
state = torch.load(ckpt_fn, map_location='cpu')
|
1494 |
+
cfg = state['xp.cfg']
|
1495 |
+
cfg.device = str(device)
|
1496 |
+
weights = state['best_state']['model']
|
1497 |
+
assert cfg.compression_model == 'encodec', "Only Encodec model is supported for now."
|
1498 |
+
if encode_only:
|
1499 |
+
all_keys = list(weights.keys())
|
1500 |
+
for key in all_keys:
|
1501 |
+
if key.startswith('decoder'):
|
1502 |
+
del weights[key]
|
1503 |
+
kwargs = dict_from_config(getattr(cfg, 'encodec'))
|
1504 |
+
encoder_name = kwargs.pop('autoencoder')
|
1505 |
+
quantizer_name = kwargs.pop('quantizer')
|
1506 |
+
encoder, _ = get_encodec_autoencoder(encoder_name, cfg)
|
1507 |
+
quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension)
|
1508 |
+
frame_rate = kwargs['sample_rate'] // encoder.hop_length
|
1509 |
+
renormalize = kwargs.pop('renormalize', False)
|
1510 |
+
# deprecated params
|
1511 |
+
kwargs.pop('renorm', None)
|
1512 |
+
compression_model = EncodecModel_encode_only(encoder, quantizer,
|
1513 |
+
frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device)
|
1514 |
+
assert compression_model.sample_rate == cfg.sample_rate, "Compression model sample rate should match"
|
1515 |
+
compression_model.load_state_dict(weights)
|
1516 |
+
compression_model.eval()
|
1517 |
+
return compression_model
|
1518 |
+
|
1519 |
+
else:
|
1520 |
+
kwargs = dict_from_config(getattr(cfg, 'encodec'))
|
1521 |
+
encoder_name = kwargs.pop('autoencoder')
|
1522 |
+
quantizer_name = kwargs.pop('quantizer')
|
1523 |
+
encoder, decoder = get_encodec_autoencoder(encoder_name, cfg)
|
1524 |
+
quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension)
|
1525 |
+
frame_rate = kwargs['sample_rate'] // encoder.hop_length
|
1526 |
+
renormalize = kwargs.pop('renormalize', False)
|
1527 |
+
# deprecated params
|
1528 |
+
kwargs.pop('renorm', None)
|
1529 |
+
compression_model = EncodecModel(encoder, decoder, quantizer,
|
1530 |
+
frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device)
|
1531 |
+
assert compression_model.sample_rate == cfg.sample_rate, "Compression model sample rate should match"
|
1532 |
+
compression_model.load_state_dict(weights)
|
1533 |
+
compression_model.eval()
|
1534 |
+
return compression_model
|
1535 |
+
|
1536 |
+
if __name__ == "__main__":
|
1537 |
+
import torchaudio
|
1538 |
+
ckpt_fn = "/home/pyp/BoostedVoiceEditor/pretrained/encodec_6f79c6a8.th"
|
1539 |
+
audio_in_fns = ["/home/pyp/BoostedVoiceEditor/demo/pam.wav", "/home/pyp/BoostedVoiceEditor/demo/ray.wav", "/home/pyp/BoostedVoiceEditor/demo/84_121550_000074_000000.wav", "/home/pyp/BoostedVoiceEditor/demo/caribbean.wav", "/home/pyp/BoostedVoiceEditor/demo/bible.wav", "/home/pyp/BoostedVoiceEditor/demo/miley.wav"]
|
1540 |
+
audio_out_fns = ["/home/pyp/BoostedVoiceEditor/demo/pam_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/ray_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/84_121550_000074_000000_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/caribbean_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/bible_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/miley_encodecTest.wav"]
|
1541 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
1542 |
+
model = get_compression_model(ckpt_fn, device=device)
|
1543 |
+
|
1544 |
+
for audio_in_fn, audio_out_fn in zip(audio_in_fns, audio_out_fns):
|
1545 |
+
audio_in, sr = torchaudio.load(audio_in_fn)
|
1546 |
+
if sr != model.sample_rate:
|
1547 |
+
audio_in = torchaudio.transforms.Resample(sr, model.sample_rate)(audio_in)
|
1548 |
+
if audio_in.shape[0] == 2:
|
1549 |
+
audio_in = audio_in.mean(dim=0, keepdim=True)
|
1550 |
+
audio_in = audio_in.unsqueeze(0)
|
1551 |
+
audio_in = audio_in.to(torch.float32).to(device)
|
1552 |
+
codes = model.encode(audio_in)[0]
|
1553 |
+
audio_out = model.decode(codes)[0].cpu()
|
1554 |
+
torchaudio.save(audio_out_fn, audio_out, model.sample_rate)
|
data/ll60k_preprocessing/step1_download.sh
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# define where to store the the downloaded data
|
2 |
+
dataroot=$DATAROOT
|
3 |
+
mkdir -p $dataroot
|
4 |
+
manifestroot=$dataroot/libriheavy
|
5 |
+
mkdir -p $manifestroot
|
6 |
+
audioroot=$dataroot/audio
|
7 |
+
mkdir -p $audioroot
|
8 |
+
|
9 |
+
# download libriheavy_long and libriheavy
|
10 |
+
cd $manifestroot
|
11 |
+
wget https://huggingface.co/datasets/pkufool/libriheavy/resolve/main/libriheavy_cuts_dev.jsonl.gz?download=true -O libriheavy_cuts_dev.jsonl.gz
|
12 |
+
wget https://huggingface.co/datasets/pkufool/libriheavy/resolve/main/libriheavy_cuts_test_clean.jsonl.gz?download=true -O libriheavy_cuts_test_clean.jsonl.gz
|
13 |
+
wget https://huggingface.co/datasets/pkufool/libriheavy/resolve/main/libriheavy_cuts_test_other.jsonl.gz?download=true -O libriheavy_cuts_test_other.jsonl.gz
|
14 |
+
wget https://huggingface.co/datasets/pkufool/libriheavy/resolve/main/libriheavy_cuts_small.jsonl.gz?download=true -O libriheavy_cuts_small.jsonl.gz
|
15 |
+
wget https://huggingface.co/datasets/pkufool/libriheavy/resolve/main/libriheavy_cuts_medium.jsonl.gz?download=true -O libriheavy_cuts_medium.jsonl.gz
|
16 |
+
wget https://huggingface.co/datasets/pkufool/libriheavy/resolve/main/libriheavy_cuts_large.jsonl.gz?download=true -O libriheavy_cuts_large.jsonl.gz
|
17 |
+
wget https://huggingface.co/datasets/pkufool/libriheavy_long/resolve/main/libriheavy_cuts_small.jsonl.gz?download=true -O libriheavy_long_original_cuts_small.jsonl.gz
|
18 |
+
wget https://huggingface.co/datasets/pkufool/libriheavy_long/resolve/main/libriheavy_cuts_medium.jsonl.gz?download=true -O libriheavy_long_original_cuts_medium.jsonl.gz
|
19 |
+
wget https://huggingface.co/datasets/pkufool/libriheavy_long/resolve/main/libriheavy_cuts_large.jsonl.gz?download=true -O libriheavy_long_original_cuts_large.jsonl.gz
|
20 |
+
|
21 |
+
# turn .jsonl.gz to .jsonl
|
22 |
+
gunzip -k libriheavy_cuts_dev.jsonl.gz
|
23 |
+
gunzip -k libriheavy_cuts_test_clean.jsonl.gz
|
24 |
+
gunzip -k libriheavy_cuts_test_other.jsonl.gz
|
25 |
+
gunzip -k libriheavy_cuts_small.jsonl.gz
|
26 |
+
gunzip -k libriheavy_cuts_medium.jsonl.gz
|
27 |
+
gunzip -k libriheavy_cuts_large.jsonl.gz
|
28 |
+
gunzip -k libriheavy_long_original_cuts_small.jsonl.gz
|
29 |
+
gunzip -k libriheavy_long_original_cuts_medium.jsonl.gz
|
30 |
+
gunzip -k libriheavy_long_original_cuts_large.jsonl.gz
|
31 |
+
|
32 |
+
# if librilight is already unzipped in origDATAROOT, then skip this step
|
33 |
+
# download ll
|
34 |
+
cd $audioroot
|
35 |
+
wget https://dl.fbaipublicfiles.com/librilight/data/small.tar
|
36 |
+
wget https://dl.fbaipublicfiles.com/librilight/data/medium.tar
|
37 |
+
wget https://dl.fbaipublicfiles.com/librilight/data/large.tar
|
38 |
+
|
39 |
+
# untar small, medium, large
|
40 |
+
tar -xf small.tar
|
41 |
+
tar -xf medium.tar
|
42 |
+
tar -xf large.tar
|
data/ll60k_preprocessing/step2_resplit_long.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# find split, spk, books in libriheavy_cuts_dev.jsonl, libriheavy_cuts_test_clean.jsonl, libriheavy_cuts_test_other.jsonl
|
2 |
+
# those would be in "id" field
|
3 |
+
|
4 |
+
import sys
|
5 |
+
import os, random, numpy as np, socket
|
6 |
+
import json
|
7 |
+
import tqdm
|
8 |
+
def write_jsonl(data, fn):
|
9 |
+
with open(fn, "w") as file:
|
10 |
+
for entry in data:
|
11 |
+
file.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
12 |
+
def read_jsonl(file_path):
|
13 |
+
cur_data = []
|
14 |
+
with open(file_path, 'r', encoding='utf-8-sig') as file:
|
15 |
+
for line in file:
|
16 |
+
cur_data.append(json.loads(line.strip()))
|
17 |
+
return cur_data
|
18 |
+
import os
|
19 |
+
dataroot=os.environ["DATAROOT"]
|
20 |
+
manifestroot=os.path.join(dataroot, "libriheavy")
|
21 |
+
tgt_names = ['libriheavy_cuts_dev.jsonl', 'libriheavy_cuts_test_clean.jsonl', 'libriheavy_cuts_test_other.jsonl']
|
22 |
+
orig_names = ['libriheavy_long_original_cuts_small.jsonl', 'libriheavy_long_original_cuts_medium.jsonl', 'libriheavy_long_original_cuts_large.jsonl']
|
23 |
+
|
24 |
+
id2split = {}
|
25 |
+
data = read_jsonl(os.path.join(manifestroot, "libriheavy_cuts_dev.jsonl"))
|
26 |
+
dev_ids = set(["/".join(item['id'].split("/")[:3]) for item in data])
|
27 |
+
data = read_jsonl(os.path.join(manifestroot, "libriheavy_cuts_test_clean.jsonl"))
|
28 |
+
test_clean_ids = set(["/".join(item['id'].split("/")[:3]) for item in data])
|
29 |
+
data = read_jsonl(os.path.join(manifestroot, "libriheavy_cuts_test_other.jsonl"))
|
30 |
+
test_other_ids = set(["/".join(item['id'].split("/")[:3]) for item in data])
|
31 |
+
|
32 |
+
long_dev = []
|
33 |
+
long_test_clean = []
|
34 |
+
long_test_other = []
|
35 |
+
for orig_name in orig_names:
|
36 |
+
keep = []
|
37 |
+
data = read_jsonl(os.path.join(manifestroot, orig_name))
|
38 |
+
for item in tqdm.tqdm(data):
|
39 |
+
if "/".join(item['id'].split("/")[:3]) in dev_ids:
|
40 |
+
long_dev.append(item)
|
41 |
+
elif "/".join(item['id'].split("/")[:3]) in test_clean_ids:
|
42 |
+
long_test_clean.append(item)
|
43 |
+
elif "/".join(item['id'].split("/")[:3]) in test_other_ids:
|
44 |
+
long_test_other.append(item)
|
45 |
+
else:
|
46 |
+
keep.append(item)
|
47 |
+
write_jsonl(keep, os.path.join(manifestroot, orig_name.replace("_original", "")))
|
48 |
+
|
49 |
+
write_jsonl(long_dev, os.path.join(manifestroot, "libriheavy_long_cuts_dev.jsonl"))
|
50 |
+
write_jsonl(long_test_clean, os.path.join(manifestroot, "libriheavy_long_cuts_test_clean.jsonl"))
|
51 |
+
write_jsonl(long_test_other, os.path.join(manifestroot, "libriheavy_long_cuts_test_other.jsonl"))
|
data/ll60k_preprocessing/step3_seg_phn_manifest.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from importlib.resources import path
|
2 |
+
import pathlib
|
3 |
+
import soundfile as sf
|
4 |
+
import numpy as np
|
5 |
+
import json
|
6 |
+
import multiprocessing
|
7 |
+
import argparse
|
8 |
+
import tqdm
|
9 |
+
import gzip
|
10 |
+
import time
|
11 |
+
import os
|
12 |
+
from tokenizer import TextTokenizer, tokenize_text
|
13 |
+
import glob
|
14 |
+
import sys
|
15 |
+
import os, random, numpy as np, socket
|
16 |
+
import json
|
17 |
+
import tqdm
|
18 |
+
def write_jsonl(data, fn):
|
19 |
+
with open(fn, "w") as file:
|
20 |
+
for entry in data:
|
21 |
+
file.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
22 |
+
def read_jsonl(file_path):
|
23 |
+
cur_data = []
|
24 |
+
with open(file_path, 'r', encoding='utf-8-sig') as file:
|
25 |
+
for line in file:
|
26 |
+
cur_data.append(json.loads(line.strip()))
|
27 |
+
return cur_data
|
28 |
+
def save_audio(seq, fn):
|
29 |
+
output = seq
|
30 |
+
os.makedirs(os.path.dirname(fn), exist_ok=True)
|
31 |
+
sf.write(fn, output, samplerate=16000)
|
32 |
+
|
33 |
+
def save_text(text, fn):
|
34 |
+
os.makedirs(os.path.dirname(fn), exist_ok=True)
|
35 |
+
with open(fn, "w") as wwf:
|
36 |
+
wwf.writelines(text)
|
37 |
+
|
38 |
+
def phonemize_and_save(text, fn):
|
39 |
+
phn = tokenize_text(text_tokenizer, text)
|
40 |
+
os.makedirs(os.path.dirname(fn), exist_ok=True)
|
41 |
+
with open(fn, "w") as f:
|
42 |
+
f.write(' '.join(phn))
|
43 |
+
return set(phn)
|
44 |
+
|
45 |
+
def cut_sequence(task):
|
46 |
+
in_audio_fn, output_dir, metadata = task
|
47 |
+
if not os.path.isfile(in_audio_fn):
|
48 |
+
# print("missing: ", in_audio_fn)
|
49 |
+
return None
|
50 |
+
data, samplerate = sf.read(in_audio_fn)
|
51 |
+
assert len(data.shape) == 1
|
52 |
+
assert samplerate == 16000
|
53 |
+
all_phns = set()
|
54 |
+
for item in metadata:
|
55 |
+
out_fn = item['file_id']
|
56 |
+
out_audio_fn = os.path.join(output_dir, "audio", out_fn)
|
57 |
+
out_text_fn = os.path.join(output_dir, "audio", out_fn.replace(".flac", ".txt"))
|
58 |
+
out_phn_fn = os.path.join(output_dir, "phoneme", out_fn.replace(".flac", ".txt"))
|
59 |
+
save_audio(data[int(item['vad'][0]*samplerate):int(item['vad'][1]*samplerate)], out_audio_fn)
|
60 |
+
save_text(item['text'], out_text_fn)
|
61 |
+
phns = phonemize_and_save(item['text'], out_phn_fn)
|
62 |
+
all_phns.update(phns)
|
63 |
+
|
64 |
+
return all_phns
|
65 |
+
|
66 |
+
|
67 |
+
from collections import defaultdict
|
68 |
+
# Function to create a defaultdict recursively
|
69 |
+
def nested_defaultdict(levels, inner_type):
|
70 |
+
if levels <= 1:
|
71 |
+
return defaultdict(inner_type)
|
72 |
+
return defaultdict(lambda: nested_defaultdict(levels-1, inner_type))
|
73 |
+
|
74 |
+
|
75 |
+
def open_mani(fn):
|
76 |
+
print("load segmentation and transcription metadata...")
|
77 |
+
stime = time.time()
|
78 |
+
data = []
|
79 |
+
with gzip.open(fn, 'rt', encoding='utf-8') as f:
|
80 |
+
for line in f:
|
81 |
+
data.append(json.loads(line))
|
82 |
+
print(f"loading done, took {time.time() - stime:.4f} seconds")
|
83 |
+
return data
|
84 |
+
|
85 |
+
def cut(split,
|
86 |
+
audio_dir,
|
87 |
+
mani_dir,
|
88 |
+
output_dir,
|
89 |
+
n_process=32,
|
90 |
+
percent=0.5):
|
91 |
+
split2manifest = {
|
92 |
+
"train": [
|
93 |
+
"libriheavy_long_cuts_small.jsonl",
|
94 |
+
"libriheavy_long_cuts_medium.jsonl",
|
95 |
+
"libriheavy_long_cuts_large.jsonl",
|
96 |
+
"libriheavy_cuts_small.jsonl",
|
97 |
+
"libriheavy_cuts_medium.jsonl",
|
98 |
+
"libriheavy_cuts_large.jsonl",
|
99 |
+
],
|
100 |
+
"valid": [
|
101 |
+
"libriheavy_cuts_dev.jsonl",
|
102 |
+
"libriheavy_long_cuts_dev.jsonl"
|
103 |
+
],
|
104 |
+
"test": [
|
105 |
+
"libriheavy_cuts_test_clean.jsonl",
|
106 |
+
"libriheavy_cuts_test_other.jsonl",
|
107 |
+
"libriheavy_long_cuts_test_clean.jsonl",
|
108 |
+
"libriheavy_long_cuts_test_other.jsonl"
|
109 |
+
]
|
110 |
+
}
|
111 |
+
|
112 |
+
print("organize data by recording_id (i.e. the original big .flac file name)...")
|
113 |
+
stime = time.time()
|
114 |
+
organized_data = nested_defaultdict(4, list)
|
115 |
+
manifest_fn = os.path.join(output_dir, "manifest_mimi", split+".txt")
|
116 |
+
os.makedirs(os.path.join(output_dir, "manifest_mimi"), exist_ok=True)
|
117 |
+
with open(manifest_fn, "w") as wf:
|
118 |
+
for mani_fn in split2manifest[split]:
|
119 |
+
# data = open_mani(os.path.join(mani_dir, mani_fn))
|
120 |
+
data = read_jsonl(os.path.join(mani_dir, mani_fn))
|
121 |
+
for item in data:
|
122 |
+
file_id = item['supervisions'][0]['id'] + '.flac'
|
123 |
+
recording_id = item['recording']['id'] + '.flac'
|
124 |
+
sizeSplit, spk, book, flac = recording_id.split("/") # e.g. 'medium/100/emerald_city_librivox_64kb_mp3/emeraldcity_01_baum_64kb'
|
125 |
+
if os.path.isfile(os.path.join(audio_dir, recording_id)):
|
126 |
+
vad = (item['start'], item['start']+item['duration'])
|
127 |
+
text = item['supervisions'][0]['custom']['texts'][0]
|
128 |
+
file_id = file_id.replace(".flac", "") + f"_{vad[0]:.2f}_{vad[1]:.2f}.flac"
|
129 |
+
organized_data[sizeSplit][spk][book][recording_id].append({"file_id": file_id, "vad":vad, "text": text})
|
130 |
+
wf.writelines(f"{file_id}\t{item['duration']}\n")
|
131 |
+
|
132 |
+
# #### take only a subet of tasks
|
133 |
+
tasks = [(os.path.join(audio_dir, recording_id), output_dir, organized_data[sizeSplit][spk][book][recording_id], spk) for sizeSplit in organized_data for spk in organized_data[sizeSplit] for book in organized_data[sizeSplit][spk] for recording_id in organized_data[sizeSplit][spk][book]]
|
134 |
+
ntasks = len(tasks)
|
135 |
+
spk2tasks = defaultdict(list)
|
136 |
+
for task in tasks:
|
137 |
+
spk2tasks[task[3]].append(task)
|
138 |
+
# randomly shuffle each task list for each speaker
|
139 |
+
for spk in spk2tasks:
|
140 |
+
random.shuffle(spk2tasks[spk])
|
141 |
+
# take only 20% of the tasks, uniformly sampled from each speaker
|
142 |
+
# randomly pick a speaker, and then randomly pick a task from that speaker
|
143 |
+
tasks = []
|
144 |
+
while len(tasks) < ntasks * percent:
|
145 |
+
spk = random.choice(list(spk2tasks.keys()))
|
146 |
+
if len(spk2tasks[spk]) == 0:
|
147 |
+
continue
|
148 |
+
tasks.append(spk2tasks[spk].pop()[:-1])
|
149 |
+
print(f"take only {percent*100:.2f}% of the tasks, {len(tasks)} out of {ntasks} tasks")
|
150 |
+
#### take only a subet of tasks
|
151 |
+
|
152 |
+
print(f"organizing done, took {time.time() - stime:.4f} seconds")
|
153 |
+
print(f"Launching {n_process} processes")
|
154 |
+
phn_vocab = set()
|
155 |
+
cnt = 0
|
156 |
+
with multiprocessing.Pool(processes=n_process) as pool:
|
157 |
+
for phns in tqdm.tqdm(pool.imap_unordered(cut_sequence, tasks), total=len(tasks)):
|
158 |
+
cnt += 1
|
159 |
+
if phns != None:
|
160 |
+
phn_vocab.update(phns)
|
161 |
+
|
162 |
+
# save phn vocabulary
|
163 |
+
if split == "train":
|
164 |
+
vocab_fn = os.path.join(output_dir, "vocab.txt")
|
165 |
+
with open(vocab_fn, "w") as f:
|
166 |
+
for i, phn in enumerate(list(phn_vocab)):
|
167 |
+
if i < len(list(phn_vocab)) - 1:
|
168 |
+
f.write(f"{str(i)}\t{phn}\n")
|
169 |
+
else:
|
170 |
+
f.write(f"{str(i)}\t{phn}")
|
171 |
+
|
172 |
+
|
173 |
+
def parse_args():
|
174 |
+
parser = argparse.ArgumentParser(description="Cut a dataset in small "
|
175 |
+
"sequences using VAD files")
|
176 |
+
parser.add_argument('--split', type=str, default='train', choices=['train', 'valid', 'test'], help="train = libriheavy_cuts_{small,medium,large}.jsonl.gz, valid = libriheavy_cuts_dev_{clean,other}.jsonl.gz, test = libriheavy_cuts_test_{clean,other}.jsonl.gz")
|
177 |
+
parser.add_argument('--audio_dir', type=str, default="/data/scratch/pyp/datasets/librilight_example",
|
178 |
+
help="Path to the audio directory")
|
179 |
+
parser.add_argument('--manifest_dir', type=str, default="/data/scratch/pyp/datasets/librilight/libriheavy", help="path to the transcription file's dir, can be downloaded https://huggingface.co/datasets/pkufool/libriheavy/tree/main/v0.1")
|
180 |
+
parser.add_argument('--output_dir', type=str, default="/data/scratch/pyp/datasets/librilight/librilight_example_preprocessed",
|
181 |
+
help="Path to the output directory")
|
182 |
+
parser.add_argument('--n_workers', type=int, default=16,
|
183 |
+
help="Number of parallel worker processes")
|
184 |
+
parser.add_argument('--percent', type=float, default=0.5, help="take only this percent of the tasks, randomly sampled from each speaker")
|
185 |
+
|
186 |
+
|
187 |
+
return parser.parse_args()
|
188 |
+
|
189 |
+
|
190 |
+
if __name__ == "__main__":
|
191 |
+
args = parse_args()
|
192 |
+
pathlib.Path(args.output_dir).mkdir(exist_ok=True, parents=True)
|
193 |
+
text_tokenizer = TextTokenizer()
|
194 |
+
cut(args.split, args.audio_dir, args.manifest_dir, args.output_dir, args.n_workers, args.percent)
|
data/ll60k_preprocessing/step4_encodec_encode.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from email.policy import default
|
3 |
+
def parse_args():
|
4 |
+
parser = argparse.ArgumentParser(description="encode the librilight dataset using codec model")
|
5 |
+
parser.add_argument('--dir', type=str, default="/data/scratch/pyp/datasets/librilight/librilight_example_preprocessed", help="Path to the directory")
|
6 |
+
parser.add_argument('--sub_root', type=str, default="preprocessed", help="sub directory")
|
7 |
+
parser.add_argument('--encodec_name', type=str, default="encodec_6f79c6a8.th", help="name of the codec model")
|
8 |
+
parser.add_argument('--n_workers', type=int, default=16, help="Number of parallel worker processes")
|
9 |
+
parser.add_argument('--batch_size', type=int, default=16, help="batch size for codec encoding, decrease it if OOM. This is the sum of batch size *over each gpu*, so increase it if you are using more gpus")
|
10 |
+
parser.add_argument('--audio_sr', type=int, default=16000, help='input audio sample rate')
|
11 |
+
parser.add_argument('--model_sr', type=int, default=16000, help='encodec input audio sample rate')
|
12 |
+
parser.add_argument('--downsample_rate', type=int, default=320, help='encodec downsample rate')
|
13 |
+
parser.add_argument('--model_code_sr', type=float, default=50, help='codec model code sample rate')
|
14 |
+
parser.add_argument('--len_cap', type=float, default=1000, help='will drop audios that are longer than this number')
|
15 |
+
parser.add_argument('--min_len', type=float, default=0.5, help='will drop audios that are shorter than this number')
|
16 |
+
parser.add_argument('--partition', type=str, default="1/1", help='split for parallel processing')
|
17 |
+
parser.add_argument('--split', type=str, default='train', choices=['train', 'valid', 'test'])
|
18 |
+
return parser.parse_args()
|
19 |
+
|
20 |
+
if __name__ == "__main__":
|
21 |
+
import logging
|
22 |
+
formatter = (
|
23 |
+
"%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s"
|
24 |
+
)
|
25 |
+
logging.basicConfig(format=formatter, level=logging.INFO)
|
26 |
+
|
27 |
+
import os, sys
|
28 |
+
import numpy as np
|
29 |
+
import torch
|
30 |
+
import torchaudio
|
31 |
+
import tqdm
|
32 |
+
import time
|
33 |
+
|
34 |
+
args = parse_args()
|
35 |
+
|
36 |
+
def sort_by_audio_len(lens):
|
37 |
+
inds = np.argsort(lens).tolist()
|
38 |
+
if len(inds) < 10:
|
39 |
+
return inds[::-1]
|
40 |
+
logging.info(f"longest: {lens[inds[-1]]/args.downsample_rate} encodec codes, {lens[inds[-1]]/args.model_sr:.2f} sec.")
|
41 |
+
logging.info(f"shortest: {lens[inds[0]]/args.downsample_rate} encodec codes, {lens[inds[0]]/args.model_sr:.2f} sec.")
|
42 |
+
logging.info(f"median: {lens[inds[len(inds)//2]]/args.downsample_rate} encodec codes, {lens[inds[len(inds)//2]]/args.model_sr:.2f} sec.")
|
43 |
+
logging.info(f"95 percentile longest: {lens[inds[int(len(inds)*0.95)]]/args.downsample_rate} encodec codes, {lens[inds[int(len(inds)*0.95)]]/args.model_sr:.2f} sec.")
|
44 |
+
return inds[::-1]
|
45 |
+
|
46 |
+
def write_array_to_txt_file(array, filename):
|
47 |
+
with open(filename, 'w') as f:
|
48 |
+
for a in array[:-1]:
|
49 |
+
f.write(' '.join(map(str, a))+'\n')
|
50 |
+
f.write(' '.join(map(str, array[-1])))
|
51 |
+
|
52 |
+
class mydataset(torch.utils.data.Dataset):
|
53 |
+
def __init__(self, split):
|
54 |
+
super().__init__()
|
55 |
+
self.split = split
|
56 |
+
self.audio_dir = audio_dir
|
57 |
+
manifest_fn = os.path.join(encodec_manifest_dir, split+".txt")
|
58 |
+
cur_sp = int(args.partition.split("/")[0])-1
|
59 |
+
total_sp = int(args.partition.split("/")[1])
|
60 |
+
with open(manifest_fn, "r") as rf:
|
61 |
+
self.data = [l.strip().split("\t") for l in rf.readlines()][cur_sp::total_sp]
|
62 |
+
self.data = [l for l in self.data if os.path.isfile(os.path.join(self.audio_dir, l[0]))]
|
63 |
+
def __len__(self):
|
64 |
+
return len(self.data)
|
65 |
+
def __getitem__(self, ind):
|
66 |
+
try:
|
67 |
+
afn = self.data[ind][0]
|
68 |
+
fn = os.path.join(self.audio_dir, afn)
|
69 |
+
audio, sr = torchaudio.load(fn)
|
70 |
+
if sr != args.model_sr:
|
71 |
+
audio = torchaudio.transforms.Resample(sr, args.model_sr)(audio)
|
72 |
+
sr = args.model_sr
|
73 |
+
assert sr == args.model_sr, sr
|
74 |
+
except Exception as e:
|
75 |
+
# logging.info(f"{e}")
|
76 |
+
return None, None, None
|
77 |
+
assert audio.ndim==2 and audio.shape[0] == 1, audio.shape
|
78 |
+
return audio.type(torch.float32).squeeze(0), audio.shape[-1], os.path.splitext(afn)[0]
|
79 |
+
def collate(self, batch):
|
80 |
+
lens, audios, segment_ids = [], [], []
|
81 |
+
for item in batch:
|
82 |
+
if item[0] != None:
|
83 |
+
audios.append(item[0])
|
84 |
+
lens.append(item[1])
|
85 |
+
segment_ids.append(item[2])
|
86 |
+
return audios, lens, segment_ids
|
87 |
+
|
88 |
+
# roots
|
89 |
+
sub_root = args.sub_root
|
90 |
+
encodec_manifest_dir = os.path.join(args.dir, sub_root, "manifest_mimi")
|
91 |
+
audio_dir = os.path.join(args.dir, sub_root, "audio")
|
92 |
+
save_manifest_dir = os.path.join(args.dir, sub_root,"manifest_final_encodec")
|
93 |
+
if args.encodec_name == "encodec_6f79c6a8.th":
|
94 |
+
save_codes_dir = os.path.join(args.dir, sub_root,"encodec_4cb")
|
95 |
+
elif args.encodec_name == "encodec_8cb1024_giga.th":
|
96 |
+
save_codes_dir = os.path.join(args.dir, sub_root,"encodec_8cb")
|
97 |
+
|
98 |
+
os.makedirs(save_manifest_dir, exist_ok=True)
|
99 |
+
os.makedirs(save_codes_dir, exist_ok=True)
|
100 |
+
|
101 |
+
# load the encodec model
|
102 |
+
def import_encodec():
|
103 |
+
from encodec import get_compression_model
|
104 |
+
userdir = os.path.expanduser("~")
|
105 |
+
model = get_compression_model(os.path.join(userdir, "VoiceStar", f"pretrained/{args.encodec_name}"), encode_only=True, device="cuda")
|
106 |
+
model = torch.nn.DataParallel(model)
|
107 |
+
return model
|
108 |
+
model = import_encodec()
|
109 |
+
|
110 |
+
# setup dataloader
|
111 |
+
mega_batch_size = 1024
|
112 |
+
batch_size = args.batch_size
|
113 |
+
|
114 |
+
dataset = mydataset(args.split)
|
115 |
+
if len(dataset) == 0:
|
116 |
+
logging.info(f"no data found for split {args.split} partition {args.partition}")
|
117 |
+
sys.exit(0)
|
118 |
+
loader = torch.torch.utils.data.DataLoader(dataset, batch_size=mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=dataset.collate)
|
119 |
+
split = args.split
|
120 |
+
|
121 |
+
skip = 0
|
122 |
+
logging.info(f"now processing split {split} partition {args.partition}...")
|
123 |
+
mega_n_steps = int(np.ceil(len(loader.dataset) / mega_batch_size))
|
124 |
+
# mega_n_steps = int(np.ceil(len(gs) / mega_batch_size))
|
125 |
+
logging.info(f"partition the split {split} into {mega_n_steps} parts, each has at most {mega_batch_size} samples")
|
126 |
+
mani_fn = os.path.join(save_manifest_dir, f"{split}_{args.partition.replace('/', '=')}.txt")
|
127 |
+
logging.info(f"manifest for split {split} partition {args.partition.replace('/', '=')}.txt will be saved at {mani_fn}")
|
128 |
+
with open(mani_fn, "w") as mani_wf:
|
129 |
+
# with open(mani_fn, "a") as mani_wf: # resume from where we failed
|
130 |
+
for m, mega_batch in enumerate(tqdm.tqdm(loader)):
|
131 |
+
|
132 |
+
logging.info(f"====================================")
|
133 |
+
logging.info(f"====================================")
|
134 |
+
logging.info(f"now processing mega step {m+1}/{mega_n_steps}")
|
135 |
+
try:
|
136 |
+
# if True:
|
137 |
+
lengths = np.array(mega_batch[1])
|
138 |
+
if len(lengths) == 0: # the loader might not find any audio because step3 will write to manifest first, and then might selection a subset to cut and save audio
|
139 |
+
continue
|
140 |
+
sorted_inds = sort_by_audio_len(lengths)
|
141 |
+
for j in range(len(sorted_inds))[::-1]:
|
142 |
+
if lengths[sorted_inds[j]] < args.model_sr*args.min_len or lengths[sorted_inds[j]] > args.model_sr*args.len_cap: # skip samples that are too short (shorter than 0.2s), or too big (bigger than 80s)
|
143 |
+
skip += 1
|
144 |
+
del sorted_inds[j]
|
145 |
+
|
146 |
+
n_steps = int(np.ceil(len(sorted_inds) / batch_size))
|
147 |
+
for n in tqdm.tqdm(range(n_steps), disable=True):
|
148 |
+
inds_used = sorted_inds[n*batch_size:(n+1)*batch_size]
|
149 |
+
while len(inds_used) < batch_size:
|
150 |
+
inds_used += sorted_inds[:batch_size-len(inds_used)]
|
151 |
+
wav_batch = [mega_batch[0][id] for id in inds_used]
|
152 |
+
all_lens = [mega_batch[1][id] for id in inds_used]
|
153 |
+
segment_id_batch = [mega_batch[2][id] for id in inds_used]
|
154 |
+
padded_wav = torch.nn.utils.rnn.pad_sequence(wav_batch, batch_first=True).unsqueeze(1) # [B, T] -> [B, 1, T]
|
155 |
+
# Extract discrete codes from EnCodec
|
156 |
+
with torch.no_grad():
|
157 |
+
if max(all_lens) > 300000 and len(all_lens) > 1: # if utterances are long, simply pass half of them at a time
|
158 |
+
codes = []
|
159 |
+
inwav = padded_wav.cuda()
|
160 |
+
codes.append(model(inwav[:len(inwav)//2])[0].cpu())
|
161 |
+
codes.append(model(inwav[len(inwav)//2:])[0].cpu())
|
162 |
+
codes = torch.cat(codes, dim=0)
|
163 |
+
else:
|
164 |
+
encoded_frames = model(padded_wav.cuda())
|
165 |
+
codes = encoded_frames[0].cpu() # [B, n_codebook, T]
|
166 |
+
|
167 |
+
for i, length in enumerate(all_lens):
|
168 |
+
save_fn = os.path.join(save_codes_dir, segment_id_batch[i]+".txt")
|
169 |
+
actual_len = round(length / args.downsample_rate) # 320 is downsample rate for this model
|
170 |
+
cur_code = codes[i].tolist() if type(codes) == list else codes[i, :, :actual_len].tolist()
|
171 |
+
os.makedirs(os.path.dirname(save_fn), exist_ok=True)
|
172 |
+
write_array_to_txt_file(cur_code, save_fn)
|
173 |
+
|
174 |
+
mani_wf.write(f"{segment_id_batch[i]}\t{len(cur_code[0])}\n") # write to manifest file
|
175 |
+
# if i == 10:
|
176 |
+
# raise
|
177 |
+
except Exception as e:
|
178 |
+
print(f'exception!! at {m+1}')
|
179 |
+
print(e)
|
180 |
+
continue
|
181 |
+
|
182 |
+
# break
|
183 |
+
logging.info(f"split {split} partition {args.partition} has {len(loader.dataset)} samples in total, skipped {skip} due to utterance being too long or too short")
|
184 |
+
# break
|
data/ll60k_preprocessing/step4_encodec_encode_script.sh
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
source ~/miniconda3/etc/profile.d/conda.sh
|
3 |
+
conda activate voicestar
|
4 |
+
|
5 |
+
dir=${dir:-/data/scratch/pyp/datasets/librilight}
|
6 |
+
sub_root=${sub_root:-preprocessed}
|
7 |
+
encodec_name=${encodec_name:-encodec_6f79c6a8.th} # or encodec_8cb1024_giga.th
|
8 |
+
n_workers=${n_workers:-12}
|
9 |
+
batch_size=${batch_size:-64}
|
10 |
+
audio_sr=16000
|
11 |
+
model_sr=16000
|
12 |
+
downsample_rate=320
|
13 |
+
model_code_sr=50
|
14 |
+
len_cap=1000
|
15 |
+
min_len=0.5
|
16 |
+
partition=${partition:-"1/1"}
|
17 |
+
split=${split:-"valid"}
|
18 |
+
|
19 |
+
python step4_encodec_encode.py --dir $dir --sub_root ${sub_root} --encodec_name ${encodec_name} --n_workers $n_workers --batch_size $batch_size --audio_sr $audio_sr --model_sr $model_sr --downsample_rate $downsample_rate --model_code_sr $model_code_sr --len_cap $len_cap --min_len $min_len --partition $partition --split $split
|
data/ll60k_preprocessing/step5_find_nearest_neighbor.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# for each each audio segment, find the non-overlapping neighboring segments
|
2 |
+
from importlib.resources import path
|
3 |
+
import pathlib
|
4 |
+
import soundfile as sf
|
5 |
+
import numpy as np
|
6 |
+
import json
|
7 |
+
import multiprocessing
|
8 |
+
import argparse
|
9 |
+
import tqdm
|
10 |
+
import gzip
|
11 |
+
import time
|
12 |
+
import os
|
13 |
+
from tokenizer import TextTokenizer, tokenize_text
|
14 |
+
import glob
|
15 |
+
import sys
|
16 |
+
import os, random, numpy as np, socket
|
17 |
+
import json
|
18 |
+
import tqdm
|
19 |
+
import json
|
20 |
+
import tqdm
|
21 |
+
def write_jsonl(data, fn):
|
22 |
+
with open(fn, "w") as file:
|
23 |
+
for entry in data:
|
24 |
+
file.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
25 |
+
def read_jsonl(file_path):
|
26 |
+
cur_data = []
|
27 |
+
with open(file_path, 'r', encoding='utf-8-sig') as file:
|
28 |
+
for line in file:
|
29 |
+
cur_data.append(json.loads(line.strip()))
|
30 |
+
return cur_data
|
31 |
+
from collections import defaultdict
|
32 |
+
# Function to create a defaultdict recursively
|
33 |
+
def nested_defaultdict(levels, inner_type):
|
34 |
+
if levels <= 1:
|
35 |
+
return defaultdict(inner_type)
|
36 |
+
return defaultdict(lambda: nested_defaultdict(levels-1, inner_type))
|
37 |
+
|
38 |
+
def find_neighbor(args):
|
39 |
+
split2manifest = {
|
40 |
+
"train": [
|
41 |
+
"libriheavy_cuts_small.jsonl",
|
42 |
+
"libriheavy_cuts_medium.jsonl",
|
43 |
+
"libriheavy_cuts_large.jsonl",
|
44 |
+
"libriheavy_long_cuts_small.jsonl",
|
45 |
+
"libriheavy_long_cuts_medium.jsonl",
|
46 |
+
"libriheavy_long_cuts_large.jsonl"
|
47 |
+
],
|
48 |
+
"valid": [
|
49 |
+
"libriheavy_cuts_dev.jsonl",
|
50 |
+
"libriheavy_long_cuts_dev.jsonl"
|
51 |
+
],
|
52 |
+
"test": [
|
53 |
+
"libriheavy_cuts_test_clean.jsonl",
|
54 |
+
"libriheavy_cuts_test_other.jsonl",
|
55 |
+
"libriheavy_long_cuts_test_clean.jsonl",
|
56 |
+
"libriheavy_long_cuts_test_other.jsonl"
|
57 |
+
]
|
58 |
+
}
|
59 |
+
|
60 |
+
stime = time.time()
|
61 |
+
organized_data = nested_defaultdict(4, list)
|
62 |
+
for mani_fn in split2manifest[args.split]:
|
63 |
+
# data = open_mani(os.path.join(mani_dir, mani_fn))
|
64 |
+
mani_full_fn = os.path.join(args.manifest_dir, mani_fn)
|
65 |
+
data = read_jsonl(mani_full_fn)
|
66 |
+
for item in data:
|
67 |
+
file_id = item['supervisions'][0]['id'] + '.flac'
|
68 |
+
recording_id = item['recording']['id'] + '.flac'
|
69 |
+
sizeSplit, spk, book, flac = recording_id.split("/") # e.g. 'medium/100/emerald_city_librivox_64kb_mp3/emeraldcity_01_baum_64kb'
|
70 |
+
if os.path.isfile(os.path.join(args.audio_dir, recording_id)):
|
71 |
+
vad = (item['start'], item['start']+item['duration'])
|
72 |
+
text = item['supervisions'][0]['custom']['texts'][0]
|
73 |
+
file_id = file_id.replace(".flac", "") + f"_{vad[0]:.2f}_{vad[1]:.2f}.flac"
|
74 |
+
organized_data[sizeSplit][spk][book][recording_id].append({"file_id": file_id, "vad":vad, "text": text})
|
75 |
+
|
76 |
+
# # for each recording_id, find the non-overlapping neighboring segments based on vad
|
77 |
+
# for sizeSplit in organized_data:
|
78 |
+
# for spk in organized_data[sizeSplit]:
|
79 |
+
# for book in organized_data[sizeSplit][spk]:
|
80 |
+
# for recording_id in organized_data[sizeSplit][spk][book]:
|
81 |
+
# segments = organized_data[sizeSplit][spk][book][recording_id]
|
82 |
+
# segments.sort(key=lambda x: x['vad'][0])
|
83 |
+
# for i in range(len(segments)):
|
84 |
+
# # for segment i, find the non-overlapping neighboring segments
|
85 |
+
# write_fn = os.path.join(args.output_dir, f"{segments[i]['file_id'].replace('.flac', '.txt')}")
|
86 |
+
# neighbors = []
|
87 |
+
# distance = []
|
88 |
+
# for j in range(len(segments)):
|
89 |
+
# if segments[i]['vad'][1] < segments[j]['vad'][0] or segments[i]['vad'][0] > segments[j]['vad'][0]:
|
90 |
+
# neighbors.append(segments[j]['file_id'].replace('.flac', '.txt'))
|
91 |
+
# distance.append(min(abs(segments[i]['vad'][1] - segments[j]['vad'][0]), abs(segments[i]['vad'][0] - segments[j]['vad'][1])))
|
92 |
+
# # order neighbors by distance
|
93 |
+
# neighbors_distance = [[x, dist] for dist, x in sorted(zip(distance, neighbors))]
|
94 |
+
# os.makedirs(os.path.dirname(write_fn), exist_ok=True)
|
95 |
+
# with open(write_fn, "w") as f:
|
96 |
+
# # note that there might be no neighbors, in which case the file is empty
|
97 |
+
# for neighbor, dist in neighbors_distance:
|
98 |
+
# f.write(f"{neighbor}\t{dist}\n")
|
99 |
+
|
100 |
+
# use multiprocessing.Pool for the above
|
101 |
+
segments = [organized_data[sizeSplit][spk][book][recording_id] for sizeSplit in organized_data for spk in organized_data[sizeSplit] for book in organized_data[sizeSplit][spk] for recording_id in organized_data[sizeSplit][spk][book]]
|
102 |
+
# only keep those that are exist
|
103 |
+
print(f"originally total {len(segments)} segments")
|
104 |
+
segments = [seg for seg in segments if os.path.isfile(os.path.join("/".join(args.output_dir.split("/")[:-1]),"audio", seg[0]['file_id']))]
|
105 |
+
print(f"after check existance, total {len(segments)} segments")
|
106 |
+
print(f"organizing took {(time.time()-stime)/60:.2f} minutes")
|
107 |
+
with multiprocessing.Pool(processes=args.n_workers) as pool:
|
108 |
+
for _ in tqdm.tqdm(pool.imap_unordered(find_neighbor_each, segments), total=len(segments)):
|
109 |
+
pass
|
110 |
+
|
111 |
+
# audio_root = "/data/scratch/pyp/datasets/librilight/preprocessed/audio"
|
112 |
+
def find_neighbor_each(segments):
|
113 |
+
# for each recording_id, find the non-overlapping neighboring segments based on vad
|
114 |
+
# only keep segments that have audio files
|
115 |
+
# actually only keep segments that have ipa_alignment files
|
116 |
+
segments = [seg for seg in segments if os.path.isfile(os.path.join("/".join(args.output_dir.split("/")[:-1]),"ipa_alignment", seg['file_id'].replace(".flac", ".txt")))]
|
117 |
+
if len(segments) <= 1:
|
118 |
+
return
|
119 |
+
for i in range(len(segments)):
|
120 |
+
# for segment i, find the non-overlapping neighboring segments
|
121 |
+
write_fn = os.path.join(args.output_dir, f"{segments[i]['file_id'].replace('.flac', '.txt')}")
|
122 |
+
neighbors = []
|
123 |
+
distance = []
|
124 |
+
for j in range(len(segments)):
|
125 |
+
if segments[i]['vad'][1] < segments[j]['vad'][0] or segments[i]['vad'][0] > segments[j]['vad'][0]:
|
126 |
+
neighbors.append(segments[j])
|
127 |
+
distance.append(min(abs(segments[i]['vad'][1] - segments[j]['vad'][0]), abs(segments[i]['vad'][0] - segments[j]['vad'][1])))
|
128 |
+
if len(neighbors) == 0:
|
129 |
+
continue
|
130 |
+
# order neighbors by distance
|
131 |
+
index = np.argsort(distance)
|
132 |
+
neighbors_distance = [[neighbors[ind], distance[ind]] for ind in index]
|
133 |
+
os.makedirs(os.path.dirname(write_fn), exist_ok=True)
|
134 |
+
with open(write_fn, "w") as f:
|
135 |
+
# note that there might be no neighbors, in which case the file is empty
|
136 |
+
for neighbor, dist in neighbors_distance:
|
137 |
+
f.write(f"{neighbor['file_id'].replace('.flac', '.txt')}\t{dist}\t{neighbor['vad'][1] - neighbor['vad'][0]}\n") # file_id, distance, duration
|
138 |
+
|
139 |
+
|
140 |
+
|
141 |
+
def parse_args():
|
142 |
+
parser = argparse.ArgumentParser(description="Cut a dataset in small "
|
143 |
+
"sequences using VAD files")
|
144 |
+
parser.add_argument('--split', type=str, default='train', choices=['train', 'valid', 'test'], help="train = libriheavy_cuts_{small,medium,large}.jsonl.gz, valid = libriheavy_cuts_dev_{clean,other}.jsonl.gz, test = libriheavy_cuts_test_{clean,other}.jsonl.gz")
|
145 |
+
parser.add_argument('--audio_dir', type=str, default="/data/scratch/pyp/datasets/librilight_example",
|
146 |
+
help="Path to the audio directory")
|
147 |
+
parser.add_argument('--manifest_dir', type=str, default="/data/scratch/pyp/datasets/librilight/libriheavy", help="path to the transcription file's dir, can be downloaded https://huggingface.co/datasets/pkufool/libriheavy/tree/main/v0.1")
|
148 |
+
parser.add_argument('--output_dir', type=str, default="/data/scratch/pyp/datasets/librilight/librilight_example_preprocessed/neighbors",
|
149 |
+
help="Path to the output directory")
|
150 |
+
parser.add_argument('--n_workers', type=int, default=16,
|
151 |
+
help="Number of parallel worker processes")
|
152 |
+
return parser.parse_args()
|
153 |
+
|
154 |
+
if __name__ == "__main__":
|
155 |
+
args = parse_args()
|
156 |
+
pathlib.Path(args.output_dir).mkdir(exist_ok=True, parents=True)
|
157 |
+
find_neighbor(args)
|
data/ll60k_preprocessing/step6_forced_alignment.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
import subprocess, tqdm
|
3 |
+
from concurrent.futures import ThreadPoolExecutor
|
4 |
+
|
5 |
+
def align_folders(audio_root, subfolder, subsubfolder):
|
6 |
+
# Construct output folder path
|
7 |
+
file_root = os.path.dirname(audio_root)
|
8 |
+
out_folder = f"{file_root}/alignment/{subfolder}/{subsubfolder}"
|
9 |
+
|
10 |
+
# Create the output directory
|
11 |
+
os.makedirs(out_folder, exist_ok=True)
|
12 |
+
|
13 |
+
# Construct the MFA align command
|
14 |
+
command = [
|
15 |
+
"mfa", "align", "--single_speaker", "-j", "8", "--clean",
|
16 |
+
f"{audio_root}/{subfolder}/{subsubfolder}", "english_us_arpa", "english_us_arpa",
|
17 |
+
out_folder, "--beam", "50", "--retry_beam", "400", "--output_format", "csv"
|
18 |
+
]
|
19 |
+
|
20 |
+
# Run the command
|
21 |
+
subprocess.run(command, check=True)
|
22 |
+
|
23 |
+
def main(file_root = "/data/scratch/pyp/datasets/librilight/librilight_example_preprocessed", max_parallel_jobs=10, max_spk=100, partition="1/10", n_workers=64):
|
24 |
+
# Find all subfolder/subsubfolder combinations
|
25 |
+
tasks = []
|
26 |
+
audio_root = os.path.join(file_root, "audio")
|
27 |
+
for subfolder in os.listdir(audio_root):
|
28 |
+
subfolder_path = os.path.join(audio_root, subfolder)
|
29 |
+
if os.path.isdir(subfolder_path):
|
30 |
+
for subsubfolder in os.listdir(subfolder_path):
|
31 |
+
subsubfolder_path = os.path.join(subfolder_path, subsubfolder)
|
32 |
+
if os.path.isdir(subsubfolder_path):
|
33 |
+
tasks.append((audio_root, subfolder, subsubfolder))
|
34 |
+
speaker_folder_map = {}
|
35 |
+
for audio_root, subfolder, subsubfolder in tasks:
|
36 |
+
if os.path.join(audio_root, subfolder) not in speaker_folder_map:
|
37 |
+
speaker_folder_map[os.path.join(audio_root, subfolder)] = [os.path.join(audio_root, subfolder, subsubfolder)]
|
38 |
+
else:
|
39 |
+
speaker_folder_map[os.path.join(audio_root, subfolder)].append(os.path.join(audio_root, subfolder, subsubfolder))
|
40 |
+
speaker_folder_partitions = []
|
41 |
+
for audio_root_subfolder, speaker_folders in speaker_folder_map.items():
|
42 |
+
speaker_folder_partitions.extend([speaker_folders[i:i+max_spk] for i in range(0, len(speaker_folders), max_spk)])
|
43 |
+
s, e = partition.split("/")
|
44 |
+
s, e = int(s)-1, int(e)
|
45 |
+
cur_tasks = speaker_folder_partitions[s::e]
|
46 |
+
import secrets, string
|
47 |
+
import soundfile, glob
|
48 |
+
from joblib import Parallel, delayed
|
49 |
+
def delete_corrupted(fn):
|
50 |
+
try:
|
51 |
+
x = soundfile.read(fn)
|
52 |
+
except:
|
53 |
+
print(f"removing corrupted file: {fn}")
|
54 |
+
os.remove(fn)
|
55 |
+
|
56 |
+
for j, task in enumerate(tqdm.tqdm(cur_tasks)):
|
57 |
+
# get subfolder for the current task
|
58 |
+
subs = [item.split("/")[-2] for item in task]
|
59 |
+
# assert that all subs are the same
|
60 |
+
assert len(set(subs)) == 1, subs
|
61 |
+
sub = subs[0]
|
62 |
+
# randomly generate a foldername
|
63 |
+
# generate a random character
|
64 |
+
# make softlink from item in task to temp folder
|
65 |
+
random_string = ''.join(secrets.choice(string.ascii_letters + string.digits) for i in range(10))
|
66 |
+
temp_folder = os.path.join(file_root, "softlink_audio", random_string)
|
67 |
+
os.makedirs(temp_folder, exist_ok=True)
|
68 |
+
out_folder = f"{file_root}/alignment/{sub}"
|
69 |
+
all_out_speaker_folders = [os.path.join(out_folder, os.path.basename(item)) for item in task]
|
70 |
+
if sum(os.path.isdir(curpath) for curpath in all_out_speaker_folders) == len(all_out_speaker_folders):
|
71 |
+
continue
|
72 |
+
# remove audio files that are corrupted
|
73 |
+
all_audio_files = [audiofile for item in task for audiofile in glob.glob(item+"/*/*.flac")]
|
74 |
+
Parallel(n_jobs=n_workers)(delayed(delete_corrupted)(audiofn) for audiofn in all_audio_files)
|
75 |
+
for item in task:
|
76 |
+
# make softlink from subsubfolder to a new folder in temp folder
|
77 |
+
os.symlink(item, os.path.join(temp_folder, os.path.basename(item)))
|
78 |
+
# run mfa on the linked folder, but save alignment to the correct folder
|
79 |
+
command = f"mfa align -j {n_workers} {temp_folder} english_us_arpa english_us_arpa {out_folder} --beam 50 --retry_beam 200 --output_format csv --quiet --use_mp --temporary_directory {temp_folder}_temp"
|
80 |
+
os.system(command)
|
81 |
+
# delete the temp_folder
|
82 |
+
os.system(f"rm -r {temp_folder}")
|
83 |
+
|
84 |
+
if __name__ == "__main__":
|
85 |
+
import fire
|
86 |
+
fire.Fire(main)
|
data/ll60k_preprocessing/step6_forced_alignment.sh
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
source ~/miniconda3/etc/profile.d/conda.sh
|
3 |
+
conda activate voicecraft
|
4 |
+
|
5 |
+
partition=$1
|
6 |
+
file_root=$2
|
7 |
+
max_spk=${max_spk:-100}
|
8 |
+
n_workers=${n_workers:-64}
|
9 |
+
python step6_forced_alignment.py \
|
10 |
+
--partition $partition \
|
11 |
+
--file_root $file_root \
|
12 |
+
--max_spk $max_spk \
|
13 |
+
--n_workers $n_workers
|
data/ll60k_preprocessing/step7_ipa_alignment.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# we have raw transcript at
|
2 |
+
# /data/scratch/pyp/datasets/librilight/preprocessed/audio
|
3 |
+
# we have word and ARPA alignment at
|
4 |
+
# /data/scratch/pyp/datasets/librilight/preprocessed/alignment
|
5 |
+
|
6 |
+
# we have manifest at /data/scratch/pyp/datasets/librilight/preprocessed/manifest_mimi
|
7 |
+
# where each row is like large/10022/essayoncriticism_1505_librivox_64kb_mp3/essayoncriticism_01_pope_64kb_5_610.32_630.08.flac 19.76
|
8 |
+
|
9 |
+
# we want to create IPA alignment from the raw transcript and word alignment, using phonemizer
|
10 |
+
# save at /data/scratch/pyp/datasets/librilight/preprocessed/ipa_alignment
|
11 |
+
|
12 |
+
# since ipa phonemized results are 1-to-1 with words (10 words might lead to a ipa sequence of 7 phonemes), we have to run phonemizer on each segment of the word sequence
|
13 |
+
import os, string, csv, random, tqdm, glob
|
14 |
+
from tokenizer import TextTokenizer, tokenize_text
|
15 |
+
|
16 |
+
|
17 |
+
def remove_punctuation(input_string):
|
18 |
+
translator = str.maketrans('', '', string.punctuation)
|
19 |
+
return input_string.translate(translator)
|
20 |
+
|
21 |
+
|
22 |
+
|
23 |
+
def create_alignment(fn, trans_dir, align_dir, audio_ext, trans_ext, arpa_ext, text_tokenizer, use_prob, ipa_alignment_fn, save=False, prompt_dur=30):
|
24 |
+
os.makedirs(os.path.dirname(ipa_alignment_fn), exist_ok=True)
|
25 |
+
trans_fn = os.path.join(trans_dir, fn.replace(audio_ext, trans_ext))
|
26 |
+
if not os.path.isfile(trans_fn):
|
27 |
+
return [], True
|
28 |
+
align_fn = os.path.join(align_dir, fn.replace(audio_ext, arpa_ext))
|
29 |
+
if not os.path.isfile(align_fn):
|
30 |
+
return [], True
|
31 |
+
# get raw transcript
|
32 |
+
with open(trans_fn, 'r') as f:
|
33 |
+
transcript = f.read().strip()
|
34 |
+
raw_word_list = transcript.split(" ")
|
35 |
+
# get word alignment
|
36 |
+
with open(align_fn, 'r') as f:
|
37 |
+
word_alignment = csv.reader(f)
|
38 |
+
word_alignment = [row for row in word_alignment if row[3]=='words']
|
39 |
+
|
40 |
+
ipa_alignment = []
|
41 |
+
|
42 |
+
for j, (item, raw_word) in enumerate(zip(word_alignment, raw_word_list)):
|
43 |
+
start, end, word = float(item[0]), float(item[1]), item[2]
|
44 |
+
if end > prompt_dur:
|
45 |
+
break
|
46 |
+
punc_re_raw_word = remove_punctuation(raw_word)
|
47 |
+
if not remove_punctuation(word).lower() == punc_re_raw_word.lower():
|
48 |
+
# print(f"word from alignment csv: {word}, word from txt: {raw_word}")
|
49 |
+
return ipa_alignment, True
|
50 |
+
if random.random() < use_prob:
|
51 |
+
cur_words = " ".join(raw_word_list[:j+1])
|
52 |
+
phn = tokenize_text(text_tokenizer, cur_words)
|
53 |
+
if len(phn) == 0:
|
54 |
+
continue
|
55 |
+
phn = " ".join(phn)
|
56 |
+
start = 0 # at this point, we always start from the beginning of the sentence
|
57 |
+
ipa_alignment.append([start, end, phn])
|
58 |
+
if save:
|
59 |
+
if ipa_alignment:
|
60 |
+
with open(ipa_alignment_fn, 'w') as f:
|
61 |
+
for item in ipa_alignment:
|
62 |
+
f.write(f"{item[0]}\t{item[1]}\t{item[2]}\n")
|
63 |
+
else:
|
64 |
+
return ipa_alignment, False
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
def main(
|
69 |
+
data_root: str = '/data/scratch/pyp/datasets/librilight/preprocessed',
|
70 |
+
audio_ext: str = '.flac',
|
71 |
+
arpa_ext: str = '.csv',
|
72 |
+
trans_ext: str = '.txt',
|
73 |
+
split: str = 'valid',
|
74 |
+
use_prob: float = 0.5,
|
75 |
+
max_dur: float = 30., # do not consider utterance longer than this
|
76 |
+
prompt_dur: float = 30., # do not consider prompt longer than this
|
77 |
+
):
|
78 |
+
text_tokenizer = TextTokenizer()
|
79 |
+
trans_dir = f'{data_root}/audio'
|
80 |
+
align_dir = f'{data_root}/alignment'
|
81 |
+
manifest_fn = f"{data_root}/manifest_final_encodec/{split}*=*.txt"
|
82 |
+
manifest_fns = glob.glob(manifest_fn)
|
83 |
+
target_dir = f'{data_root}/ipa_alignment'
|
84 |
+
encodec_sr = 50
|
85 |
+
os.makedirs(target_dir, exist_ok=True)
|
86 |
+
manifest = []
|
87 |
+
for manifest_fn in manifest_fns:
|
88 |
+
with open(manifest_fn, 'r') as f:
|
89 |
+
temp = [l.strip().split("\t") for l in f.readlines()]
|
90 |
+
manifest += [l[0] + audio_ext for l in temp if float(l[1])/encodec_sr < max_dur]
|
91 |
+
# # sequential processing
|
92 |
+
n_flags = 0
|
93 |
+
zero_words = 0
|
94 |
+
for j, fn in enumerate(tqdm.tqdm(manifest)):
|
95 |
+
ipa_alignment_fn = os.path.join(target_dir, fn.replace(audio_ext, '.txt'))
|
96 |
+
ipa_alignment, flag = create_alignment(fn, trans_dir, align_dir, audio_ext, trans_ext, arpa_ext, text_tokenizer, use_prob, ipa_alignment_fn, prompt_dur=prompt_dur)
|
97 |
+
n_flags += flag
|
98 |
+
if not ipa_alignment:
|
99 |
+
zero_words += 1
|
100 |
+
# print(f"{n_flags} out of {j+1} utterances have mismatched words")
|
101 |
+
# print(f"{zero_words} out of {j+1} utterances have zero words")
|
102 |
+
if ipa_alignment:
|
103 |
+
with open(ipa_alignment_fn, 'w') as f:
|
104 |
+
for item in ipa_alignment:
|
105 |
+
f.write(f"{item[0]}\t{item[1]}\t{item[2]}\n")
|
106 |
+
|
107 |
+
# # # # do the above using joblib parallisim
|
108 |
+
# print(f"Processing {len(manifest)} utterances")
|
109 |
+
# from joblib import Parallel, delayed
|
110 |
+
# Parallel(n_jobs=32, verbose=2)(delayed(create_alignment)(fn, trans_dir, align_dir, audio_ext, trans_ext, arpa_ext, text_tokenizer, use_prob, os.path.join(target_dir, fn.replace(audio_ext, '.txt')), save=True) for fn in manifest)
|
111 |
+
|
112 |
+
if __name__ == "__main__":
|
113 |
+
import fire
|
114 |
+
fire.Fire(main)
|
data/ll60k_preprocessing/tokenizer.py
ADDED
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/data/tokenizer.py
|
2 |
+
# Copyright 2023 (authors: Feiteng Li)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import re
|
17 |
+
from dataclasses import asdict, dataclass
|
18 |
+
from typing import Any, Dict, List, Optional, Pattern, Union
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import torch
|
22 |
+
import torchaudio
|
23 |
+
# from encodec import EncodecModel
|
24 |
+
# from encodec.utils import convert_audio
|
25 |
+
# from lhotse.features import FeatureExtractor
|
26 |
+
# from lhotse.utils import Seconds, compute_num_frames
|
27 |
+
from phonemizer.backend import EspeakBackend
|
28 |
+
from phonemizer.backend.espeak.language_switch import LanguageSwitch
|
29 |
+
from phonemizer.backend.espeak.words_mismatch import WordMismatch
|
30 |
+
from phonemizer.punctuation import Punctuation
|
31 |
+
from phonemizer.separator import Separator
|
32 |
+
|
33 |
+
try:
|
34 |
+
from pypinyin import Style, pinyin
|
35 |
+
from pypinyin.style._utils import get_finals, get_initials
|
36 |
+
except Exception:
|
37 |
+
pass
|
38 |
+
|
39 |
+
|
40 |
+
class PypinyinBackend:
|
41 |
+
"""PypinyinBackend for Chinese. Most codes is referenced from espnet.
|
42 |
+
There are two types pinyin or initials_finals, one is
|
43 |
+
just like "ni1 hao3", the other is like "n i1 h ao3".
|
44 |
+
"""
|
45 |
+
|
46 |
+
def __init__(
|
47 |
+
self,
|
48 |
+
backend="initials_finals",
|
49 |
+
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
|
50 |
+
) -> None:
|
51 |
+
self.backend = backend
|
52 |
+
self.punctuation_marks = punctuation_marks
|
53 |
+
|
54 |
+
def phonemize(
|
55 |
+
self, text: List[str], separator: Separator, strip=True, njobs=1
|
56 |
+
) -> List[str]:
|
57 |
+
assert isinstance(text, List)
|
58 |
+
phonemized = []
|
59 |
+
for _text in text:
|
60 |
+
_text = re.sub(" +", " ", _text.strip())
|
61 |
+
_text = _text.replace(" ", separator.word)
|
62 |
+
phones = []
|
63 |
+
if self.backend == "pypinyin":
|
64 |
+
for n, py in enumerate(
|
65 |
+
pinyin(
|
66 |
+
_text, style=Style.TONE3, neutral_tone_with_five=True
|
67 |
+
)
|
68 |
+
):
|
69 |
+
if all([c in self.punctuation_marks for c in py[0]]):
|
70 |
+
if len(phones):
|
71 |
+
assert phones[-1] == separator.syllable
|
72 |
+
phones.pop(-1)
|
73 |
+
|
74 |
+
phones.extend(list(py[0]))
|
75 |
+
else:
|
76 |
+
phones.extend([py[0], separator.syllable])
|
77 |
+
elif self.backend == "pypinyin_initials_finals":
|
78 |
+
for n, py in enumerate(
|
79 |
+
pinyin(
|
80 |
+
_text, style=Style.TONE3, neutral_tone_with_five=True
|
81 |
+
)
|
82 |
+
):
|
83 |
+
if all([c in self.punctuation_marks for c in py[0]]):
|
84 |
+
if len(phones):
|
85 |
+
assert phones[-1] == separator.syllable
|
86 |
+
phones.pop(-1)
|
87 |
+
phones.extend(list(py[0]))
|
88 |
+
else:
|
89 |
+
if py[0][-1].isalnum():
|
90 |
+
initial = get_initials(py[0], strict=False)
|
91 |
+
if py[0][-1].isdigit():
|
92 |
+
final = (
|
93 |
+
get_finals(py[0][:-1], strict=False)
|
94 |
+
+ py[0][-1]
|
95 |
+
)
|
96 |
+
else:
|
97 |
+
final = get_finals(py[0], strict=False)
|
98 |
+
phones.extend(
|
99 |
+
[
|
100 |
+
initial,
|
101 |
+
separator.phone,
|
102 |
+
final,
|
103 |
+
separator.syllable,
|
104 |
+
]
|
105 |
+
)
|
106 |
+
else:
|
107 |
+
assert ValueError
|
108 |
+
else:
|
109 |
+
raise NotImplementedError
|
110 |
+
phonemized.append(
|
111 |
+
"".join(phones).rstrip(f"{separator.word}{separator.syllable}")
|
112 |
+
)
|
113 |
+
return phonemized
|
114 |
+
|
115 |
+
|
116 |
+
class TextTokenizer:
|
117 |
+
"""Phonemize Text."""
|
118 |
+
|
119 |
+
def __init__(
|
120 |
+
self,
|
121 |
+
language="en-us",
|
122 |
+
backend="espeak",
|
123 |
+
separator=Separator(word="_", syllable="-", phone="|"),
|
124 |
+
preserve_punctuation=True,
|
125 |
+
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
|
126 |
+
with_stress: bool = False,
|
127 |
+
tie: Union[bool, str] = False,
|
128 |
+
language_switch: LanguageSwitch = "keep-flags",
|
129 |
+
words_mismatch: WordMismatch = "ignore",
|
130 |
+
) -> None:
|
131 |
+
if backend == "espeak":
|
132 |
+
phonemizer = EspeakBackend(
|
133 |
+
language,
|
134 |
+
punctuation_marks=punctuation_marks,
|
135 |
+
preserve_punctuation=preserve_punctuation,
|
136 |
+
with_stress=with_stress,
|
137 |
+
tie=tie,
|
138 |
+
language_switch=language_switch,
|
139 |
+
words_mismatch=words_mismatch,
|
140 |
+
)
|
141 |
+
elif backend in ["pypinyin", "pypinyin_initials_finals"]:
|
142 |
+
phonemizer = PypinyinBackend(
|
143 |
+
backend=backend,
|
144 |
+
punctuation_marks=punctuation_marks + separator.word,
|
145 |
+
)
|
146 |
+
else:
|
147 |
+
raise NotImplementedError(f"{backend}")
|
148 |
+
|
149 |
+
self.backend = phonemizer
|
150 |
+
self.separator = separator
|
151 |
+
|
152 |
+
def to_list(self, phonemized: str) -> List[str]:
|
153 |
+
fields = []
|
154 |
+
for word in phonemized.split(self.separator.word):
|
155 |
+
# "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z.
|
156 |
+
pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE)
|
157 |
+
fields.extend(
|
158 |
+
[p for p in pp if p != self.separator.phone]
|
159 |
+
+ [self.separator.word]
|
160 |
+
)
|
161 |
+
assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count(
|
162 |
+
self.separator.phone
|
163 |
+
)
|
164 |
+
return fields[:-1]
|
165 |
+
|
166 |
+
def __call__(self, text, strip=True) -> List[List[str]]:
|
167 |
+
if isinstance(text, str):
|
168 |
+
text = [text]
|
169 |
+
|
170 |
+
phonemized = self.backend.phonemize(
|
171 |
+
text, separator=self.separator, strip=strip, njobs=1
|
172 |
+
)
|
173 |
+
return [self.to_list(p) for p in phonemized]
|
174 |
+
|
175 |
+
|
176 |
+
def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]:
|
177 |
+
phonemes = tokenizer([text.strip()])
|
178 |
+
return phonemes[0] # k2symbols
|
179 |
+
|
180 |
+
|
181 |
+
def remove_encodec_weight_norm(model):
|
182 |
+
from encodec.modules import SConv1d
|
183 |
+
from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock
|
184 |
+
from torch.nn.utils import remove_weight_norm
|
185 |
+
encoder = model.encoder.model
|
186 |
+
for key in encoder._modules:
|
187 |
+
if isinstance(encoder._modules[key], SEANetResnetBlock):
|
188 |
+
remove_weight_norm(encoder._modules[key].shortcut.conv.conv)
|
189 |
+
block_modules = encoder._modules[key].block._modules
|
190 |
+
for skey in block_modules:
|
191 |
+
if isinstance(block_modules[skey], SConv1d):
|
192 |
+
remove_weight_norm(block_modules[skey].conv.conv)
|
193 |
+
elif isinstance(encoder._modules[key], SConv1d):
|
194 |
+
remove_weight_norm(encoder._modules[key].conv.conv)
|
195 |
+
decoder = model.decoder.model
|
196 |
+
for key in decoder._modules:
|
197 |
+
if isinstance(decoder._modules[key], SEANetResnetBlock):
|
198 |
+
remove_weight_norm(decoder._modules[key].shortcut.conv.conv)
|
199 |
+
block_modules = decoder._modules[key].block._modules
|
200 |
+
for skey in block_modules:
|
201 |
+
if isinstance(block_modules[skey], SConv1d):
|
202 |
+
remove_weight_norm(block_modules[skey].conv.conv)
|
203 |
+
elif isinstance(decoder._modules[key], SConvTranspose1d):
|
204 |
+
remove_weight_norm(decoder._modules[key].convtr.convtr)
|
205 |
+
elif isinstance(decoder._modules[key], SConv1d):
|
206 |
+
remove_weight_norm(decoder._modules[key].conv.conv)
|
207 |
+
|
208 |
+
|
209 |
+
# class AudioTokenizer:
|
210 |
+
# """EnCodec audio."""
|
211 |
+
|
212 |
+
# def __init__(
|
213 |
+
# self,
|
214 |
+
# bandwidth, float=6.0,
|
215 |
+
# device: Any = None,
|
216 |
+
# ) -> None:
|
217 |
+
# # Instantiate a pretrained EnCodec model
|
218 |
+
# model = EncodecModel.encodec_model_24khz()
|
219 |
+
# model.set_target_bandwidth(bandwidth=bandwidth)
|
220 |
+
# remove_encodec_weight_norm(model)
|
221 |
+
|
222 |
+
# if not device:
|
223 |
+
# device = torch.device("cpu")
|
224 |
+
# if torch.cuda.is_available():
|
225 |
+
# device = torch.device("cuda:0")
|
226 |
+
|
227 |
+
# self._device = device
|
228 |
+
|
229 |
+
# self.codec = model.to(device)
|
230 |
+
# self.sample_rate = model.sample_rate
|
231 |
+
# self.channels = model.channels
|
232 |
+
|
233 |
+
# @property
|
234 |
+
# def device(self):
|
235 |
+
# return self._device
|
236 |
+
|
237 |
+
# def encode(self, wav: torch.Tensor) -> torch.Tensor:
|
238 |
+
# return self.codec.encode(wav.to(self.device))
|
239 |
+
|
240 |
+
# def decode(self, frames: torch.Tensor) -> torch.Tensor:
|
241 |
+
# return self.codec.decode(frames)
|
242 |
+
|
243 |
+
# class AudioTokenizer:
|
244 |
+
# """EnCodec audio."""
|
245 |
+
|
246 |
+
# def __init__(
|
247 |
+
# self,
|
248 |
+
# bandwidth: float=6.0,
|
249 |
+
# device: Any = None,
|
250 |
+
# hificodec=False,
|
251 |
+
# signature = None
|
252 |
+
# ) -> None:
|
253 |
+
# self.hificodec = hificodec
|
254 |
+
# self.customized = True if signature != None else False
|
255 |
+
# if hificodec:
|
256 |
+
# import sys
|
257 |
+
# sys.path.append("/home/pyp/AcademiCodec")
|
258 |
+
# from academicodec.models.hificodec.vqvae import VQVAE
|
259 |
+
# config_path = "/home/pyp/AcademiCodec/egs/HiFi-Codec-16k-320d/config_16k_320d.json"
|
260 |
+
# model_path = "/home/pyp/AcademiCodec/egs/HiFi-Codec-16k-320d/checkpoint/HiFi-Codec-16k-320d"
|
261 |
+
# self.sample_rate = 16000
|
262 |
+
# self.channels = 1
|
263 |
+
# model = VQVAE(config_path, model_path, with_encoder=True)
|
264 |
+
# model.generator.remove_weight_norm()
|
265 |
+
# model.encoder.remove_weight_norm()
|
266 |
+
# model.eval()
|
267 |
+
# else:
|
268 |
+
# if signature != None:
|
269 |
+
# # use customized encodec model
|
270 |
+
# # import sys
|
271 |
+
# # sys.path.append("home/pyp/audiocraft")
|
272 |
+
# from audiocraft.solvers import CompressionSolver
|
273 |
+
# model_path = f'//sig/{signature}'
|
274 |
+
# model = CompressionSolver.model_from_checkpoint(model_path)
|
275 |
+
# self.sample_rate = model.sample_rate
|
276 |
+
# self.channels = model.channels
|
277 |
+
# else:
|
278 |
+
# # Instantiate a pretrained EnCodec model
|
279 |
+
# model = EncodecModel.encodec_model_24khz()
|
280 |
+
# model.set_target_bandwidth(bandwidth=bandwidth)
|
281 |
+
# remove_encodec_weight_norm(model)
|
282 |
+
# self.sample_rate = model.sample_rate
|
283 |
+
# self.channels = model.channels
|
284 |
+
|
285 |
+
# if not device:
|
286 |
+
# device = torch.device("cpu")
|
287 |
+
# if torch.cuda.is_available():
|
288 |
+
# device = torch.device("cuda:0")
|
289 |
+
|
290 |
+
# self._device = device
|
291 |
+
|
292 |
+
# self.codec = model.to(device)
|
293 |
+
|
294 |
+
# @property
|
295 |
+
# def device(self):
|
296 |
+
# return self._device
|
297 |
+
|
298 |
+
# def encode(self, wav: torch.Tensor) -> torch.Tensor:
|
299 |
+
# if self.hificodec:
|
300 |
+
# assert wav.ndim==3 and wav.shape[:2] == torch.Size((1,1)), wav.shape
|
301 |
+
# wav = wav.squeeze(0)
|
302 |
+
# codes = self.codec.encode(wav.to(self.device)) # [1,T,4]
|
303 |
+
# return [(codes.transpose(2,1),None)]
|
304 |
+
# elif self.customized:
|
305 |
+
# codes = self.codec.encode(wav.to(self.device))
|
306 |
+
# return [(codes[0], None)]
|
307 |
+
# return self.codec.encode(wav.to(self.device))
|
308 |
+
|
309 |
+
# def decode(self, frames: torch.Tensor) -> torch.Tensor:
|
310 |
+
# if self.hificodec:
|
311 |
+
# frames = frames[0][0] # [1,4,T]
|
312 |
+
# assert frames.shape[:2] == torch.Size((1,4))
|
313 |
+
# audio = self.codec(frames.transpose(2,1))
|
314 |
+
# assert audio.shape[0] == 1, audio.shape
|
315 |
+
# return audio
|
316 |
+
# elif self.customized:
|
317 |
+
# frames = frames[0][0] # [1,4,T]
|
318 |
+
# return self.codec.decode(frames)
|
319 |
+
# return self.codec.decode(frames)
|
320 |
+
# # try:
|
321 |
+
# # return self.codec.decode(frames)
|
322 |
+
# # except:
|
323 |
+
# # import logging
|
324 |
+
# # logging.info(f"error when decoding frame of shape: {frames[0][0].shape}")
|
325 |
+
# # self.codec.cpu()
|
326 |
+
# # ret = self.codec.cpu().decode([(frames[0][0].cpu(),None)])[0].to(self._device)
|
327 |
+
# # self.codec.to(self._device)
|
328 |
+
# # return [ret]
|
329 |
+
|
330 |
+
# def tokenize_audio(tokenizer: AudioTokenizer, audio_path: str, offset = -1, num_frames=-1):
|
331 |
+
# # Load and pre-process the audio waveform
|
332 |
+
# if offset != -1 and num_frames!=-1:
|
333 |
+
# wav, sr = torchaudio.load(audio_path, frame_offset=offset, num_frames=num_frames)
|
334 |
+
# else:
|
335 |
+
# wav, sr = torchaudio.load(audio_path)
|
336 |
+
# wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels)
|
337 |
+
# wav = wav.unsqueeze(0)
|
338 |
+
|
339 |
+
# # Extract discrete codes from EnCodec
|
340 |
+
# with torch.no_grad():
|
341 |
+
# encoded_frames = tokenizer.encode(wav)
|
342 |
+
# return encoded_frames
|
343 |
+
|
344 |
+
|
345 |
+
# @dataclass
|
346 |
+
# class AudioTokenConfig:
|
347 |
+
# frame_shift: Seconds = 320.0 / 24000
|
348 |
+
# num_quantizers: int = 8
|
349 |
+
|
350 |
+
# def to_dict(self) -> Dict[str, Any]:
|
351 |
+
# return asdict(self)
|
352 |
+
|
353 |
+
# @staticmethod
|
354 |
+
# def from_dict(data: Dict[str, Any]) -> "AudioTokenConfig":
|
355 |
+
# return AudioTokenConfig(**data)
|
356 |
+
|
357 |
+
|
358 |
+
# class AudioTokenExtractor(FeatureExtractor):
|
359 |
+
# name = "encodec"
|
360 |
+
# config_type = AudioTokenConfig
|
361 |
+
|
362 |
+
# def __init__(self, config: Optional[Any] = None):
|
363 |
+
# super(AudioTokenExtractor, self).__init__(config)
|
364 |
+
# self.tokenizer = AudioTokenizer()
|
365 |
+
|
366 |
+
# def extract(
|
367 |
+
# self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int
|
368 |
+
# ) -> np.ndarray:
|
369 |
+
# if not isinstance(samples, torch.Tensor):
|
370 |
+
# samples = torch.from_numpy(samples)
|
371 |
+
# if sampling_rate != self.tokenizer.sample_rate:
|
372 |
+
# samples = convert_audio(
|
373 |
+
# samples,
|
374 |
+
# sampling_rate,
|
375 |
+
# self.tokenizer.sample_rate,
|
376 |
+
# self.tokenizer.channels,
|
377 |
+
# )
|
378 |
+
# if len(samples.shape) == 2:
|
379 |
+
# samples = samples.unsqueeze(0)
|
380 |
+
# else:
|
381 |
+
# raise ValueError()
|
382 |
+
|
383 |
+
# device = self.tokenizer.device
|
384 |
+
# encoded_frames = self.tokenizer.encode(samples.detach().to(device))
|
385 |
+
# codes = encoded_frames[0][0] # [B, n_q, T]
|
386 |
+
# if True:
|
387 |
+
# duration = round(samples.shape[-1] / sampling_rate, ndigits=12)
|
388 |
+
# expected_num_frames = compute_num_frames(
|
389 |
+
# duration=duration,
|
390 |
+
# frame_shift=self.frame_shift,
|
391 |
+
# sampling_rate=sampling_rate,
|
392 |
+
# )
|
393 |
+
# assert abs(codes.shape[-1] - expected_num_frames) <= 1
|
394 |
+
# codes = codes[..., :expected_num_frames]
|
395 |
+
# return codes.cpu().squeeze(0).permute(1, 0).numpy()
|
396 |
+
|
397 |
+
# @property
|
398 |
+
# def frame_shift(self) -> Seconds:
|
399 |
+
# return self.config.frame_shift
|
400 |
+
|
401 |
+
# def feature_dim(self, sampling_rate: int) -> int:
|
402 |
+
# return self.config.num_quantizers
|
403 |
+
|
404 |
+
# def pad_tensor_list(self, tensor_list, device, padding_value=0):
|
405 |
+
# # 计算每个张量的长度
|
406 |
+
# lengths = [tensor.shape[0] for tensor in tensor_list]
|
407 |
+
# # 使用pad_sequence函数进行填充
|
408 |
+
# tensor_list = [torch.Tensor(t).to(device) for t in tensor_list]
|
409 |
+
# padded_tensor = torch.nn.utils.rnn.pad_sequence(
|
410 |
+
# tensor_list, batch_first=True, padding_value=padding_value
|
411 |
+
# )
|
412 |
+
# return padded_tensor, lengths
|
413 |
+
|
414 |
+
# def extract_batch(self, samples, sampling_rate, lengths) -> np.ndarray:
|
415 |
+
# samples = [wav.squeeze() for wav in samples]
|
416 |
+
# device = self.tokenizer.device
|
417 |
+
# samples, lengths = self.pad_tensor_list(samples, device)
|
418 |
+
# samples = samples.unsqueeze(1)
|
419 |
+
|
420 |
+
# if not isinstance(samples, torch.Tensor):
|
421 |
+
# samples = torch.from_numpy(samples)
|
422 |
+
# if len(samples.shape) != 3:
|
423 |
+
# raise ValueError()
|
424 |
+
# if sampling_rate != self.tokenizer.sample_rate:
|
425 |
+
# samples = [
|
426 |
+
# convert_audio(
|
427 |
+
# wav,
|
428 |
+
# sampling_rate,
|
429 |
+
# self.tokenizer.sample_rate,
|
430 |
+
# self.tokenizer.channels,
|
431 |
+
# )
|
432 |
+
# for wav in samples
|
433 |
+
# ]
|
434 |
+
# # Extract discrete codes from EnCodec
|
435 |
+
# with torch.no_grad():
|
436 |
+
# encoded_frames = self.tokenizer.encode(samples.detach().to(device))
|
437 |
+
# encoded_frames = encoded_frames[0][0] # [B, n_q, T]
|
438 |
+
# batch_codes = []
|
439 |
+
# for b, length in enumerate(lengths):
|
440 |
+
# codes = encoded_frames[b]
|
441 |
+
# duration = round(length / sampling_rate, ndigits=12)
|
442 |
+
# expected_num_frames = compute_num_frames(
|
443 |
+
# duration=duration,
|
444 |
+
# frame_shift=self.frame_shift,
|
445 |
+
# sampling_rate=sampling_rate,
|
446 |
+
# )
|
447 |
+
# batch_codes.append(codes[..., :expected_num_frames])
|
448 |
+
# return [codes.cpu().permute(1, 0).numpy() for codes in batch_codes]
|
449 |
+
|
450 |
+
|
451 |
+
if __name__ == "__main__":
|
452 |
+
model = EncodecModel.encodec_model_24khz()
|
453 |
+
model.set_target_bandwidth(6.0)
|
454 |
+
# model.cuda()
|
455 |
+
samples = torch.from_numpy(np.random.random([4, 1, 30000])).type(torch.float32)
|
456 |
+
codes_norm = model.encode(samples.cuda())
|
457 |
+
remove_encodec_weight_norm(model)
|
458 |
+
codes_raw = model.encode(samples.cuda())
|
459 |
+
|
460 |
+
assert torch.allclose(codes_raw[0][0], codes_norm[0][0])
|
data/tokenizer.py
ADDED
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/data/tokenizer.py
|
2 |
+
# Copyright 2023 (authors: Feiteng Li)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import re
|
17 |
+
from dataclasses import asdict, dataclass
|
18 |
+
from typing import Any, Dict, List, Optional, Pattern, Union
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import torch
|
22 |
+
import torchaudio
|
23 |
+
# from encodec import EncodecModel
|
24 |
+
# from encodec.utils import convert_audio
|
25 |
+
# from lhotse.features import FeatureExtractor
|
26 |
+
# from lhotse.utils import Seconds, compute_num_frames
|
27 |
+
from phonemizer.backend import EspeakBackend
|
28 |
+
from phonemizer.backend.espeak.language_switch import LanguageSwitch
|
29 |
+
from phonemizer.backend.espeak.words_mismatch import WordMismatch
|
30 |
+
from phonemizer.punctuation import Punctuation
|
31 |
+
from phonemizer.separator import Separator
|
32 |
+
|
33 |
+
|
34 |
+
try:
|
35 |
+
from pypinyin import Style, pinyin
|
36 |
+
from pypinyin.style._utils import get_finals, get_initials
|
37 |
+
except Exception:
|
38 |
+
pass
|
39 |
+
|
40 |
+
|
41 |
+
class PypinyinBackend:
|
42 |
+
"""PypinyinBackend for Chinese. Most codes is referenced from espnet.
|
43 |
+
There are two types pinyin or initials_finals, one is
|
44 |
+
just like "ni1 hao3", the other is like "n i1 h ao3".
|
45 |
+
"""
|
46 |
+
|
47 |
+
def __init__(
|
48 |
+
self,
|
49 |
+
backend="initials_finals",
|
50 |
+
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
|
51 |
+
) -> None:
|
52 |
+
self.backend = backend
|
53 |
+
self.punctuation_marks = punctuation_marks
|
54 |
+
|
55 |
+
def phonemize(
|
56 |
+
self, text: List[str], separator: Separator, strip=True, njobs=1
|
57 |
+
) -> List[str]:
|
58 |
+
assert isinstance(text, List)
|
59 |
+
phonemized = []
|
60 |
+
for _text in text:
|
61 |
+
_text = re.sub(" +", " ", _text.strip())
|
62 |
+
_text = _text.replace(" ", separator.word)
|
63 |
+
phones = []
|
64 |
+
if self.backend == "pypinyin":
|
65 |
+
for n, py in enumerate(
|
66 |
+
pinyin(
|
67 |
+
_text, style=Style.TONE3, neutral_tone_with_five=True
|
68 |
+
)
|
69 |
+
):
|
70 |
+
if all([c in self.punctuation_marks for c in py[0]]):
|
71 |
+
if len(phones):
|
72 |
+
assert phones[-1] == separator.syllable
|
73 |
+
phones.pop(-1)
|
74 |
+
|
75 |
+
phones.extend(list(py[0]))
|
76 |
+
else:
|
77 |
+
phones.extend([py[0], separator.syllable])
|
78 |
+
elif self.backend == "pypinyin_initials_finals":
|
79 |
+
for n, py in enumerate(
|
80 |
+
pinyin(
|
81 |
+
_text, style=Style.TONE3, neutral_tone_with_five=True
|
82 |
+
)
|
83 |
+
):
|
84 |
+
if all([c in self.punctuation_marks for c in py[0]]):
|
85 |
+
if len(phones):
|
86 |
+
assert phones[-1] == separator.syllable
|
87 |
+
phones.pop(-1)
|
88 |
+
phones.extend(list(py[0]))
|
89 |
+
else:
|
90 |
+
if py[0][-1].isalnum():
|
91 |
+
initial = get_initials(py[0], strict=False)
|
92 |
+
if py[0][-1].isdigit():
|
93 |
+
final = (
|
94 |
+
get_finals(py[0][:-1], strict=False)
|
95 |
+
+ py[0][-1]
|
96 |
+
)
|
97 |
+
else:
|
98 |
+
final = get_finals(py[0], strict=False)
|
99 |
+
phones.extend(
|
100 |
+
[
|
101 |
+
initial,
|
102 |
+
separator.phone,
|
103 |
+
final,
|
104 |
+
separator.syllable,
|
105 |
+
]
|
106 |
+
)
|
107 |
+
else:
|
108 |
+
assert ValueError
|
109 |
+
else:
|
110 |
+
raise NotImplementedError
|
111 |
+
phonemized.append(
|
112 |
+
"".join(phones).rstrip(f"{separator.word}{separator.syllable}")
|
113 |
+
)
|
114 |
+
return phonemized
|
115 |
+
|
116 |
+
|
117 |
+
class TextTokenizer:
|
118 |
+
"""Phonemize Text."""
|
119 |
+
|
120 |
+
def __init__(
|
121 |
+
self,
|
122 |
+
language="en-us",
|
123 |
+
backend="espeak",
|
124 |
+
separator=Separator(word="_", syllable="-", phone="|"),
|
125 |
+
preserve_punctuation=True,
|
126 |
+
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
|
127 |
+
with_stress: bool = False,
|
128 |
+
tie: Union[bool, str] = False,
|
129 |
+
language_switch: LanguageSwitch = "keep-flags",
|
130 |
+
words_mismatch: WordMismatch = "ignore",
|
131 |
+
) -> None:
|
132 |
+
if backend == "espeak":
|
133 |
+
phonemizer = EspeakBackend(
|
134 |
+
language,
|
135 |
+
punctuation_marks=punctuation_marks,
|
136 |
+
preserve_punctuation=preserve_punctuation,
|
137 |
+
with_stress=with_stress,
|
138 |
+
tie=tie,
|
139 |
+
language_switch=language_switch,
|
140 |
+
words_mismatch=words_mismatch,
|
141 |
+
)
|
142 |
+
elif backend in ["pypinyin", "pypinyin_initials_finals"]:
|
143 |
+
phonemizer = PypinyinBackend(
|
144 |
+
backend=backend,
|
145 |
+
punctuation_marks=punctuation_marks + separator.word,
|
146 |
+
)
|
147 |
+
else:
|
148 |
+
raise NotImplementedError(f"{backend}")
|
149 |
+
|
150 |
+
self.backend = phonemizer
|
151 |
+
self.separator = separator
|
152 |
+
|
153 |
+
def to_list(self, phonemized: str) -> List[str]:
|
154 |
+
fields = []
|
155 |
+
for word in phonemized.split(self.separator.word):
|
156 |
+
# "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z.
|
157 |
+
pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE)
|
158 |
+
fields.extend(
|
159 |
+
[p for p in pp if p != self.separator.phone]
|
160 |
+
+ [self.separator.word]
|
161 |
+
)
|
162 |
+
assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count(
|
163 |
+
self.separator.phone
|
164 |
+
)
|
165 |
+
return fields[:-1]
|
166 |
+
|
167 |
+
def __call__(self, text, strip=True) -> List[List[str]]:
|
168 |
+
if isinstance(text, str):
|
169 |
+
text = [text]
|
170 |
+
|
171 |
+
phonemized = self.backend.phonemize(
|
172 |
+
text, separator=self.separator, strip=strip, njobs=1
|
173 |
+
)
|
174 |
+
return [self.to_list(p) for p in phonemized]
|
175 |
+
|
176 |
+
|
177 |
+
def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]:
|
178 |
+
phonemes = tokenizer([text.strip()])
|
179 |
+
return phonemes[0] # k2symbols
|
180 |
+
|
181 |
+
|
182 |
+
def remove_encodec_weight_norm(model):
|
183 |
+
from encodec.modules import SConv1d
|
184 |
+
from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock
|
185 |
+
from torch.nn.utils import remove_weight_norm
|
186 |
+
encoder = model.encoder.model
|
187 |
+
for key in encoder._modules:
|
188 |
+
if isinstance(encoder._modules[key], SEANetResnetBlock):
|
189 |
+
remove_weight_norm(encoder._modules[key].shortcut.conv.conv)
|
190 |
+
block_modules = encoder._modules[key].block._modules
|
191 |
+
for skey in block_modules:
|
192 |
+
if isinstance(block_modules[skey], SConv1d):
|
193 |
+
remove_weight_norm(block_modules[skey].conv.conv)
|
194 |
+
elif isinstance(encoder._modules[key], SConv1d):
|
195 |
+
remove_weight_norm(encoder._modules[key].conv.conv)
|
196 |
+
decoder = model.decoder.model
|
197 |
+
for key in decoder._modules:
|
198 |
+
if isinstance(decoder._modules[key], SEANetResnetBlock):
|
199 |
+
remove_weight_norm(decoder._modules[key].shortcut.conv.conv)
|
200 |
+
block_modules = decoder._modules[key].block._modules
|
201 |
+
for skey in block_modules:
|
202 |
+
if isinstance(block_modules[skey], SConv1d):
|
203 |
+
remove_weight_norm(block_modules[skey].conv.conv)
|
204 |
+
elif isinstance(decoder._modules[key], SConvTranspose1d):
|
205 |
+
remove_weight_norm(decoder._modules[key].convtr.convtr)
|
206 |
+
elif isinstance(decoder._modules[key], SConv1d):
|
207 |
+
remove_weight_norm(decoder._modules[key].conv.conv)
|
208 |
+
|
209 |
+
|
210 |
+
class AudioTokenizer:
|
211 |
+
"""mimi audio."""
|
212 |
+
|
213 |
+
def __init__(
|
214 |
+
self,
|
215 |
+
bandwidth: float=6.0,
|
216 |
+
device: Any = None,
|
217 |
+
hificodec=False,
|
218 |
+
signature = None,
|
219 |
+
encode_only = False
|
220 |
+
) -> None:
|
221 |
+
self.signature = signature
|
222 |
+
from data.encodec import get_compression_model
|
223 |
+
model = get_compression_model(signature, encode_only=encode_only, device=device)
|
224 |
+
self.sample_rate = model.sample_rate
|
225 |
+
self.channels = model.channels
|
226 |
+
|
227 |
+
if not device:
|
228 |
+
device = torch.device("cpu")
|
229 |
+
if torch.cuda.is_available():
|
230 |
+
device = torch.device("cuda")
|
231 |
+
|
232 |
+
self._device = device
|
233 |
+
|
234 |
+
self.codec = model.to(device)
|
235 |
+
|
236 |
+
@property
|
237 |
+
def device(self):
|
238 |
+
return self._device
|
239 |
+
|
240 |
+
def encode(self, wav: torch.Tensor) -> torch.Tensor:
|
241 |
+
if self.signature != None:
|
242 |
+
if self.signature == "lfsc":
|
243 |
+
if wav.ndim==3:
|
244 |
+
assert wav.shape[:2] == torch.Size((1,1)), wav.shape
|
245 |
+
wav = wav.squeeze(0)
|
246 |
+
elif wav.ndim==2:
|
247 |
+
assert wav.shape[0] == 1, wav.shape
|
248 |
+
else:
|
249 |
+
raise ValueError(wav.shape)
|
250 |
+
audio_len = torch.tensor([wav.shape[1]]).to(self.device)
|
251 |
+
codes, encoded_len = self.codec.encode(audio=wav.to(self.device), audio_len=audio_len)
|
252 |
+
return codes[:, :, :encoded_len[0]]
|
253 |
+
else:
|
254 |
+
codes = self.codec.encode(wav.to(self.device))
|
255 |
+
return codes[0]
|
256 |
+
else:
|
257 |
+
assert wav.ndim==3 and wav.shape[:2] == torch.Size((1,1)), wav.shape
|
258 |
+
return self.codec.encode(wav.to(self.device))
|
259 |
+
|
260 |
+
def decode(self, frames: torch.Tensor) -> torch.Tensor:
|
261 |
+
if self.signature != None and self.signature == "lfsc":
|
262 |
+
encoded_len = torch.tensor([frames.shape[-1]]).to(self.device)
|
263 |
+
reconstructed_audio, decoded_len = self.codec.decode(tokens=frames, tokens_len=encoded_len)
|
264 |
+
return reconstructed_audio[:, :decoded_len[0]].unsqueeze(0)
|
265 |
+
else:
|
266 |
+
return self.codec.decode(frames)
|
267 |
+
|
268 |
+
def tokenize_audio(tokenizer: AudioTokenizer, audio_path: str, offset = -1, num_frames=-1):
|
269 |
+
# Load and pre-process the audio waveform
|
270 |
+
if offset != -1 and num_frames!=-1:
|
271 |
+
wav, sr = torchaudio.load(audio_path, frame_offset=offset, num_frames=num_frames)
|
272 |
+
else:
|
273 |
+
wav, sr = torchaudio.load(audio_path)
|
274 |
+
if sr != tokenizer.sample_rate:
|
275 |
+
wav = torchaudio.transforms.Resample(sr, tokenizer.sample_rate)(wav)
|
276 |
+
sr = tokenizer.sample_rate
|
277 |
+
if wav.shape[0] == 2:
|
278 |
+
wav = wav.mean(dim=0, keepdim=True)
|
279 |
+
wav = wav.unsqueeze(0)
|
280 |
+
# Extract discrete codes from mimi
|
281 |
+
with torch.no_grad():
|
282 |
+
encoded_frames = tokenizer.encode(wav)
|
283 |
+
return encoded_frames
|
284 |
+
|
285 |
+
|
286 |
+
if __name__ == "__main__":
|
287 |
+
# tok = AudioTokenizer(signature="lfsc", device="cpu")
|
288 |
+
tok = AudioTokenizer(signature="/home/pyp/BoostedVoiceEditor/pretrained/encodec_6f79c6a8.th", device="cpu")
|
289 |
+
inaudio = "/home/pyp/BoostedVoiceEditor/demo/pam.wav"
|
290 |
+
encoded_frames = tokenize_audio(tok, inaudio)
|
291 |
+
print(encoded_frames.shape)
|
292 |
+
# decode it back
|
293 |
+
decoded_audio = tok.decode(encoded_frames)
|
294 |
+
torchaudio.save("/home/pyp/BoostedVoiceEditor/demo/pam_reconstructed_encodec_4cb_2nd.wav", decoded_audio[0], tok.sample_rate)
|
295 |
+
|
demo/5895_34622_000026_000002.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6914b36cccc39da83dc9e2f9370556e62876b6f0e37e0ab324d5c85208107fe4
|
3 |
+
size 503738
|
generated_tts/generated.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b054b4bcd8b52504ef903b578c4040edbfd4949d3aaae6be7f76ac2976a1f280
|
3 |
+
size 251598
|
inference_commandline.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torchaudio
|
4 |
+
import numpy as np
|
5 |
+
import random
|
6 |
+
import whisper
|
7 |
+
import fire
|
8 |
+
from argparse import Namespace
|
9 |
+
|
10 |
+
from data.tokenizer import (
|
11 |
+
AudioTokenizer,
|
12 |
+
TextTokenizer,
|
13 |
+
)
|
14 |
+
|
15 |
+
from models import voice_star
|
16 |
+
from inference_tts_utils import inference_one_sample
|
17 |
+
|
18 |
+
############################################################
|
19 |
+
# Utility Functions
|
20 |
+
############################################################
|
21 |
+
|
22 |
+
def seed_everything(seed=1):
|
23 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
24 |
+
random.seed(seed)
|
25 |
+
np.random.seed(seed)
|
26 |
+
torch.manual_seed(seed)
|
27 |
+
torch.cuda.manual_seed(seed)
|
28 |
+
torch.backends.cudnn.benchmark = False
|
29 |
+
torch.backends.cudnn.deterministic = True
|
30 |
+
|
31 |
+
|
32 |
+
def estimate_duration(ref_audio_path, text):
|
33 |
+
"""
|
34 |
+
Estimate duration based on seconds per character from the reference audio.
|
35 |
+
"""
|
36 |
+
info = torchaudio.info(ref_audio_path)
|
37 |
+
audio_duration = info.num_frames / info.sample_rate
|
38 |
+
length_text = max(len(text), 1)
|
39 |
+
spc = audio_duration / length_text # seconds per character
|
40 |
+
return len(text) * spc
|
41 |
+
|
42 |
+
############################################################
|
43 |
+
# Main Inference Function
|
44 |
+
############################################################
|
45 |
+
|
46 |
+
def run_inference(
|
47 |
+
reference_speech="./demo/5895_34622_000026_000002.wav",
|
48 |
+
target_text="I cannot believe that the same model can also do text to speech synthesis too! And you know what? this audio is 8 seconds long.",
|
49 |
+
# Model
|
50 |
+
model_name="VoiceStar_840M_30s", # or VoiceStar_840M_40s, the later model is trained on maximally 40s long speech
|
51 |
+
model_root="./pretrained",
|
52 |
+
# Additional optional
|
53 |
+
reference_text=None, # if None => run whisper on reference_speech
|
54 |
+
target_duration=None, # if None => estimate from reference_speech and target_text
|
55 |
+
# Default hyperparameters from snippet
|
56 |
+
codec_audio_sr=16000, # do not change
|
57 |
+
codec_sr=50, # do not change
|
58 |
+
top_k=10, # try 10, 20, 30, 40
|
59 |
+
top_p=1, # do not change
|
60 |
+
min_p=1, # do not change
|
61 |
+
temperature=1,
|
62 |
+
silence_tokens=None, # do not change it
|
63 |
+
kvcache=1, # if OOM, set to 0
|
64 |
+
multi_trial=None, # do not change it
|
65 |
+
repeat_prompt=1, # increase this to improve speaker similarity, but it reference speech duration in total adding target duration is longer than maximal training duration, quality may drop
|
66 |
+
stop_repetition=3, # will not use it
|
67 |
+
sample_batch_size=1, # do not change
|
68 |
+
# Others
|
69 |
+
seed=1,
|
70 |
+
output_dir="./generated_tts",
|
71 |
+
# Some snippet-based defaults
|
72 |
+
cut_off_sec=100, # do not adjust this, we always use the entire reference speech. If you wish to change, also make sure to change the reference_transcript, so that it's only the trasnscript of the speech remained
|
73 |
+
):
|
74 |
+
"""
|
75 |
+
Inference script using Fire.
|
76 |
+
|
77 |
+
Example:
|
78 |
+
python inference_commandline.py \
|
79 |
+
--reference_speech "./demo/5895_34622_000026_000002.wav" \
|
80 |
+
--target_text "I cannot believe ... this audio is 10 seconds long." \
|
81 |
+
--reference_text "(optional) text to use as prefix" \
|
82 |
+
--target_duration (optional float)
|
83 |
+
"""
|
84 |
+
|
85 |
+
# Seed everything
|
86 |
+
seed_everything(seed)
|
87 |
+
|
88 |
+
# Load model, phn2num, and args
|
89 |
+
torch.serialization.add_safe_globals([Namespace])
|
90 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
91 |
+
ckpt_fn = os.path.join(model_root, model_name+".pth")
|
92 |
+
if not os.path.exists(ckpt_fn):
|
93 |
+
# use wget to download
|
94 |
+
print(f"[Info] Downloading {model_name} checkpoint...")
|
95 |
+
os.system(f"wget https://huggingface.co/pyp1/VoiceStar/resolve/main/{model_name}.pth?download=true -O {ckpt_fn}")
|
96 |
+
bundle = torch.load(ckpt_fn, map_location=device, weights_only=True)
|
97 |
+
args = bundle["args"]
|
98 |
+
phn2num = bundle["phn2num"]
|
99 |
+
model = voice_star.VoiceStar(args)
|
100 |
+
model.load_state_dict(bundle["model"])
|
101 |
+
model.to(device)
|
102 |
+
model.eval()
|
103 |
+
|
104 |
+
# If reference_text not provided, use whisper large-v3-turbo
|
105 |
+
if reference_text is None:
|
106 |
+
print("[Info] No reference_text provided, transcribing reference_speech with Whisper.")
|
107 |
+
wh_model = whisper.load_model("large-v3-turbo")
|
108 |
+
result = wh_model.transcribe(reference_speech)
|
109 |
+
prefix_transcript = result["text"]
|
110 |
+
print(f"[Info] Whisper transcribed text: {prefix_transcript}")
|
111 |
+
else:
|
112 |
+
prefix_transcript = reference_text
|
113 |
+
|
114 |
+
# If target_duration not provided, estimate from reference speech + target_text
|
115 |
+
if target_duration is None:
|
116 |
+
target_generation_length = estimate_duration(reference_speech, target_text)
|
117 |
+
print(f"[Info] target_duration not provided, estimated as {target_generation_length:.2f} seconds. If not desired, please provide a target_duration.")
|
118 |
+
else:
|
119 |
+
target_generation_length = float(target_duration)
|
120 |
+
|
121 |
+
# signature from snippet
|
122 |
+
if args.n_codebooks == 4:
|
123 |
+
signature = "./pretrained/encodec_6f79c6a8.th"
|
124 |
+
elif args.n_codebooks == 8:
|
125 |
+
signature = "./pretrained/encodec_8cb1024_giga.th"
|
126 |
+
else:
|
127 |
+
# fallback, just use the 6-f79c6a8
|
128 |
+
signature = "./pretrained/encodec_6f79c6a8.th"
|
129 |
+
|
130 |
+
if silence_tokens is None:
|
131 |
+
# default from snippet
|
132 |
+
silence_tokens = []
|
133 |
+
|
134 |
+
if multi_trial is None:
|
135 |
+
# default from snippet
|
136 |
+
multi_trial = []
|
137 |
+
|
138 |
+
delay_pattern_increment = args.n_codebooks + 1 # from snippet
|
139 |
+
|
140 |
+
# We can compute prompt_end_frame if we want, from snippet
|
141 |
+
info = torchaudio.info(reference_speech)
|
142 |
+
prompt_end_frame = int(cut_off_sec * info.sample_rate)
|
143 |
+
|
144 |
+
# Prepare tokenizers
|
145 |
+
audio_tokenizer = AudioTokenizer(signature=signature)
|
146 |
+
text_tokenizer = TextTokenizer(backend="espeak")
|
147 |
+
|
148 |
+
# decode_config from snippet
|
149 |
+
decode_config = {
|
150 |
+
'top_k': top_k,
|
151 |
+
'top_p': top_p,
|
152 |
+
'min_p': min_p,
|
153 |
+
'temperature': temperature,
|
154 |
+
'stop_repetition': stop_repetition,
|
155 |
+
'kvcache': kvcache,
|
156 |
+
'codec_audio_sr': codec_audio_sr,
|
157 |
+
'codec_sr': codec_sr,
|
158 |
+
'silence_tokens': silence_tokens,
|
159 |
+
'sample_batch_size': sample_batch_size
|
160 |
+
}
|
161 |
+
|
162 |
+
# Run inference
|
163 |
+
print("[Info] Running TTS inference...")
|
164 |
+
concated_audio, gen_audio = inference_one_sample(
|
165 |
+
model, args, phn2num, text_tokenizer, audio_tokenizer,
|
166 |
+
reference_speech, target_text,
|
167 |
+
device, decode_config,
|
168 |
+
prompt_end_frame=prompt_end_frame,
|
169 |
+
target_generation_length=target_generation_length,
|
170 |
+
delay_pattern_increment=delay_pattern_increment,
|
171 |
+
prefix_transcript=prefix_transcript,
|
172 |
+
multi_trial=multi_trial,
|
173 |
+
repeat_prompt=repeat_prompt,
|
174 |
+
)
|
175 |
+
|
176 |
+
# The model returns a list of waveforms, pick the first
|
177 |
+
concated_audio, gen_audio = concated_audio[0].cpu(), gen_audio[0].cpu()
|
178 |
+
|
179 |
+
# Save the audio (just the generated portion, as the snippet does)
|
180 |
+
os.makedirs(output_dir, exist_ok=True)
|
181 |
+
out_filename = "generated.wav"
|
182 |
+
out_path = os.path.join(output_dir, out_filename)
|
183 |
+
torchaudio.save(out_path, gen_audio, codec_audio_sr)
|
184 |
+
|
185 |
+
print(f"[Success] Generated audio saved to {out_path}")
|
186 |
+
|
187 |
+
|
188 |
+
def main():
|
189 |
+
fire.Fire(run_inference)
|
190 |
+
|
191 |
+
if __name__ == "__main__":
|
192 |
+
main()
|
inference_gradio.py
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
gradio_tts_app.py
|
4 |
+
|
5 |
+
Run:
|
6 |
+
python gradio_tts_app.py
|
7 |
+
|
8 |
+
Then open the printed local or public URL in your browser.
|
9 |
+
"""
|
10 |
+
|
11 |
+
import os
|
12 |
+
import random
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
import torchaudio
|
16 |
+
import whisper
|
17 |
+
import gradio as gr
|
18 |
+
from argparse import Namespace
|
19 |
+
|
20 |
+
# ---------------------------------------------------------------------
|
21 |
+
# The following imports assume your local project structure:
|
22 |
+
# data/tokenizer.py
|
23 |
+
# models/voice_star.py
|
24 |
+
# inference_tts_utils.py
|
25 |
+
# Adjust if needed.
|
26 |
+
# ---------------------------------------------------------------------
|
27 |
+
from data.tokenizer import AudioTokenizer, TextTokenizer
|
28 |
+
from models import voice_star
|
29 |
+
from inference_tts_utils import inference_one_sample
|
30 |
+
|
31 |
+
|
32 |
+
############################################################
|
33 |
+
# Utility Functions
|
34 |
+
############################################################
|
35 |
+
|
36 |
+
def seed_everything(seed=1):
|
37 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
38 |
+
random.seed(seed)
|
39 |
+
np.random.seed(seed)
|
40 |
+
torch.manual_seed(seed)
|
41 |
+
torch.cuda.manual_seed(seed)
|
42 |
+
torch.backends.cudnn.benchmark = False
|
43 |
+
torch.backends.cudnn.deterministic = True
|
44 |
+
|
45 |
+
|
46 |
+
def estimate_duration(ref_audio_path, text):
|
47 |
+
"""
|
48 |
+
Estimate duration based on seconds per character from the reference audio.
|
49 |
+
"""
|
50 |
+
info = torchaudio.info(ref_audio_path)
|
51 |
+
audio_duration = info.num_frames / info.sample_rate
|
52 |
+
length_text = max(len(text), 1)
|
53 |
+
spc = audio_duration / length_text # seconds per character
|
54 |
+
return len(text) * spc
|
55 |
+
|
56 |
+
|
57 |
+
############################################################
|
58 |
+
# Main Inference Function
|
59 |
+
############################################################
|
60 |
+
|
61 |
+
def run_inference(
|
62 |
+
# User-adjustable parameters (no "# do not change" in snippet)
|
63 |
+
reference_speech="./demo/5895_34622_000026_000002.wav",
|
64 |
+
target_text="VoiceStar is a very interesting model, it's duration controllable and can extrapolate",
|
65 |
+
model_name="VoiceStar_840M_40s",
|
66 |
+
model_root="./pretrained",
|
67 |
+
reference_text=None, # optional
|
68 |
+
target_duration=None, # optional
|
69 |
+
top_k=10, # can try 10, 20, 30, 40
|
70 |
+
temperature=1,
|
71 |
+
kvcache=1, # if OOM, set to 0
|
72 |
+
repeat_prompt=1, # use higher to improve speaker similarity
|
73 |
+
stop_repetition=3, # snippet says "will not use it" but not "do not change"
|
74 |
+
seed=1,
|
75 |
+
output_dir="./generated_tts",
|
76 |
+
|
77 |
+
# Non-adjustable parameters (based on snippet instructions)
|
78 |
+
codec_audio_sr=16000, # do not change
|
79 |
+
codec_sr=50, # do not change
|
80 |
+
top_p=1, # do not change
|
81 |
+
min_p=1, # do not change
|
82 |
+
silence_tokens=None, # do not change it
|
83 |
+
multi_trial=None, # do not change it
|
84 |
+
sample_batch_size=1, # do not change
|
85 |
+
cut_off_sec=100, # do not adjust
|
86 |
+
):
|
87 |
+
"""
|
88 |
+
Inference script for VoiceStar TTS.
|
89 |
+
"""
|
90 |
+
# 1. Set seed
|
91 |
+
seed_everything(seed)
|
92 |
+
|
93 |
+
# 2. Load model checkpoint
|
94 |
+
torch.serialization.add_safe_globals([Namespace])
|
95 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
96 |
+
ckpt_fn = os.path.join(model_root, model_name + ".pth")
|
97 |
+
if not os.path.exists(ckpt_fn):
|
98 |
+
# use wget to download
|
99 |
+
print(f"[Info] Downloading {model_name} checkpoint...")
|
100 |
+
os.system(f"wget https://huggingface.co/pyp1/VoiceStar/resolve/main/{model_name}.pth?download=true -O {ckpt_fn}")
|
101 |
+
bundle = torch.load(ckpt_fn, map_location=device, weights_only=True)
|
102 |
+
args = bundle["args"]
|
103 |
+
phn2num = bundle["phn2num"]
|
104 |
+
|
105 |
+
model = voice_star.VoiceStar(args)
|
106 |
+
model.load_state_dict(bundle["model"])
|
107 |
+
model.to(device)
|
108 |
+
model.eval()
|
109 |
+
|
110 |
+
# 3. If reference_text not provided, transcribe reference speech with Whisper
|
111 |
+
if reference_text is None:
|
112 |
+
print("[Info] No reference_text provided. Transcribing reference_speech with Whisper (large-v3-turbo).")
|
113 |
+
wh_model = whisper.load_model("large-v3-turbo")
|
114 |
+
result = wh_model.transcribe(reference_speech)
|
115 |
+
prefix_transcript = result["text"]
|
116 |
+
print(f"[Info] Whisper transcribed text: {prefix_transcript}")
|
117 |
+
else:
|
118 |
+
prefix_transcript = reference_text
|
119 |
+
|
120 |
+
# 4. If target_duration not provided, estimate from reference speech + target_text
|
121 |
+
if target_duration is None:
|
122 |
+
target_generation_length = estimate_duration(reference_speech, target_text)
|
123 |
+
print(f"[Info] target_duration not provided, estimated as {target_generation_length:.2f}s. Provide --target_duration if needed.")
|
124 |
+
else:
|
125 |
+
target_generation_length = float(target_duration)
|
126 |
+
|
127 |
+
# 5. Prepare signature from snippet
|
128 |
+
if args.n_codebooks == 4:
|
129 |
+
signature = "./pretrained/encodec_6f79c6a8.th"
|
130 |
+
elif args.n_codebooks == 8:
|
131 |
+
signature = "./pretrained/encodec_8cb1024_giga.th"
|
132 |
+
else:
|
133 |
+
signature = "./pretrained/encodec_6f79c6a8.th"
|
134 |
+
|
135 |
+
if silence_tokens is None:
|
136 |
+
silence_tokens = []
|
137 |
+
|
138 |
+
if multi_trial is None:
|
139 |
+
multi_trial = []
|
140 |
+
|
141 |
+
delay_pattern_increment = args.n_codebooks + 1 # from snippet
|
142 |
+
|
143 |
+
info = torchaudio.info(reference_speech)
|
144 |
+
prompt_end_frame = int(cut_off_sec * info.sample_rate)
|
145 |
+
|
146 |
+
# 6. Tokenizers
|
147 |
+
audio_tokenizer = AudioTokenizer(signature=signature)
|
148 |
+
text_tokenizer = TextTokenizer(backend="espeak")
|
149 |
+
|
150 |
+
# 7. decode_config
|
151 |
+
decode_config = {
|
152 |
+
"top_k": top_k,
|
153 |
+
"top_p": top_p,
|
154 |
+
"min_p": min_p,
|
155 |
+
"temperature": temperature,
|
156 |
+
"stop_repetition": stop_repetition,
|
157 |
+
"kvcache": kvcache,
|
158 |
+
"codec_audio_sr": codec_audio_sr,
|
159 |
+
"codec_sr": codec_sr,
|
160 |
+
"silence_tokens": silence_tokens,
|
161 |
+
"sample_batch_size": sample_batch_size,
|
162 |
+
}
|
163 |
+
|
164 |
+
# 8. Run inference
|
165 |
+
print("[Info] Running TTS inference...")
|
166 |
+
concated_audio, gen_audio = inference_one_sample(
|
167 |
+
model, args, phn2num, text_tokenizer, audio_tokenizer,
|
168 |
+
reference_speech, target_text,
|
169 |
+
device, decode_config,
|
170 |
+
prompt_end_frame=prompt_end_frame,
|
171 |
+
target_generation_length=target_generation_length,
|
172 |
+
delay_pattern_increment=delay_pattern_increment,
|
173 |
+
prefix_transcript=prefix_transcript,
|
174 |
+
multi_trial=multi_trial,
|
175 |
+
repeat_prompt=repeat_prompt,
|
176 |
+
)
|
177 |
+
|
178 |
+
# The model returns a list of waveforms, pick the first
|
179 |
+
concated_audio, gen_audio = concated_audio[0].cpu(), gen_audio[0].cpu()
|
180 |
+
|
181 |
+
# 9. Save generated audio
|
182 |
+
os.makedirs(output_dir, exist_ok=True)
|
183 |
+
out_filename = "generated.wav"
|
184 |
+
out_path = os.path.join(output_dir, out_filename)
|
185 |
+
torchaudio.save(out_path, gen_audio, codec_audio_sr)
|
186 |
+
|
187 |
+
print(f"[Success] Generated audio saved to {out_path}")
|
188 |
+
return out_path # Return the path for Gradio to load
|
189 |
+
|
190 |
+
|
191 |
+
############################
|
192 |
+
# Transcription function
|
193 |
+
############################
|
194 |
+
|
195 |
+
def transcribe_audio(reference_speech):
|
196 |
+
"""
|
197 |
+
Transcribe uploaded reference audio with Whisper, return text.
|
198 |
+
If no file, return empty string.
|
199 |
+
"""
|
200 |
+
if reference_speech is None:
|
201 |
+
return ""
|
202 |
+
audio_path = reference_speech # Because type="filepath"
|
203 |
+
|
204 |
+
if not os.path.exists(audio_path):
|
205 |
+
return "File not found."
|
206 |
+
|
207 |
+
print("[Info] Transcribing with Whisper...")
|
208 |
+
model = whisper.load_model("medium") # or "large-v2" etc.
|
209 |
+
result = model.transcribe(audio_path)
|
210 |
+
return result["text"]
|
211 |
+
|
212 |
+
############################
|
213 |
+
# Gradio UI
|
214 |
+
############################
|
215 |
+
|
216 |
+
def main():
|
217 |
+
with gr.Blocks() as demo:
|
218 |
+
gr.Markdown("## VoiceStar TTS with Editable Reference Text")
|
219 |
+
|
220 |
+
with gr.Row():
|
221 |
+
reference_speech_input = gr.Audio(
|
222 |
+
label="Reference Speech",
|
223 |
+
type="filepath",
|
224 |
+
elem_id="ref_speech"
|
225 |
+
)
|
226 |
+
transcribe_button = gr.Button("Transcribe")
|
227 |
+
|
228 |
+
# The transcribed text appears here and can be edited
|
229 |
+
reference_text_box = gr.Textbox(
|
230 |
+
label="Reference Text (Editable)",
|
231 |
+
placeholder="Click 'Transcribe' to auto-fill from reference speech...",
|
232 |
+
lines=2
|
233 |
+
)
|
234 |
+
|
235 |
+
target_text_box = gr.Textbox(
|
236 |
+
label="Target Text",
|
237 |
+
value="VoiceStar is a very interesting model, it's duration controllable and can extrapolate to unseen duration.",
|
238 |
+
lines=3
|
239 |
+
)
|
240 |
+
|
241 |
+
model_name_box = gr.Textbox(
|
242 |
+
label="Model Name",
|
243 |
+
value="VoiceStar_840M_40s"
|
244 |
+
)
|
245 |
+
|
246 |
+
model_root_box = gr.Textbox(
|
247 |
+
label="Model Root Directory",
|
248 |
+
value="/data1/scratch/pyp/BoostedVoiceEditor/runs"
|
249 |
+
)
|
250 |
+
|
251 |
+
reference_duration_box = gr.Textbox(
|
252 |
+
label="Target Duration (Optional)",
|
253 |
+
placeholder="Leave empty for auto-estimate."
|
254 |
+
)
|
255 |
+
|
256 |
+
top_k_box = gr.Number(label="top_k", value=10)
|
257 |
+
temperature_box = gr.Number(label="temperature", value=1.0)
|
258 |
+
kvcache_box = gr.Number(label="kvcache (1 or 0)", value=1)
|
259 |
+
repeat_prompt_box = gr.Number(label="repeat_prompt", value=1)
|
260 |
+
stop_repetition_box = gr.Number(label="stop_repetition", value=3)
|
261 |
+
seed_box = gr.Number(label="Random Seed", value=1)
|
262 |
+
output_dir_box = gr.Textbox(label="Output Directory", value="./generated_tts")
|
263 |
+
|
264 |
+
generate_button = gr.Button("Generate TTS")
|
265 |
+
output_audio = gr.Audio(label="Generated Audio", type="filepath")
|
266 |
+
|
267 |
+
# 1) When user clicks "Transcribe", we call `transcribe_audio`
|
268 |
+
transcribe_button.click(
|
269 |
+
fn=transcribe_audio,
|
270 |
+
inputs=[reference_speech_input],
|
271 |
+
outputs=[reference_text_box],
|
272 |
+
)
|
273 |
+
|
274 |
+
# 2) The actual TTS generation function.
|
275 |
+
def gradio_inference(
|
276 |
+
reference_speech,
|
277 |
+
reference_text,
|
278 |
+
target_text,
|
279 |
+
model_name,
|
280 |
+
model_root,
|
281 |
+
target_duration,
|
282 |
+
top_k,
|
283 |
+
temperature,
|
284 |
+
kvcache,
|
285 |
+
repeat_prompt,
|
286 |
+
stop_repetition,
|
287 |
+
seed,
|
288 |
+
output_dir
|
289 |
+
):
|
290 |
+
# Convert any empty strings to None for optional fields
|
291 |
+
dur = float(target_duration) if target_duration else None
|
292 |
+
|
293 |
+
out_path = run_inference(
|
294 |
+
reference_speech=reference_speech,
|
295 |
+
reference_text=reference_text if reference_text else None,
|
296 |
+
target_text=target_text,
|
297 |
+
model_name=model_name,
|
298 |
+
model_root=model_root,
|
299 |
+
target_duration=dur,
|
300 |
+
top_k=int(top_k),
|
301 |
+
temperature=float(temperature),
|
302 |
+
kvcache=int(kvcache),
|
303 |
+
repeat_prompt=int(repeat_prompt),
|
304 |
+
stop_repetition=int(stop_repetition),
|
305 |
+
seed=int(seed),
|
306 |
+
output_dir=output_dir
|
307 |
+
)
|
308 |
+
return out_path
|
309 |
+
|
310 |
+
# 3) Link the "Generate TTS" button
|
311 |
+
generate_button.click(
|
312 |
+
fn=gradio_inference,
|
313 |
+
inputs=[
|
314 |
+
reference_speech_input,
|
315 |
+
reference_text_box,
|
316 |
+
target_text_box,
|
317 |
+
model_name_box,
|
318 |
+
model_root_box,
|
319 |
+
reference_duration_box,
|
320 |
+
top_k_box,
|
321 |
+
temperature_box,
|
322 |
+
kvcache_box,
|
323 |
+
repeat_prompt_box,
|
324 |
+
stop_repetition_box,
|
325 |
+
seed_box,
|
326 |
+
output_dir_box
|
327 |
+
],
|
328 |
+
outputs=[output_audio],
|
329 |
+
)
|
330 |
+
|
331 |
+
demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)
|
332 |
+
|
333 |
+
if __name__ == "__main__":
|
334 |
+
main()
|
inference_tts_utils.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse, pickle
|
2 |
+
import logging
|
3 |
+
import os, random
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torchaudio
|
7 |
+
|
8 |
+
from data.tokenizer import (
|
9 |
+
AudioTokenizer,
|
10 |
+
TextTokenizer,
|
11 |
+
tokenize_audio,
|
12 |
+
tokenize_text
|
13 |
+
)
|
14 |
+
import argparse, time, tqdm
|
15 |
+
|
16 |
+
|
17 |
+
# this script only works for the musicgen architecture
|
18 |
+
def get_args():
|
19 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
20 |
+
parser.add_argument("--manifest_fn", type=str, default="path/to/eval_metadata_file")
|
21 |
+
parser.add_argument("--audio_root", type=str, default="path/to/audio_folder")
|
22 |
+
parser.add_argument("--exp_dir", type=str, default="path/to/model_folder")
|
23 |
+
parser.add_argument("--seed", type=int, default=1)
|
24 |
+
parser.add_argument("--codec_audio_sr", type=int, default=16000, help='the sample rate of audio that the codec is trained for')
|
25 |
+
parser.add_argument("--codec_sr", type=int, default=50, help='the sample rate of the codec codes')
|
26 |
+
parser.add_argument("--top_k", type=int, default=0, help="sampling param")
|
27 |
+
parser.add_argument("--top_p", type=float, default=0.8, help="sampling param")
|
28 |
+
parser.add_argument("--temperature", type=float, default=1.0, help="sampling param")
|
29 |
+
parser.add_argument("--output_dir", type=str, default=None)
|
30 |
+
parser.add_argument("--device", type=str, default="cuda")
|
31 |
+
parser.add_argument("--signature", type=str, default=None, help="path to the encodec model")
|
32 |
+
parser.add_argument("--crop_concat", type=int, default=0)
|
33 |
+
parser.add_argument("--stop_repetition", type=int, default=-1, help="used for inference, when the number of consecutive repetition of a token is bigger than this, stop it")
|
34 |
+
parser.add_argument("--kvcache", type=int, default=1, help='if true, use kv cache, which is 4-8x faster than without')
|
35 |
+
parser.add_argument("--sample_batch_size", type=int, default=1, help="batch size for sampling, NOTE that it's not running inference for several samples, but duplicate one input sample batch_size times, and during inference, we only return the shortest generation")
|
36 |
+
parser.add_argument("--silence_tokens", type=str, default="[1388,1898,131]", help="note that if you are not using the pretrained encodec 6f79c6a8, make sure you specified it yourself, rather than using the default")
|
37 |
+
return parser.parse_args()
|
38 |
+
|
39 |
+
|
40 |
+
@torch.no_grad()
|
41 |
+
def inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_tokenizer, audio_fn, target_text, device, decode_config, prompt_end_frame, target_generation_length, delay_pattern_increment, prefix_transcript=None, quiet=False, repeat_prompt=0, multi_trial=[]):
|
42 |
+
# seq_len_thres = 500 # 10s, 26% of the data in seed tts
|
43 |
+
# encode audio
|
44 |
+
encoded_frames = tokenize_audio(audio_tokenizer, audio_fn, offset=0, num_frames=prompt_end_frame)
|
45 |
+
# if sequence length is shorter than seq_len_thres, repeat the audio
|
46 |
+
# if encoded_frames.shape[2] < seq_len_thres:
|
47 |
+
# encoded_frames = torch.cat([encoded_frames, encoded_frames, encoded_frames], dim=2)
|
48 |
+
# doubled = True
|
49 |
+
single_encoded_frames = encoded_frames
|
50 |
+
|
51 |
+
if isinstance(repeat_prompt, int) and repeat_prompt > 0:
|
52 |
+
cur_repeat_prompt = repeat_prompt
|
53 |
+
while cur_repeat_prompt > 0:
|
54 |
+
encoded_frames = torch.cat([encoded_frames, single_encoded_frames], dim=2)
|
55 |
+
cur_repeat_prompt -= 1
|
56 |
+
elif isinstance(repeat_prompt, str) and repeat_prompt.lower() == "max":
|
57 |
+
repeat_prompt = 0
|
58 |
+
while encoded_frames.shape[2] + decode_config['codec_sr'] * target_generation_length + delay_pattern_increment + single_encoded_frames.shape[2] < model_args.audio_max_length * decode_config['codec_sr']:
|
59 |
+
encoded_frames = torch.cat([encoded_frames, single_encoded_frames], dim=2)
|
60 |
+
repeat_prompt += 1
|
61 |
+
if getattr(model_args, "y_sep_token", None) != None:
|
62 |
+
encoded_frames = torch.cat([encoded_frames, torch.LongTensor([model_args.y_sep_token]*model_args.n_codebooks).unsqueeze(0).unsqueeze(2).to(encoded_frames.device)], dim=2)
|
63 |
+
# print(encoded_frames.shape)
|
64 |
+
original_audio = encoded_frames.transpose(2,1) # [1,T,K]
|
65 |
+
assert original_audio.ndim==3 and original_audio.shape[0] == 1 and original_audio.shape[2] == model_args.n_codebooks, original_audio.shape
|
66 |
+
|
67 |
+
# phonemize
|
68 |
+
if isinstance(target_text, list):
|
69 |
+
text_tokens = [phn2num[phn] for phn in target_text if phn in phn2num]
|
70 |
+
else:
|
71 |
+
text_tokens = [phn2num[phn] for phn in
|
72 |
+
tokenize_text(
|
73 |
+
text_tokenizer, text=target_text.strip()
|
74 |
+
) if phn in phn2num
|
75 |
+
]
|
76 |
+
if getattr(model_args, "x_sep_token", None) != None:
|
77 |
+
assert prefix_transcript != None, "prefix_transcript must be provided if x_sep_token is not None"
|
78 |
+
if prefix_transcript is not None:
|
79 |
+
if isinstance(prefix_transcript, list):
|
80 |
+
prefix_tokens = [phn2num[phn] for phn in prefix_transcript if phn in phn2num]
|
81 |
+
else:
|
82 |
+
prefix_tokens = [phn2num[phn] for phn in
|
83 |
+
tokenize_text(
|
84 |
+
text_tokenizer, text=prefix_transcript.strip()
|
85 |
+
) if phn in phn2num
|
86 |
+
]
|
87 |
+
# if doubled:
|
88 |
+
# prefix_tokens = prefix_tokens + prefix_tokens + prefix_tokens
|
89 |
+
single_prefix_tokens = prefix_tokens
|
90 |
+
while repeat_prompt > 0:
|
91 |
+
prefix_tokens = prefix_tokens + single_prefix_tokens
|
92 |
+
repeat_prompt -= 1
|
93 |
+
if getattr(model_args, "x_sep_token", None) != None:
|
94 |
+
text_tokens = prefix_tokens + [getattr(model_args, "x_sep_token", None)] + text_tokens
|
95 |
+
else:
|
96 |
+
text_tokens = prefix_tokens + text_tokens
|
97 |
+
if getattr(model_args, "add_eos_to_text", 0) != 0:
|
98 |
+
text_tokens.append(model_args.add_eos_to_text)
|
99 |
+
if getattr(model_args, "add_bos_to_text", 0) != 0:
|
100 |
+
text_tokens = [model_args.add_bos_to_text] + text_tokens
|
101 |
+
text_tokens = torch.LongTensor(text_tokens).unsqueeze(0)
|
102 |
+
text_tokens_lens = torch.LongTensor([text_tokens.shape[-1]])
|
103 |
+
|
104 |
+
if not quiet:
|
105 |
+
logging.info(f"original audio length: {original_audio.shape[1]} codec frames, which is {original_audio.shape[1]/decode_config['codec_sr']:.2f} sec.")
|
106 |
+
|
107 |
+
|
108 |
+
if getattr(model_args, "parallel_pattern", 0) != 0:
|
109 |
+
tgt_y_lens = torch.LongTensor([int(original_audio.shape[1] + decode_config['codec_sr'] * target_generation_length + 2)]) # parallel pattern, therefore only add the empty_token (i.e. the sos token) and eos (i.e. 2 more tokens). Note that the delayed pattern between, both sos and eos is counted (sos is counted in the n_codebooks, eos is counted in the 1)
|
110 |
+
else:
|
111 |
+
tgt_y_lens = torch.LongTensor([int(original_audio.shape[1] + decode_config['codec_sr'] * target_generation_length + delay_pattern_increment)]) # delay pattern increment has accounted for the added eos
|
112 |
+
|
113 |
+
# forward
|
114 |
+
assert decode_config['sample_batch_size'] <= 1
|
115 |
+
stime = time.time()
|
116 |
+
assert multi_trial == []
|
117 |
+
if not quiet:
|
118 |
+
logging.info(f"running inference with batch size 1")
|
119 |
+
concat_frames, gen_frames = model.inference_tts(
|
120 |
+
text_tokens.to(device),
|
121 |
+
text_tokens_lens.to(device),
|
122 |
+
original_audio[...,:model_args.n_codebooks].to(device), # [1,T,8]
|
123 |
+
tgt_y_lens = tgt_y_lens.to(device),
|
124 |
+
top_k=decode_config['top_k'],
|
125 |
+
top_p=decode_config['top_p'],
|
126 |
+
min_p=decode_config['min_p'],
|
127 |
+
temperature=decode_config['temperature'],
|
128 |
+
stop_repetition=decode_config['stop_repetition'],
|
129 |
+
kvcache=decode_config['kvcache'],
|
130 |
+
silence_tokens=eval(decode_config['silence_tokens']) if type(decode_config['silence_tokens'])==str else decode_config['silence_tokens']
|
131 |
+
) # output is [1,K,T]
|
132 |
+
if not quiet:
|
133 |
+
logging.info(f"inference on one sample take: {time.time() - stime:.4f} sec.")
|
134 |
+
|
135 |
+
logging.info(f"generated encoded_frames.shape: {gen_frames.shape}, which is {gen_frames.shape[-1]/decode_config['codec_sr']} sec.")
|
136 |
+
|
137 |
+
# for timestamp, codes in enumerate(gen_frames[0].transpose(1,0)):
|
138 |
+
# logging.info(f"{timestamp}: {codes.tolist()}")
|
139 |
+
# decode (both original and generated)
|
140 |
+
# concat_sample = audio_tokenizer.decode(
|
141 |
+
# [(concat_frames, None)] # [1,T,8] -> [1,8,T]
|
142 |
+
# )
|
143 |
+
if getattr(model_args, "y_sep_token", None) != None:
|
144 |
+
concat_frames = torch.cat([concat_frames[:, :, :original_audio.shape[1]-1], concat_frames[:, :, original_audio.shape[1]:]], dim=2)
|
145 |
+
concat_sample = audio_tokenizer.decode(
|
146 |
+
concat_frames # [1,8,T]
|
147 |
+
)
|
148 |
+
gen_sample = audio_tokenizer.decode(
|
149 |
+
gen_frames
|
150 |
+
)
|
151 |
+
#Empty cuda cache between runs
|
152 |
+
if torch.cuda.is_available():
|
153 |
+
torch.cuda.empty_cache()
|
154 |
+
# return
|
155 |
+
return concat_sample, gen_sample
|
main.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import torch, os
|
3 |
+
|
4 |
+
from tqdm import tqdm
|
5 |
+
import pickle
|
6 |
+
import argparse
|
7 |
+
import logging, datetime
|
8 |
+
import torch.distributed as dist
|
9 |
+
from config import MyParser
|
10 |
+
from steps import trainer
|
11 |
+
from copy_codebase import copy_codebase
|
12 |
+
|
13 |
+
def world_info_from_env():
|
14 |
+
local_rank = int(os.environ["LOCAL_RANK"])
|
15 |
+
global_rank = int(os.environ["RANK"])
|
16 |
+
world_size = int(os.environ["WORLD_SIZE"])
|
17 |
+
return local_rank, global_rank, world_size
|
18 |
+
|
19 |
+
if __name__ == "__main__":
|
20 |
+
formatter = (
|
21 |
+
"%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s"
|
22 |
+
)
|
23 |
+
logging.basicConfig(format=formatter, level=logging.INFO)
|
24 |
+
|
25 |
+
torch.cuda.empty_cache()
|
26 |
+
args = MyParser().parse_args()
|
27 |
+
exp_dir = Path(args.exp_dir)
|
28 |
+
exp_dir.mkdir(exist_ok=True, parents=True)
|
29 |
+
logging.info(f"exp_dir: {str(exp_dir)}")
|
30 |
+
|
31 |
+
if args.resume and (os.path.exists("%s/bundle.pth" % args.exp_dir) or os.path.exists("%s/bundle_prev.pth" % args.exp_dir)):
|
32 |
+
if not os.path.exists("%s/bundle.pth" % args.exp_dir):
|
33 |
+
os.system(f"cp {args.exp_dir}/bundle_prev.pth {args.exp_dir}/bundle.pth")
|
34 |
+
resume = args.resume
|
35 |
+
assert(bool(args.exp_dir))
|
36 |
+
with open("%s/args.pkl" % args.exp_dir, "rb") as f:
|
37 |
+
old_args = pickle.load(f)
|
38 |
+
new_args = vars(args)
|
39 |
+
old_args = vars(old_args)
|
40 |
+
for key in new_args:
|
41 |
+
if key not in old_args or old_args[key] != new_args[key]:
|
42 |
+
old_args[key] = new_args[key]
|
43 |
+
args = argparse.Namespace(**old_args)
|
44 |
+
args.resume = resume
|
45 |
+
else:
|
46 |
+
args.resume = False
|
47 |
+
with open("%s/args.pkl" % args.exp_dir, "wb") as f:
|
48 |
+
pickle.dump(args, f)
|
49 |
+
|
50 |
+
# make timeout longer (for generation)
|
51 |
+
timeout = datetime.timedelta(seconds=7200) # 60 minutes
|
52 |
+
|
53 |
+
if args.multinodes:
|
54 |
+
_local_rank, _, _ = world_info_from_env()
|
55 |
+
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, timeout=timeout)
|
56 |
+
else:
|
57 |
+
dist.init_process_group(backend='nccl', init_method='env://', timeout=timeout)
|
58 |
+
|
59 |
+
if args.local_wandb:
|
60 |
+
os.environ["WANDB_MODE"] = "offline"
|
61 |
+
|
62 |
+
rank = dist.get_rank()
|
63 |
+
if rank == 0:
|
64 |
+
logging.info(args)
|
65 |
+
logging.info(f"exp_dir: {str(exp_dir)}")
|
66 |
+
world_size = dist.get_world_size()
|
67 |
+
|
68 |
+
local_rank = int(_local_rank) if args.multinodes else rank
|
69 |
+
num_devices= torch.cuda.device_count()
|
70 |
+
logging.info(f"{local_rank=}, {rank=}, {world_size=}, {type(local_rank)=}, {type(rank)=}, {type(world_size)=}")
|
71 |
+
for device_idx in range(num_devices):
|
72 |
+
device_name = torch.cuda.get_device_name(device_idx)
|
73 |
+
logging.info(f"Device {device_idx}: {device_name}")
|
74 |
+
|
75 |
+
torch.cuda.set_device(local_rank)
|
76 |
+
if rank == 0:
|
77 |
+
user_dir = os.path.expanduser("~")
|
78 |
+
codebase_name = "VoiceStar"
|
79 |
+
now = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
|
80 |
+
copy_codebase(os.path.join(user_dir, codebase_name), os.path.join(exp_dir, f"{codebase_name}_{now}"), max_size_mb=5, gitignore_path=os.path.join(user_dir, codebase_name, ".gitignore"))
|
81 |
+
my_trainer = trainer.Trainer(args, world_size, rank, local_rank)
|
82 |
+
my_trainer.train()
|
models/modules/__init__.py
ADDED
File without changes
|
models/modules/activation.py
ADDED
@@ -0,0 +1,781 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
|
2 |
+
from typing import Optional, Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import Tensor
|
6 |
+
from torch.nn import Linear, Module
|
7 |
+
from torch.nn import functional as F
|
8 |
+
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
|
9 |
+
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
10 |
+
from torch.nn.parameter import Parameter
|
11 |
+
import logging
|
12 |
+
from typing import Callable, List, Optional, Tuple, Union
|
13 |
+
from typing import TYPE_CHECKING
|
14 |
+
if TYPE_CHECKING:
|
15 |
+
from torch.types import _dtype as DType
|
16 |
+
else:
|
17 |
+
# The JIT doesn't understand Union, nor torch.dtype here
|
18 |
+
DType = int
|
19 |
+
|
20 |
+
def _canonical_mask(
|
21 |
+
mask: Optional[Tensor],
|
22 |
+
mask_name: str,
|
23 |
+
other_type: Optional[DType],
|
24 |
+
other_name: str,
|
25 |
+
target_type: DType,
|
26 |
+
check_other: bool = True,
|
27 |
+
) -> Optional[Tensor]:
|
28 |
+
|
29 |
+
if mask is not None:
|
30 |
+
_mask_dtype = mask.dtype
|
31 |
+
_mask_is_float = torch.is_floating_point(mask)
|
32 |
+
if _mask_dtype != torch.bool and not _mask_is_float:
|
33 |
+
raise AssertionError(
|
34 |
+
f"only bool and floating types of {mask_name} are supported")
|
35 |
+
if check_other and other_type is not None:
|
36 |
+
if _mask_dtype != other_type:
|
37 |
+
warnings.warn(
|
38 |
+
f"Support for mismatched {mask_name} and {other_name} "
|
39 |
+
"is deprecated. Use same type for both instead."
|
40 |
+
)
|
41 |
+
if not _mask_is_float:
|
42 |
+
mask = (
|
43 |
+
torch.zeros_like(mask, dtype=target_type)
|
44 |
+
.masked_fill_(mask, float("-inf"))
|
45 |
+
)
|
46 |
+
return mask
|
47 |
+
|
48 |
+
def _in_projection_packed(
|
49 |
+
q: Tensor,
|
50 |
+
k: Tensor,
|
51 |
+
v: Tensor,
|
52 |
+
w: Tensor,
|
53 |
+
b: Optional[Tensor] = None,
|
54 |
+
) -> List[Tensor]:
|
55 |
+
r"""
|
56 |
+
Performs the in-projection step of the attention operation, using packed weights.
|
57 |
+
Output is a triple containing projection tensors for query, key and value.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
q, k, v: query, key and value tensors to be projected. For self-attention,
|
61 |
+
these are typically the same tensor; for encoder-decoder attention,
|
62 |
+
k and v are typically the same tensor. (We take advantage of these
|
63 |
+
identities for performance if they are present.) Regardless, q, k and v
|
64 |
+
must share a common embedding dimension; otherwise their shapes may vary.
|
65 |
+
w: projection weights for q, k and v, packed into a single tensor. Weights
|
66 |
+
are packed along dimension 0, in q, k, v order.
|
67 |
+
b: optional projection biases for q, k and v, packed into a single tensor
|
68 |
+
in q, k, v order.
|
69 |
+
|
70 |
+
Shape:
|
71 |
+
Inputs:
|
72 |
+
- q: :math:`(..., E)` where E is the embedding dimension
|
73 |
+
- k: :math:`(..., E)` where E is the embedding dimension
|
74 |
+
- v: :math:`(..., E)` where E is the embedding dimension
|
75 |
+
- w: :math:`(E * 3, E)` where E is the embedding dimension
|
76 |
+
- b: :math:`E * 3` where E is the embedding dimension
|
77 |
+
|
78 |
+
Output:
|
79 |
+
- in output list :math:`[q', k', v']`, each output tensor will have the
|
80 |
+
same shape as the corresponding input tensor.
|
81 |
+
"""
|
82 |
+
E = q.size(-1)
|
83 |
+
if k is v:
|
84 |
+
if q is k:
|
85 |
+
# self-attention
|
86 |
+
proj = F.linear(q, w, b)
|
87 |
+
# reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
|
88 |
+
proj = proj.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
|
89 |
+
return proj[0], proj[1], proj[2]
|
90 |
+
else:
|
91 |
+
# encoder-decoder attention
|
92 |
+
w_q, w_kv = w.split([E, E * 2])
|
93 |
+
if b is None:
|
94 |
+
b_q = b_kv = None
|
95 |
+
else:
|
96 |
+
b_q, b_kv = b.split([E, E * 2])
|
97 |
+
q_proj = F.linear(q, w_q, b_q)
|
98 |
+
kv_proj = F.linear(k, w_kv, b_kv)
|
99 |
+
# reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk()
|
100 |
+
kv_proj = kv_proj.unflatten(-1, (2, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
|
101 |
+
return (q_proj, kv_proj[0], kv_proj[1])
|
102 |
+
else:
|
103 |
+
w_q, w_k, w_v = w.chunk(3)
|
104 |
+
if b is None:
|
105 |
+
b_q = b_k = b_v = None
|
106 |
+
else:
|
107 |
+
b_q, b_k, b_v = b.chunk(3)
|
108 |
+
return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)
|
109 |
+
|
110 |
+
def _none_or_dtype(input: Optional[Tensor]) -> Optional[DType]:
|
111 |
+
if input is None:
|
112 |
+
return None
|
113 |
+
elif isinstance(input, torch.Tensor):
|
114 |
+
return input.dtype
|
115 |
+
raise RuntimeError("input to _none_or_dtype() must be None or torch.Tensor")
|
116 |
+
|
117 |
+
def rotate_half(x):
|
118 |
+
x1 = x[..., :x.shape[-1] // 2]
|
119 |
+
x2 = x[..., x.shape[-1] // 2:]
|
120 |
+
return torch.cat([-x2, x1], dim=-1)
|
121 |
+
|
122 |
+
def apply_rotary_pos_emb(q, k, q_sinu=None, k_sinu=None, sinu=None, unsqueeze_dim=1, args=None, q_offset=0):
|
123 |
+
if sinu is not None:
|
124 |
+
q_emb = q * sinu['cos'][:, q_offset:q_offset+q.shape[2]].unsqueeze(unsqueeze_dim) + rotate_half(q) * sinu['sin'][:, q_offset:q_offset+q.shape[2]].unsqueeze(unsqueeze_dim)
|
125 |
+
k_emb = k * sinu['cos'][:, :k.shape[2]].unsqueeze(unsqueeze_dim) + rotate_half(k) * sinu['sin'][:, :k.shape[2]].unsqueeze(unsqueeze_dim)
|
126 |
+
if q_sinu is not None:
|
127 |
+
assert sinu is None, "sinu must be None"
|
128 |
+
q_emb = q * q_sinu['cos'][:, :, q_offset:q_offset+q.shape[2]] + rotate_half(q) * q_sinu['sin'][:, :, q_offset:q_offset+q.shape[2]]
|
129 |
+
k_emb = k * k_sinu['cos'][:, :, :k.shape[2]] + rotate_half(k) * k_sinu['sin'][:, :, :k.shape[2]]
|
130 |
+
# else:
|
131 |
+
# assert freqs is not None, "freqs must be provided"
|
132 |
+
# assert key_lens is not None, "key_lens must be provided"
|
133 |
+
# assert query_lens is not None, "query_lens must be provided"
|
134 |
+
# # key_multiple
|
135 |
+
# assert key_lens.ndim==1, key_lens
|
136 |
+
# assert query_lens.ndim==1, query_lens
|
137 |
+
# q_lens_expanded = query_lens.unsqueeze(-1).unsqueeze(-1) # [B, 1, 1]
|
138 |
+
# k_lens_expanded = key_lens.unsqueeze(-1).unsqueeze(-1) # [B, 1, 1]
|
139 |
+
# query_ids_multiple = q_lens_expanded / (q_lens_expanded - 1)
|
140 |
+
# key_ids_multiple = k_lens_expanded / (k_lens_expanded - 1)
|
141 |
+
# # freqs.shape [1, x_len_max, d]
|
142 |
+
# # torch.set_printoptions(edgeitems=200)
|
143 |
+
# # logging.info(f"{freqs[:, :q.shape[2]]=}")
|
144 |
+
# # logging.info(f"{query_ids_multiple=}")
|
145 |
+
# # logging.info(f"{key_ids_multiple=}")
|
146 |
+
# # print(f"q.shape: {q.shape}")
|
147 |
+
# # print(f"q_offset: {q_offset}")
|
148 |
+
# # print(f"k.shape: {k.shape}")
|
149 |
+
# q_emb = freqs[:, q_offset:q_offset+q.shape[-2]] * query_ids_multiple # [B, q_len_max, d]
|
150 |
+
# k_emb = freqs[:, :k.shape[2]] * key_ids_multiple # [B, k_len_max, d]
|
151 |
+
# # logging.info(f"{q_emb[0, :, :5]=}")
|
152 |
+
# # logging.info(f"{k_emb[0, :, :5]=}")
|
153 |
+
# multiple = k_lens_expanded if multiple_key_length else q_lens_expanded
|
154 |
+
# if progress_no_multiple:
|
155 |
+
# multiple = 1
|
156 |
+
# q_emb = q_emb / q_lens_expanded * multiple * progress_scale
|
157 |
+
# k_emb = k_emb / k_lens_expanded * multiple * progress_scale
|
158 |
+
# q_cos = q_emb.cos().unsqueeze(unsqueeze_dim) # [B, 1, q_len_max, d] # 1 is for nhead
|
159 |
+
# q_sin = q_emb.sin().unsqueeze(unsqueeze_dim)
|
160 |
+
# k_cos = k_emb.cos().unsqueeze(unsqueeze_dim)
|
161 |
+
# k_sin = k_emb.sin().unsqueeze(unsqueeze_dim)
|
162 |
+
# # # visualize rotary pos emb with dummy feature
|
163 |
+
# # q_tmp = torch.ones_like(q)
|
164 |
+
# # k_tmp = torch.ones_like(k)
|
165 |
+
# # q_tmp_emb = q_tmp * q_cos + rotate_half(q_tmp) * q_sin
|
166 |
+
# # k_tmp_emb = k_tmp * k_cos + rotate_half(k_tmp) * k_sin
|
167 |
+
# # sims = q_tmp_emb @ k_tmp_emb.transpose(-2, -1)
|
168 |
+
# # import matplotlib.pyplot as plt
|
169 |
+
# # for i, sim in enumerate(sims):
|
170 |
+
# # plt.imshow(sim[0][:query_lens[i], :key_lens[i]].detach().cpu().numpy())
|
171 |
+
# # plt.savefig(f"sim{i}_head0.png")
|
172 |
+
# # plt.imshow(sim[5][:query_lens[i], :key_lens[i]].detach().cpu().numpy())
|
173 |
+
# # plt.savefig(f"sim{i}_head5.png")
|
174 |
+
# q_emb = q * q_cos + rotate_half(q) * q_sin
|
175 |
+
# k_emb = k * k_cos + rotate_half(k) * k_sin
|
176 |
+
# # # visualize the real attention weights
|
177 |
+
# # sims = q_emb @ k_emb.transpose(-2, -1)
|
178 |
+
# # from datetime import datetime
|
179 |
+
# # from matplotlib import pyplot as plt
|
180 |
+
# # now = datetime.now()
|
181 |
+
# # for i, sim in enumerate(sims):
|
182 |
+
# # for ihead, si in enumerate(sim):
|
183 |
+
# # if query_lens[i] == key_lens[i]:
|
184 |
+
# # continue
|
185 |
+
# # plt.imshow(si[:query_lens[i], :key_lens[i]].detach().cpu().numpy())
|
186 |
+
# # plt.savefig(f"sample{i}_head{ihead}_{now.strftime('%Y-%m-%d_%H-%M-%S')}.png")
|
187 |
+
return q_emb, k_emb
|
188 |
+
|
189 |
+
|
190 |
+
class MultiheadAttention(Module):
|
191 |
+
r"""Allows the model to jointly attend to information
|
192 |
+
from different representation subspaces as described in the paper:
|
193 |
+
`Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
|
194 |
+
|
195 |
+
Multi-Head Attention is defined as:
|
196 |
+
|
197 |
+
.. math::
|
198 |
+
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
|
199 |
+
|
200 |
+
where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
|
201 |
+
|
202 |
+
``forward()`` will use a special optimized implementation if all of the following
|
203 |
+
conditions are met:
|
204 |
+
|
205 |
+
- self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This
|
206 |
+
restriction will be loosened in the future.)
|
207 |
+
- Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
|
208 |
+
- training is disabled (using ``.eval()``)
|
209 |
+
- dropout is 0
|
210 |
+
- ``add_bias_kv`` is ``False``
|
211 |
+
- ``add_zero_attn`` is ``False``
|
212 |
+
- ``batch_first`` is ``True`` and the input is batched
|
213 |
+
- ``kdim`` and ``vdim`` are equal to ``embed_dim``
|
214 |
+
- at most one of ``key_padding_mask`` or ``attn_mask`` is passed
|
215 |
+
- if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
|
216 |
+
nor ``attn_mask`` is passed
|
217 |
+
|
218 |
+
If the optimized implementation is in use, a
|
219 |
+
`NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
|
220 |
+
``query``/``key``/``value`` to represent padding more efficiently than using a
|
221 |
+
padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
|
222 |
+
will be returned, and an additional speedup proportional to the fraction of the input
|
223 |
+
that is padding can be expected.
|
224 |
+
|
225 |
+
Args:
|
226 |
+
embed_dim: Total dimension of the model.
|
227 |
+
num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
|
228 |
+
across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
|
229 |
+
dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
|
230 |
+
bias: If specified, adds bias to input / output projection layers. Default: ``True``.
|
231 |
+
add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
|
232 |
+
add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
|
233 |
+
Default: ``False``.
|
234 |
+
kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
|
235 |
+
vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
|
236 |
+
batch_first: If ``True``, then the input and output tensors are provided
|
237 |
+
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
|
238 |
+
|
239 |
+
Examples::
|
240 |
+
|
241 |
+
>>> # xdoctest: +SKIP
|
242 |
+
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
|
243 |
+
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
244 |
+
|
245 |
+
"""
|
246 |
+
__constants__ = ["batch_first"]
|
247 |
+
bias_k: Optional[torch.Tensor]
|
248 |
+
bias_v: Optional[torch.Tensor]
|
249 |
+
|
250 |
+
def __init__(
|
251 |
+
self,
|
252 |
+
embed_dim,
|
253 |
+
num_heads,
|
254 |
+
dropout=0.0,
|
255 |
+
bias=True,
|
256 |
+
add_bias_kv=False,
|
257 |
+
add_zero_attn=False,
|
258 |
+
kdim=None,
|
259 |
+
vdim=None,
|
260 |
+
batch_first=False,
|
261 |
+
linear1_cls=Linear,
|
262 |
+
linear2_cls=Linear,
|
263 |
+
device=None,
|
264 |
+
dtype=None,
|
265 |
+
) -> None:
|
266 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
267 |
+
super(MultiheadAttention, self).__init__()
|
268 |
+
self.embed_dim = embed_dim
|
269 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
270 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
271 |
+
self._qkv_same_embed_dim = (
|
272 |
+
self.kdim == embed_dim and self.vdim == embed_dim
|
273 |
+
)
|
274 |
+
|
275 |
+
self.num_heads = num_heads
|
276 |
+
self.dropout = dropout
|
277 |
+
self.batch_first = batch_first
|
278 |
+
self.head_dim = embed_dim // num_heads
|
279 |
+
assert (
|
280 |
+
self.head_dim * num_heads == self.embed_dim
|
281 |
+
), "embed_dim must be divisible by num_heads"
|
282 |
+
|
283 |
+
if add_bias_kv:
|
284 |
+
self.bias_k = Parameter(
|
285 |
+
torch.empty((1, 1, embed_dim), **factory_kwargs)
|
286 |
+
)
|
287 |
+
self.bias_v = Parameter(
|
288 |
+
torch.empty((1, 1, embed_dim), **factory_kwargs)
|
289 |
+
)
|
290 |
+
else:
|
291 |
+
self.bias_k = self.bias_v = None
|
292 |
+
|
293 |
+
if linear1_cls == Linear:
|
294 |
+
if not self._qkv_same_embed_dim:
|
295 |
+
self.q_proj_weight = Parameter(
|
296 |
+
torch.empty((embed_dim, embed_dim), **factory_kwargs)
|
297 |
+
)
|
298 |
+
self.k_proj_weight = Parameter(
|
299 |
+
torch.empty((embed_dim, self.kdim), **factory_kwargs)
|
300 |
+
)
|
301 |
+
self.v_proj_weight = Parameter(
|
302 |
+
torch.empty((embed_dim, self.vdim), **factory_kwargs)
|
303 |
+
)
|
304 |
+
self.register_parameter("in_proj_weight", None)
|
305 |
+
else:
|
306 |
+
# go down this route with music_gen
|
307 |
+
self.in_proj_weight = Parameter(
|
308 |
+
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
|
309 |
+
)
|
310 |
+
self.register_parameter("q_proj_weight", None)
|
311 |
+
self.register_parameter("k_proj_weight", None)
|
312 |
+
self.register_parameter("v_proj_weight", None)
|
313 |
+
|
314 |
+
if bias: # True by default
|
315 |
+
self.in_proj_bias = Parameter(
|
316 |
+
torch.empty(3 * embed_dim, **factory_kwargs)
|
317 |
+
)
|
318 |
+
else:
|
319 |
+
self.register_parameter("in_proj_bias", None)
|
320 |
+
self.out_proj = NonDynamicallyQuantizableLinear(
|
321 |
+
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
322 |
+
)
|
323 |
+
|
324 |
+
self._reset_parameters()
|
325 |
+
else:
|
326 |
+
if not self._qkv_same_embed_dim:
|
327 |
+
raise NotImplementedError
|
328 |
+
else:
|
329 |
+
self.in_proj_linear = linear1_cls(
|
330 |
+
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
|
331 |
+
)
|
332 |
+
self.in_proj_weight = self.in_proj_linear.weight
|
333 |
+
|
334 |
+
self.register_parameter("q_proj_weight", None)
|
335 |
+
self.register_parameter("k_proj_weight", None)
|
336 |
+
self.register_parameter("v_proj_weight", None)
|
337 |
+
|
338 |
+
if bias:
|
339 |
+
self.in_proj_bias = self.in_proj_linear.bias
|
340 |
+
else:
|
341 |
+
self.register_parameter("in_proj_bias", None)
|
342 |
+
|
343 |
+
self.out_proj = linear2_cls(
|
344 |
+
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
345 |
+
)
|
346 |
+
|
347 |
+
if self.bias_k is not None:
|
348 |
+
xavier_normal_(self.bias_k)
|
349 |
+
if self.bias_v is not None:
|
350 |
+
xavier_normal_(self.bias_v)
|
351 |
+
|
352 |
+
self.add_zero_attn = add_zero_attn
|
353 |
+
|
354 |
+
def _reset_parameters(self):
|
355 |
+
if self._qkv_same_embed_dim:
|
356 |
+
xavier_uniform_(self.in_proj_weight)
|
357 |
+
else:
|
358 |
+
xavier_uniform_(self.q_proj_weight)
|
359 |
+
xavier_uniform_(self.k_proj_weight)
|
360 |
+
xavier_uniform_(self.v_proj_weight)
|
361 |
+
|
362 |
+
if self.in_proj_bias is not None:
|
363 |
+
constant_(self.in_proj_bias, 0.0)
|
364 |
+
constant_(self.out_proj.bias, 0.0)
|
365 |
+
|
366 |
+
if self.bias_k is not None:
|
367 |
+
xavier_normal_(self.bias_k)
|
368 |
+
if self.bias_v is not None:
|
369 |
+
xavier_normal_(self.bias_v)
|
370 |
+
|
371 |
+
def __setstate__(self, state):
|
372 |
+
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
|
373 |
+
if "_qkv_same_embed_dim" not in state:
|
374 |
+
state["_qkv_same_embed_dim"] = True
|
375 |
+
|
376 |
+
super(MultiheadAttention, self).__setstate__(state)
|
377 |
+
|
378 |
+
def forward(
|
379 |
+
self,
|
380 |
+
query: Tensor,
|
381 |
+
key: Tensor,
|
382 |
+
value: Tensor,
|
383 |
+
key_padding_mask: Optional[Tensor] = None,
|
384 |
+
need_weights: bool = True,
|
385 |
+
attn_mask: Optional[Tensor] = None,
|
386 |
+
average_attn_weights: bool = True,
|
387 |
+
past: Optional[Tensor] = None,
|
388 |
+
q_sinu = None,
|
389 |
+
k_sinu = None,
|
390 |
+
sinu = None,
|
391 |
+
args = None,
|
392 |
+
q_offset = 0,
|
393 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
394 |
+
r"""
|
395 |
+
Args:
|
396 |
+
query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
|
397 |
+
or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
|
398 |
+
:math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
|
399 |
+
Queries are compared against key-value pairs to produce the output.
|
400 |
+
See "Attention Is All You Need" for more details.
|
401 |
+
key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
|
402 |
+
or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
|
403 |
+
:math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
|
404 |
+
See "Attention Is All You Need" for more details.
|
405 |
+
value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
|
406 |
+
``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
|
407 |
+
sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
|
408 |
+
See "Attention Is All You Need" for more details.
|
409 |
+
key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
|
410 |
+
to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
|
411 |
+
Binary and byte masks are supported.
|
412 |
+
For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
|
413 |
+
the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
|
414 |
+
need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
|
415 |
+
Default: ``True``.
|
416 |
+
attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
|
417 |
+
:math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
|
418 |
+
:math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
|
419 |
+
broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
|
420 |
+
Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
|
421 |
+
corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
|
422 |
+
corresponding position is not allowed to attend. For a float mask, the mask values will be added to
|
423 |
+
the attention weight.
|
424 |
+
average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
|
425 |
+
heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
|
426 |
+
effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
|
427 |
+
sinu: for direct original rope positional encoding
|
428 |
+
freqs: for progress monitoring with rope positional encoding
|
429 |
+
q_offset: for progress monitoring with rope positional encoding, during inference when kvcache is on, most of the time query is only of length 1, so we need to offset the query to get the correct progress
|
430 |
+
|
431 |
+
Outputs:
|
432 |
+
- **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
|
433 |
+
:math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
|
434 |
+
where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
|
435 |
+
embedding dimension ``embed_dim``.
|
436 |
+
- **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
|
437 |
+
returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
|
438 |
+
:math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
|
439 |
+
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
|
440 |
+
head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
|
441 |
+
|
442 |
+
.. note::
|
443 |
+
`batch_first` argument is ignored for unbatched inputs.
|
444 |
+
"""
|
445 |
+
is_batched = query.dim() == 3
|
446 |
+
if key_padding_mask is not None:
|
447 |
+
_kpm_dtype = key_padding_mask.dtype
|
448 |
+
if _kpm_dtype != torch.bool and not torch.is_floating_point(
|
449 |
+
key_padding_mask
|
450 |
+
):
|
451 |
+
raise AssertionError(
|
452 |
+
"only bool and floating types of key_padding_mask are supported"
|
453 |
+
)
|
454 |
+
why_not_fast_path = ""
|
455 |
+
if not is_batched:
|
456 |
+
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
|
457 |
+
elif query is not key or key is not value:
|
458 |
+
# When lifting this restriction, don't forget to either
|
459 |
+
# enforce that the dtypes all match or test cases where
|
460 |
+
# they don't!
|
461 |
+
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
|
462 |
+
elif (
|
463 |
+
self.in_proj_bias is not None
|
464 |
+
and query.dtype != self.in_proj_bias.dtype
|
465 |
+
):
|
466 |
+
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
|
467 |
+
elif (
|
468 |
+
self.in_proj_weight is not None
|
469 |
+
and query.dtype != self.in_proj_weight.dtype
|
470 |
+
):
|
471 |
+
# this case will fail anyway, but at least they'll get a useful error message.
|
472 |
+
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
|
473 |
+
elif self.training:
|
474 |
+
why_not_fast_path = "training is enabled"
|
475 |
+
elif not self.batch_first:
|
476 |
+
why_not_fast_path = "batch_first was not True"
|
477 |
+
elif self.bias_k is not None:
|
478 |
+
why_not_fast_path = "self.bias_k was not None"
|
479 |
+
elif self.bias_v is not None:
|
480 |
+
why_not_fast_path = "self.bias_v was not None"
|
481 |
+
elif self.dropout:
|
482 |
+
why_not_fast_path = f"dropout was {self.dropout}, required zero"
|
483 |
+
elif self.add_zero_attn:
|
484 |
+
why_not_fast_path = "add_zero_attn was enabled"
|
485 |
+
elif not self._qkv_same_embed_dim:
|
486 |
+
why_not_fast_path = "_qkv_same_embed_dim was not True"
|
487 |
+
elif attn_mask is not None:
|
488 |
+
why_not_fast_path = "attn_mask was not None"
|
489 |
+
elif query.is_nested and key_padding_mask is not None:
|
490 |
+
why_not_fast_path = (
|
491 |
+
"key_padding_mask is not supported with NestedTensor input"
|
492 |
+
)
|
493 |
+
elif self.num_heads % 2 == 1:
|
494 |
+
why_not_fast_path = "num_heads is odd"
|
495 |
+
elif torch.is_autocast_enabled():
|
496 |
+
why_not_fast_path = "autocast is enabled"
|
497 |
+
|
498 |
+
if not why_not_fast_path:
|
499 |
+
tensor_args = (
|
500 |
+
query,
|
501 |
+
key,
|
502 |
+
value,
|
503 |
+
self.in_proj_weight,
|
504 |
+
self.in_proj_bias,
|
505 |
+
self.out_proj.weight,
|
506 |
+
self.out_proj.bias,
|
507 |
+
)
|
508 |
+
# We have to use list comprehensions below because TorchScript does not support
|
509 |
+
# generator expressions.
|
510 |
+
if torch.overrides.has_torch_function(tensor_args):
|
511 |
+
why_not_fast_path = "some Tensor argument has_torch_function"
|
512 |
+
elif not all(
|
513 |
+
[
|
514 |
+
(x is None or x.is_cuda or "cpu" in str(x.device))
|
515 |
+
for x in tensor_args
|
516 |
+
]
|
517 |
+
):
|
518 |
+
why_not_fast_path = (
|
519 |
+
"some Tensor argument is neither CUDA nor CPU"
|
520 |
+
)
|
521 |
+
elif torch.is_grad_enabled() and any(
|
522 |
+
[x is not None and x.requires_grad for x in tensor_args]
|
523 |
+
):
|
524 |
+
why_not_fast_path = (
|
525 |
+
"grad is enabled and at least one of query or the "
|
526 |
+
"input/output projection weights or biases requires_grad"
|
527 |
+
)
|
528 |
+
if not why_not_fast_path:
|
529 |
+
return torch._native_multi_head_attention(
|
530 |
+
query,
|
531 |
+
key,
|
532 |
+
value,
|
533 |
+
self.embed_dim,
|
534 |
+
self.num_heads,
|
535 |
+
self.in_proj_weight,
|
536 |
+
self.in_proj_bias,
|
537 |
+
self.out_proj.weight,
|
538 |
+
self.out_proj.bias,
|
539 |
+
key_padding_mask
|
540 |
+
if key_padding_mask is not None
|
541 |
+
else attn_mask,
|
542 |
+
need_weights,
|
543 |
+
average_attn_weights,
|
544 |
+
1
|
545 |
+
if key_padding_mask is not None
|
546 |
+
else 0
|
547 |
+
if attn_mask is not None
|
548 |
+
else None,
|
549 |
+
)
|
550 |
+
|
551 |
+
any_nested = query.is_nested or key.is_nested or value.is_nested
|
552 |
+
assert not any_nested, (
|
553 |
+
"MultiheadAttention does not support NestedTensor outside of its fast path. "
|
554 |
+
+ f"The fast path was not hit because {why_not_fast_path}"
|
555 |
+
)
|
556 |
+
|
557 |
+
if self.batch_first and is_batched:
|
558 |
+
# make sure that the transpose op does not affect the "is" property
|
559 |
+
if key is value:
|
560 |
+
if query is key:
|
561 |
+
query = key = value = query.transpose(1, 0)
|
562 |
+
else:
|
563 |
+
query, key = [x.transpose(1, 0) for x in (query, key)]
|
564 |
+
value = key
|
565 |
+
else:
|
566 |
+
query, key, value = [
|
567 |
+
x.transpose(1, 0) for x in (query, key, value)
|
568 |
+
]
|
569 |
+
|
570 |
+
if not self._qkv_same_embed_dim:
|
571 |
+
attn_output, attn_output_weights = F.multi_head_attention_forward(
|
572 |
+
query,
|
573 |
+
key,
|
574 |
+
value,
|
575 |
+
self.embed_dim,
|
576 |
+
self.num_heads,
|
577 |
+
self.in_proj_weight,
|
578 |
+
self.in_proj_bias,
|
579 |
+
self.bias_k,
|
580 |
+
self.bias_v,
|
581 |
+
self.add_zero_attn,
|
582 |
+
self.dropout,
|
583 |
+
self.out_proj.weight,
|
584 |
+
self.out_proj.bias,
|
585 |
+
training=self.training,
|
586 |
+
key_padding_mask=key_padding_mask,
|
587 |
+
need_weights=need_weights,
|
588 |
+
attn_mask=attn_mask,
|
589 |
+
use_separate_proj_weight=True,
|
590 |
+
q_proj_weight=self.q_proj_weight,
|
591 |
+
k_proj_weight=self.k_proj_weight,
|
592 |
+
v_proj_weight=self.v_proj_weight,
|
593 |
+
average_attn_weights=average_attn_weights,
|
594 |
+
)
|
595 |
+
else:
|
596 |
+
# music_gen should go down this route
|
597 |
+
# logging.info("using in_proj_weight")
|
598 |
+
# attn_output, attn_output_weights = F.multi_head_attention_forward(
|
599 |
+
# query,
|
600 |
+
# key,
|
601 |
+
# value,
|
602 |
+
# self.embed_dim,
|
603 |
+
# self.num_heads,
|
604 |
+
# self.in_proj_weight,
|
605 |
+
# self.in_proj_bias,
|
606 |
+
# self.bias_k,
|
607 |
+
# self.bias_v,
|
608 |
+
# self.add_zero_attn,
|
609 |
+
# self.dropout,
|
610 |
+
# self.out_proj.weight,
|
611 |
+
# self.out_proj.bias,
|
612 |
+
# training=self.training,
|
613 |
+
# key_padding_mask=key_padding_mask,
|
614 |
+
# need_weights=need_weights,
|
615 |
+
# attn_mask=attn_mask,
|
616 |
+
# average_attn_weights=average_attn_weights,
|
617 |
+
# )
|
618 |
+
# re-write the self.attention here, to get k, v cache
|
619 |
+
tgt_len, bsz, embed_dim = query.shape
|
620 |
+
src_len, _, _ = key.shape
|
621 |
+
num_heads = self.num_heads
|
622 |
+
key_padding_mask = _canonical_mask(
|
623 |
+
mask=key_padding_mask,
|
624 |
+
mask_name="key_padding_mask",
|
625 |
+
other_type=_none_or_dtype(attn_mask),
|
626 |
+
other_name="attn_mask",
|
627 |
+
target_type=query.dtype
|
628 |
+
)
|
629 |
+
attn_mask = _canonical_mask(
|
630 |
+
mask=attn_mask,
|
631 |
+
mask_name="attn_mask",
|
632 |
+
other_type=None,
|
633 |
+
other_name="",
|
634 |
+
target_type=query.dtype,
|
635 |
+
check_other=False,
|
636 |
+
)
|
637 |
+
head_dim = self.embed_dim // self.num_heads
|
638 |
+
assert head_dim * self.num_heads == self.embed_dim, f"embed_dim {self.embed_dim} not divisible by num_heads {self.num_heads}"
|
639 |
+
assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
|
640 |
+
q, k, v = _in_projection_packed(query, key, value, self.in_proj_weight, self.in_proj_bias)
|
641 |
+
# k_present, v_present = k, v
|
642 |
+
|
643 |
+
#
|
644 |
+
# reshape q, k, v for multihead attention and make em batch first
|
645 |
+
#
|
646 |
+
|
647 |
+
|
648 |
+
|
649 |
+
q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
|
650 |
+
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
651 |
+
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1) # (bsz * num_heads, src_len, head_dim)
|
652 |
+
src_len = k.size(1)
|
653 |
+
if past is not None and past.ndim > 2:
|
654 |
+
expected_src_len = src_len + past[0].shape[-2]
|
655 |
+
else:
|
656 |
+
expected_src_len = src_len
|
657 |
+
|
658 |
+
|
659 |
+
# ensure attn_mask's dim is 3
|
660 |
+
if attn_mask is not None:
|
661 |
+
if attn_mask.dim() == 2:
|
662 |
+
correct_2d_size = (tgt_len, expected_src_len)
|
663 |
+
if attn_mask.shape != correct_2d_size:
|
664 |
+
raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
|
665 |
+
attn_mask = attn_mask.unsqueeze(0)
|
666 |
+
elif attn_mask.dim() == 3:
|
667 |
+
correct_3d_size = (bsz * num_heads, tgt_len, expected_src_len)
|
668 |
+
if attn_mask.shape != correct_3d_size:
|
669 |
+
raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
|
670 |
+
else:
|
671 |
+
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
|
672 |
+
|
673 |
+
if key_padding_mask is not None:
|
674 |
+
assert key_padding_mask.shape == (bsz, expected_src_len), \
|
675 |
+
f"expecting key_padding_mask shape of {(bsz, expected_src_len)}, but got {key_padding_mask.shape}"
|
676 |
+
key_padding_mask = key_padding_mask.view(bsz, 1, 1, expected_src_len). \
|
677 |
+
expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, expected_src_len)
|
678 |
+
if attn_mask is None:
|
679 |
+
attn_mask = key_padding_mask
|
680 |
+
else:
|
681 |
+
attn_mask = attn_mask + key_padding_mask
|
682 |
+
|
683 |
+
if not self.training:
|
684 |
+
dropout_p = 0.0
|
685 |
+
else:
|
686 |
+
dropout_p = self.dropout
|
687 |
+
|
688 |
+
if need_weights:
|
689 |
+
raise NotImplementedError("need_weights not implemented for music_gen")
|
690 |
+
# B, Nt, E = q.shape
|
691 |
+
# q_scaled = q / math.sqrt(E)
|
692 |
+
|
693 |
+
# assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
|
694 |
+
|
695 |
+
# if attn_mask is not None:
|
696 |
+
# attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
|
697 |
+
# else:
|
698 |
+
# attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
|
699 |
+
# attn_output_weights = softmax(attn_output_weights, dim=-1)
|
700 |
+
# if dropout_p > 0.0:
|
701 |
+
# attn_output_weights = dropout(attn_output_weights, p=dropout_p)
|
702 |
+
|
703 |
+
# attn_output = torch.bmm(attn_output_weights, v)
|
704 |
+
|
705 |
+
# attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
|
706 |
+
# attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
707 |
+
# attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
|
708 |
+
|
709 |
+
# # optionally average attention weights over heads
|
710 |
+
# attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
711 |
+
# if average_attn_weights:
|
712 |
+
# attn_output_weights = attn_output_weights.mean(dim=1)
|
713 |
+
|
714 |
+
# if not is_batched:
|
715 |
+
# # squeeze the output if input was unbatched
|
716 |
+
# attn_output = attn_output.squeeze(1)
|
717 |
+
# attn_output_weights = attn_output_weights.squeeze(0)
|
718 |
+
# return attn_output, attn_output_weights
|
719 |
+
else:
|
720 |
+
# attn_mask can be either (L,S) or (N*num_heads, L, S)
|
721 |
+
# if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
|
722 |
+
# in order to match the input for SDPA of (N, num_heads, L, S)
|
723 |
+
if attn_mask is not None:
|
724 |
+
if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
|
725 |
+
attn_mask = attn_mask.unsqueeze(0)
|
726 |
+
else:
|
727 |
+
attn_mask = attn_mask.view(bsz, num_heads, -1, expected_src_len)
|
728 |
+
|
729 |
+
q = q.view(bsz, num_heads, tgt_len, head_dim)
|
730 |
+
k = k.view(bsz, num_heads, src_len, head_dim)
|
731 |
+
v = v.view(bsz, num_heads, src_len, head_dim)
|
732 |
+
# logging.info(f"shape of past: {past.shape}")
|
733 |
+
if past is not None:
|
734 |
+
present = torch.stack([k, v], dim=0) # (2, bsz, num_heads, src_len, head_dim)
|
735 |
+
if past.ndim > 2: # this means we use kvcache, otherwise we just pass in a placeholder, but not actually using kvcache
|
736 |
+
pk, pv = past
|
737 |
+
k = torch.cat([pk, k], dim=-2)
|
738 |
+
v = torch.cat([pv, v], dim=-2)
|
739 |
+
else:
|
740 |
+
present = None
|
741 |
+
# when using kvcache, need to offset postion of q when applying rotary pos emb
|
742 |
+
# here we assume that this kvcache is only used in self-attention, and therefore k and q always have the same seq_len
|
743 |
+
# rope positional encoding
|
744 |
+
if sinu is not None:
|
745 |
+
# direct rotary
|
746 |
+
# logging.info("perform rotary positional encoding")
|
747 |
+
q, k = apply_rotary_pos_emb(q, k, sinu=sinu, args = args, q_offset=q_offset)
|
748 |
+
if q_sinu is not None:
|
749 |
+
assert sinu is None, "sinu and q_sinu cannot be used together"
|
750 |
+
assert k_sinu is not None, "k_sinu must be provided"
|
751 |
+
q, k = apply_rotary_pos_emb(q, k, q_sinu=q_sinu, k_sinu=k_sinu, args = args, q_offset=q_offset)
|
752 |
+
|
753 |
+
# if self.training and it's cross attention, will get attention_weights
|
754 |
+
if args != None and self.training and getattr(args, "attention_alignment_loss", 0) and not (query is key):
|
755 |
+
attention_weights = q @ k.transpose(-1, -2)
|
756 |
+
else:
|
757 |
+
attention_weights = None
|
758 |
+
attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal=False)
|
759 |
+
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
|
760 |
+
|
761 |
+
attn_output = F.linear(attn_output, self.out_proj.weight, self.out_proj.bias)
|
762 |
+
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
|
763 |
+
if not is_batched:
|
764 |
+
# squeeze the output if input was unbatched
|
765 |
+
attn_output = attn_output.squeeze(1)
|
766 |
+
# if self.training:
|
767 |
+
# return attn_output, None
|
768 |
+
# else:
|
769 |
+
# return (attn_output, present), None
|
770 |
+
|
771 |
+
# harded coded, the code do not support returning attn weigths yet
|
772 |
+
attn_output_weights=None
|
773 |
+
if self.batch_first and is_batched:
|
774 |
+
if attention_weights != None:
|
775 |
+
return {"attn_output": attn_output.transpose(1, 0), "attention_weights": attention_weights}, present
|
776 |
+
return attn_output.transpose(1, 0), present
|
777 |
+
else:
|
778 |
+
if attention_weights != None:
|
779 |
+
return {"attn_output": attn_output, "attention_weights": attention_weights}, present
|
780 |
+
return attn_output, present
|
781 |
+
|
models/modules/embedding.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py
|
2 |
+
# Copyright 2023 (authors: Feiteng Li)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import math
|
17 |
+
import logging
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
|
21 |
+
|
22 |
+
class TokenEmbedding(nn.Module):
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
dim_model: int,
|
26 |
+
vocab_size: int,
|
27 |
+
dropout: float = 0.0,
|
28 |
+
):
|
29 |
+
super().__init__()
|
30 |
+
|
31 |
+
self.vocab_size = vocab_size
|
32 |
+
self.dim_model = dim_model
|
33 |
+
|
34 |
+
self.dropout = torch.nn.Dropout(p=dropout)
|
35 |
+
self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model)
|
36 |
+
|
37 |
+
@property
|
38 |
+
def weight(self) -> torch.Tensor:
|
39 |
+
return self.word_embeddings.weight
|
40 |
+
|
41 |
+
def embedding(self, index: int) -> torch.Tensor:
|
42 |
+
return self.word_embeddings.weight[index : index + 1]
|
43 |
+
|
44 |
+
def forward(self, x: torch.Tensor):
|
45 |
+
X = self.word_embeddings(x)
|
46 |
+
X = self.dropout(X)
|
47 |
+
|
48 |
+
return X
|
49 |
+
|
50 |
+
|
51 |
+
class SinePositionalEmbedding(nn.Module):
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
dim_model: int,
|
55 |
+
dropout: float = 0.0,
|
56 |
+
scale: bool = False,
|
57 |
+
alpha: bool = False,
|
58 |
+
):
|
59 |
+
super().__init__()
|
60 |
+
self.dim_model = dim_model
|
61 |
+
self.x_scale = math.sqrt(dim_model) if scale else 1.0
|
62 |
+
self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
|
63 |
+
self.dropout = torch.nn.Dropout(p=dropout)
|
64 |
+
|
65 |
+
self.reverse = False
|
66 |
+
self.pe = None
|
67 |
+
self.extend_pe(torch.tensor(0.0).expand(1, 4000))
|
68 |
+
|
69 |
+
def extend_pe(self, x):
|
70 |
+
"""Reset the positional encodings."""
|
71 |
+
if self.pe is not None:
|
72 |
+
if self.pe.size(1) >= x.size(1):
|
73 |
+
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
74 |
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
75 |
+
return
|
76 |
+
pe = torch.zeros(x.size(1), self.dim_model)
|
77 |
+
if self.reverse:
|
78 |
+
position = torch.arange(
|
79 |
+
x.size(1) - 1, -1, -1.0, dtype=torch.float32
|
80 |
+
).unsqueeze(1)
|
81 |
+
else:
|
82 |
+
position = torch.arange(
|
83 |
+
0, x.size(1), dtype=torch.float32
|
84 |
+
).unsqueeze(1)
|
85 |
+
div_term = torch.exp(
|
86 |
+
torch.arange(0, self.dim_model, 2, dtype=torch.float32)
|
87 |
+
* -(math.log(10000.0) / self.dim_model)
|
88 |
+
)
|
89 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
90 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
91 |
+
pe = pe.unsqueeze(0)
|
92 |
+
self.pe = pe.to(device=x.device, dtype=x.dtype).detach()
|
93 |
+
|
94 |
+
def forward(self, x: torch.Tensor, *args) -> torch.Tensor:
|
95 |
+
self.extend_pe(x)
|
96 |
+
output = x.unsqueeze(-1) if x.ndim == 2 else x
|
97 |
+
output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
|
98 |
+
return self.dropout(output)
|
99 |
+
|
100 |
+
|
101 |
+
class SinePositionalEmbedding_progress(nn.Module):
|
102 |
+
def __init__(
|
103 |
+
self,
|
104 |
+
dim_model: int,
|
105 |
+
dropout: float = 0.0,
|
106 |
+
scale: bool = False,
|
107 |
+
alpha: bool = False,
|
108 |
+
args = None
|
109 |
+
):
|
110 |
+
super().__init__()
|
111 |
+
self.args = args
|
112 |
+
self.dim_model = dim_model
|
113 |
+
self.x_scale = math.sqrt(dim_model) if scale else 1.0
|
114 |
+
self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
|
115 |
+
self.dropout = torch.nn.Dropout(p=dropout)
|
116 |
+
|
117 |
+
self.reverse = False
|
118 |
+
self.div_term = torch.exp(
|
119 |
+
torch.arange(0, self.dim_model, 2, dtype=torch.float32)
|
120 |
+
* -(math.log(args.sinusoidal_base) / self.dim_model)
|
121 |
+
).unsqueeze(0).unsqueeze(0) # [1, 1, dim_model//2]
|
122 |
+
self.position = None
|
123 |
+
self.extend_position(torch.tensor(0.0).expand(1, 10000))
|
124 |
+
self.progress_scale = getattr(args, "progress_scale", 1.0)
|
125 |
+
|
126 |
+
def extend_position(self, x):
|
127 |
+
"""Reset the positional encodings."""
|
128 |
+
if self.position is not None:
|
129 |
+
if self.div_term.dtype != x.dtype or self.div_term.device != x.device:
|
130 |
+
self.div_term = self.div_term.to(dtype=x.dtype, device=x.device)
|
131 |
+
if self.position.size(1) >= x.size(1):
|
132 |
+
if self.position.dtype != x.dtype or self.position.device != x.device:
|
133 |
+
self.position = self.position.to(dtype=x.dtype, device=x.device)
|
134 |
+
return
|
135 |
+
if self.reverse:
|
136 |
+
self.position = torch.arange(
|
137 |
+
x.size(1) - 1, -1, -1.0, dtype=torch.float32
|
138 |
+
).unsqueeze(0).unsqueeze(2).to(x)
|
139 |
+
else:
|
140 |
+
self.position = torch.arange(
|
141 |
+
0, x.size(1), dtype=torch.float32
|
142 |
+
).unsqueeze(0).unsqueeze(2).to(x) # [1, seq_len, 1]
|
143 |
+
|
144 |
+
def forward(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor:
|
145 |
+
assert x.ndim == 3, x.shape
|
146 |
+
self.extend_position(x)
|
147 |
+
x_lens = x_lens.unsqueeze(1).unsqueeze(2) # [B, 1, 1]
|
148 |
+
multiple = x_lens / (x_lens - 1)
|
149 |
+
progress = self.position[:, :x.shape[1]] * multiple / x_lens * self.progress_scale
|
150 |
+
# torch.set_printoptions(edgeitems=100)
|
151 |
+
# for i in range(x_lens.shape[0]):
|
152 |
+
# logging.info(f"{progress[i, :x_lens[i,0,0], 0]}")
|
153 |
+
invfreq = self.div_term * progress # might want to use a scale term here
|
154 |
+
pe = torch.zeros_like(x)
|
155 |
+
pe[..., 0::2] = torch.sin(invfreq)
|
156 |
+
pe[..., 1::2] = torch.cos(invfreq)
|
157 |
+
output = x * self.x_scale + self.alpha * pe
|
158 |
+
return self.dropout(output)
|
models/modules/sampling.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
def top_k_top_p_filtering(
|
5 |
+
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
|
6 |
+
):
|
7 |
+
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
8 |
+
Args:
|
9 |
+
logits: logits distribution shape (batch size, vocabulary size)
|
10 |
+
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
11 |
+
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
12 |
+
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
13 |
+
Make sure we keep at least min_tokens_to_keep per batch example in the output
|
14 |
+
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
15 |
+
"""
|
16 |
+
if top_k > 0:
|
17 |
+
top_k = min(
|
18 |
+
max(top_k, min_tokens_to_keep), logits.size(-1)
|
19 |
+
) # Safety check
|
20 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
21 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
22 |
+
logits[indices_to_remove] = filter_value
|
23 |
+
|
24 |
+
if top_p < 1.0:
|
25 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
26 |
+
cumulative_probs = torch.cumsum(
|
27 |
+
F.softmax(sorted_logits, dim=-1), dim=-1
|
28 |
+
)
|
29 |
+
|
30 |
+
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
|
31 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
32 |
+
if min_tokens_to_keep > 1:
|
33 |
+
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
34 |
+
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
|
35 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
36 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
|
37 |
+
..., :-1
|
38 |
+
].clone()
|
39 |
+
sorted_indices_to_remove[..., 0] = 0
|
40 |
+
|
41 |
+
# scatter sorted tensors to original indexing
|
42 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
43 |
+
1, sorted_indices, sorted_indices_to_remove
|
44 |
+
)
|
45 |
+
logits[indices_to_remove] = filter_value
|
46 |
+
return logits
|
47 |
+
|
48 |
+
def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
|
49 |
+
# temperature: (`optional`) float
|
50 |
+
# The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
|
51 |
+
# top_k: (`optional`) int
|
52 |
+
# The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
|
53 |
+
# top_p: (`optional`) float
|
54 |
+
# The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
|
55 |
+
|
56 |
+
# Temperature (higher temperature => more likely to sample low probability tokens)
|
57 |
+
if temperature != 1.0:
|
58 |
+
logits = logits / temperature
|
59 |
+
# Top-p/top-k filtering
|
60 |
+
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
61 |
+
# Sample
|
62 |
+
token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
|
63 |
+
return token
|
models/modules/scaling.py
ADDED
@@ -0,0 +1,1406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/scaling.py
|
2 |
+
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
|
3 |
+
#
|
4 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
|
19 |
+
import collections
|
20 |
+
import logging
|
21 |
+
import random
|
22 |
+
import math
|
23 |
+
from functools import reduce
|
24 |
+
from itertools import repeat
|
25 |
+
from typing import Optional, Tuple, Union
|
26 |
+
|
27 |
+
import torch
|
28 |
+
import torch.nn as nn
|
29 |
+
import torch.nn.functional as F
|
30 |
+
from torch import Tensor
|
31 |
+
from torch.nn import Embedding as ScaledEmbedding
|
32 |
+
|
33 |
+
# from valle.utils import Transpose
|
34 |
+
|
35 |
+
class Transpose(nn.Identity):
|
36 |
+
"""(N, T, D) -> (N, D, T)"""
|
37 |
+
|
38 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
39 |
+
return input.transpose(1, 2)
|
40 |
+
|
41 |
+
class ActivationBalancerFunction(torch.autograd.Function):
|
42 |
+
@staticmethod
|
43 |
+
def forward(
|
44 |
+
ctx,
|
45 |
+
x: Tensor,
|
46 |
+
scale_factor: Tensor,
|
47 |
+
sign_factor: Optional[Tensor],
|
48 |
+
channel_dim: int,
|
49 |
+
) -> Tensor:
|
50 |
+
if channel_dim < 0:
|
51 |
+
channel_dim += x.ndim
|
52 |
+
ctx.channel_dim = channel_dim
|
53 |
+
xgt0 = x > 0
|
54 |
+
if sign_factor is None:
|
55 |
+
ctx.save_for_backward(xgt0, scale_factor)
|
56 |
+
else:
|
57 |
+
ctx.save_for_backward(xgt0, scale_factor, sign_factor)
|
58 |
+
return x
|
59 |
+
|
60 |
+
@staticmethod
|
61 |
+
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
|
62 |
+
if len(ctx.saved_tensors) == 3:
|
63 |
+
xgt0, scale_factor, sign_factor = ctx.saved_tensors
|
64 |
+
for _ in range(ctx.channel_dim, x_grad.ndim - 1):
|
65 |
+
scale_factor = scale_factor.unsqueeze(-1)
|
66 |
+
sign_factor = sign_factor.unsqueeze(-1)
|
67 |
+
factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
|
68 |
+
else:
|
69 |
+
xgt0, scale_factor = ctx.saved_tensors
|
70 |
+
for _ in range(ctx.channel_dim, x_grad.ndim - 1):
|
71 |
+
scale_factor = scale_factor.unsqueeze(-1)
|
72 |
+
factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
|
73 |
+
neg_delta_grad = x_grad.abs() * factor
|
74 |
+
return (
|
75 |
+
x_grad - neg_delta_grad,
|
76 |
+
None,
|
77 |
+
None,
|
78 |
+
None,
|
79 |
+
)
|
80 |
+
|
81 |
+
|
82 |
+
def _compute_scale_factor(
|
83 |
+
x: Tensor,
|
84 |
+
channel_dim: int,
|
85 |
+
min_abs: float,
|
86 |
+
max_abs: float,
|
87 |
+
gain_factor: float,
|
88 |
+
max_factor: float,
|
89 |
+
) -> Tensor:
|
90 |
+
if channel_dim < 0:
|
91 |
+
channel_dim += x.ndim
|
92 |
+
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
93 |
+
x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)
|
94 |
+
|
95 |
+
if min_abs == 0.0:
|
96 |
+
below_threshold = 0.0
|
97 |
+
else:
|
98 |
+
# below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
|
99 |
+
# x_abs)_mean , min_abs.
|
100 |
+
below_threshold = (
|
101 |
+
(min_abs - x_abs_mean) * (gain_factor / min_abs)
|
102 |
+
).clamp(min=0, max=max_factor)
|
103 |
+
|
104 |
+
above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(
|
105 |
+
min=0, max=max_factor
|
106 |
+
)
|
107 |
+
|
108 |
+
return below_threshold - above_threshold
|
109 |
+
|
110 |
+
|
111 |
+
def _compute_sign_factor(
|
112 |
+
x: Tensor,
|
113 |
+
channel_dim: int,
|
114 |
+
min_positive: float,
|
115 |
+
max_positive: float,
|
116 |
+
gain_factor: float,
|
117 |
+
max_factor: float,
|
118 |
+
) -> Tensor:
|
119 |
+
if channel_dim < 0:
|
120 |
+
channel_dim += x.ndim
|
121 |
+
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
122 |
+
proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims)
|
123 |
+
if min_positive == 0.0:
|
124 |
+
factor1 = 0.0
|
125 |
+
else:
|
126 |
+
# 0 if proportion_positive >= min_positive, else can be
|
127 |
+
# as large as max_factor.
|
128 |
+
factor1 = (
|
129 |
+
(min_positive - proportion_positive) * (gain_factor / min_positive)
|
130 |
+
).clamp_(min=0, max=max_factor)
|
131 |
+
|
132 |
+
if max_positive == 1.0:
|
133 |
+
factor2 = 0.0
|
134 |
+
else:
|
135 |
+
# 0 if self.proportion_positive <= max_positive, else can be
|
136 |
+
# as large as -max_factor.
|
137 |
+
factor2 = (
|
138 |
+
(proportion_positive - max_positive)
|
139 |
+
* (gain_factor / (1.0 - max_positive))
|
140 |
+
).clamp_(min=0, max=max_factor)
|
141 |
+
sign_factor = factor1 - factor2
|
142 |
+
# require min_positive != 0 or max_positive != 1:
|
143 |
+
assert not isinstance(sign_factor, float)
|
144 |
+
return sign_factor
|
145 |
+
|
146 |
+
|
147 |
+
class ActivationScaleBalancerFunction(torch.autograd.Function):
|
148 |
+
"""
|
149 |
+
This object is used in class ActivationBalancer when the user specified
|
150 |
+
min_positive=0, max_positive=1, so there are no constraints on the signs
|
151 |
+
of the activations and only the absolute value has a constraint.
|
152 |
+
"""
|
153 |
+
|
154 |
+
@staticmethod
|
155 |
+
def forward(
|
156 |
+
ctx,
|
157 |
+
x: Tensor,
|
158 |
+
sign_factor: Tensor,
|
159 |
+
scale_factor: Tensor,
|
160 |
+
channel_dim: int,
|
161 |
+
) -> Tensor:
|
162 |
+
if channel_dim < 0:
|
163 |
+
channel_dim += x.ndim
|
164 |
+
ctx.channel_dim = channel_dim
|
165 |
+
xgt0 = x > 0
|
166 |
+
ctx.save_for_backward(xgt0, sign_factor, scale_factor)
|
167 |
+
return x
|
168 |
+
|
169 |
+
@staticmethod
|
170 |
+
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
|
171 |
+
xgt0, sign_factor, scale_factor = ctx.saved_tensors
|
172 |
+
for _ in range(ctx.channel_dim, x_grad.ndim - 1):
|
173 |
+
sign_factor = sign_factor.unsqueeze(-1)
|
174 |
+
scale_factor = scale_factor.unsqueeze(-1)
|
175 |
+
|
176 |
+
factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
|
177 |
+
neg_delta_grad = x_grad.abs() * factor
|
178 |
+
return (
|
179 |
+
x_grad - neg_delta_grad,
|
180 |
+
None,
|
181 |
+
None,
|
182 |
+
None,
|
183 |
+
)
|
184 |
+
|
185 |
+
|
186 |
+
class RandomClampFunction(torch.autograd.Function):
|
187 |
+
@staticmethod
|
188 |
+
def forward(
|
189 |
+
ctx,
|
190 |
+
x: Tensor,
|
191 |
+
min: Optional[float],
|
192 |
+
max: Optional[float],
|
193 |
+
prob: float,
|
194 |
+
reflect: float,
|
195 |
+
) -> Tensor:
|
196 |
+
x_clamped = torch.clamp(x, min=min, max=max)
|
197 |
+
mask = torch.rand_like(x) < prob
|
198 |
+
ans = torch.where(mask, x_clamped, x)
|
199 |
+
if x.requires_grad:
|
200 |
+
ctx.save_for_backward(ans == x)
|
201 |
+
ctx.reflect = reflect
|
202 |
+
if reflect != 0.0:
|
203 |
+
ans = ans * (1.0 + reflect) - (x * reflect)
|
204 |
+
return ans
|
205 |
+
|
206 |
+
@staticmethod
|
207 |
+
def backward(
|
208 |
+
ctx, ans_grad: Tensor
|
209 |
+
) -> Tuple[Tensor, None, None, None, None]:
|
210 |
+
(is_same,) = ctx.saved_tensors
|
211 |
+
x_grad = ans_grad * is_same.to(ans_grad.dtype)
|
212 |
+
reflect = ctx.reflect
|
213 |
+
if reflect != 0.0:
|
214 |
+
x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect)
|
215 |
+
return x_grad, None, None, None, None
|
216 |
+
|
217 |
+
|
218 |
+
def random_clamp(
|
219 |
+
x: Tensor,
|
220 |
+
min: Optional[float] = None,
|
221 |
+
max: Optional[float] = None,
|
222 |
+
prob: float = 0.5,
|
223 |
+
reflect: float = 0.0,
|
224 |
+
):
|
225 |
+
return RandomClampFunction.apply(x, min, max, prob, reflect)
|
226 |
+
|
227 |
+
|
228 |
+
def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor:
|
229 |
+
"""
|
230 |
+
A randomized way of casting a floating point value to half precision.
|
231 |
+
"""
|
232 |
+
if x.dtype == torch.float16:
|
233 |
+
return x
|
234 |
+
x_abs = x.abs()
|
235 |
+
is_too_small = x_abs < min_abs
|
236 |
+
# for elements where is_too_small is true, random_val will contain +-min_abs with
|
237 |
+
# probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations,
|
238 |
+
# for those elements].
|
239 |
+
random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs)
|
240 |
+
return torch.where(is_too_small, random_val, x).to(torch.float16)
|
241 |
+
|
242 |
+
|
243 |
+
class RandomGradFunction(torch.autograd.Function):
|
244 |
+
"""
|
245 |
+
Does nothing in forward pass; in backward pass, gets rid of very small grads using
|
246 |
+
randomized approach that preserves expectations (intended to reduce roundoff).
|
247 |
+
"""
|
248 |
+
|
249 |
+
@staticmethod
|
250 |
+
def forward(ctx, x: Tensor, min_abs: float) -> Tensor:
|
251 |
+
ctx.min_abs = min_abs
|
252 |
+
return x
|
253 |
+
|
254 |
+
@staticmethod
|
255 |
+
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]:
|
256 |
+
if ans_grad.dtype == torch.float16:
|
257 |
+
return (
|
258 |
+
random_cast_to_half(
|
259 |
+
ans_grad.to(torch.float32), min_abs=ctx.min_abs
|
260 |
+
),
|
261 |
+
None,
|
262 |
+
)
|
263 |
+
else:
|
264 |
+
return ans_grad, None
|
265 |
+
|
266 |
+
|
267 |
+
class RandomGrad(torch.nn.Module):
|
268 |
+
"""
|
269 |
+
Gets rid of very small gradients using an expectation-preserving method, intended to increase
|
270 |
+
accuracy of training when using amp (automatic mixed precision)
|
271 |
+
"""
|
272 |
+
|
273 |
+
def __init__(self, min_abs: float = 5.0e-06):
|
274 |
+
super(RandomGrad, self).__init__()
|
275 |
+
self.min_abs = min_abs
|
276 |
+
|
277 |
+
def forward(self, x: Tensor):
|
278 |
+
if (
|
279 |
+
torch.jit.is_scripting()
|
280 |
+
or not self.training
|
281 |
+
or torch.jit.is_tracing()
|
282 |
+
):
|
283 |
+
return x
|
284 |
+
else:
|
285 |
+
return RandomGradFunction.apply(x, self.min_abs)
|
286 |
+
|
287 |
+
|
288 |
+
class SoftmaxFunction(torch.autograd.Function):
|
289 |
+
"""
|
290 |
+
Tries to handle half-precision derivatives in a randomized way that should
|
291 |
+
be more accurate for training than the default behavior.
|
292 |
+
"""
|
293 |
+
|
294 |
+
@staticmethod
|
295 |
+
def forward(ctx, x: Tensor, dim: int):
|
296 |
+
ans = x.softmax(dim=dim)
|
297 |
+
# if x dtype is float16, x.softmax() returns a float32 because
|
298 |
+
# (presumably) that op does not support float16, and autocast
|
299 |
+
# is enabled.
|
300 |
+
if torch.is_autocast_enabled():
|
301 |
+
ans = ans.to(torch.float16)
|
302 |
+
ctx.save_for_backward(ans)
|
303 |
+
ctx.x_dtype = x.dtype
|
304 |
+
ctx.dim = dim
|
305 |
+
return ans
|
306 |
+
|
307 |
+
@staticmethod
|
308 |
+
def backward(ctx, ans_grad: Tensor):
|
309 |
+
(ans,) = ctx.saved_tensors
|
310 |
+
with torch.cuda.amp.autocast(enabled=False):
|
311 |
+
ans_grad = ans_grad.to(torch.float32)
|
312 |
+
ans = ans.to(torch.float32)
|
313 |
+
x_grad = ans_grad * ans
|
314 |
+
x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
|
315 |
+
return x_grad, None
|
316 |
+
|
317 |
+
|
318 |
+
def softmax(x: Tensor, dim: int):
|
319 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
320 |
+
return x.softmax(dim)
|
321 |
+
|
322 |
+
return SoftmaxFunction.apply(x, dim)
|
323 |
+
|
324 |
+
|
325 |
+
class MaxEigLimiterFunction(torch.autograd.Function):
|
326 |
+
@staticmethod
|
327 |
+
def forward(
|
328 |
+
ctx,
|
329 |
+
x: Tensor,
|
330 |
+
coeffs: Tensor,
|
331 |
+
direction: Tensor,
|
332 |
+
channel_dim: int,
|
333 |
+
grad_scale: float,
|
334 |
+
) -> Tensor:
|
335 |
+
ctx.channel_dim = channel_dim
|
336 |
+
ctx.grad_scale = grad_scale
|
337 |
+
ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach())
|
338 |
+
return x
|
339 |
+
|
340 |
+
@staticmethod
|
341 |
+
def backward(ctx, x_grad, *args):
|
342 |
+
with torch.enable_grad():
|
343 |
+
(x_orig, coeffs, new_direction) = ctx.saved_tensors
|
344 |
+
x_orig.requires_grad = True
|
345 |
+
num_channels = x_orig.shape[ctx.channel_dim]
|
346 |
+
x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels)
|
347 |
+
new_direction.requires_grad = False
|
348 |
+
x = x - x.mean(dim=0)
|
349 |
+
x_var = (x ** 2).mean()
|
350 |
+
x_residual = x - coeffs * new_direction
|
351 |
+
x_residual_var = (x_residual ** 2).mean()
|
352 |
+
# `variance_proportion` is the proportion of the variance accounted for
|
353 |
+
# by the top eigen-direction. This is to be minimized.
|
354 |
+
variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
|
355 |
+
variance_proportion.backward()
|
356 |
+
x_orig_grad = x_orig.grad
|
357 |
+
x_extra_grad = (
|
358 |
+
x_orig.grad
|
359 |
+
* ctx.grad_scale
|
360 |
+
* x_grad.norm()
|
361 |
+
/ (x_orig_grad.norm() + 1.0e-20)
|
362 |
+
)
|
363 |
+
return x_grad + x_extra_grad.detach(), None, None, None, None
|
364 |
+
|
365 |
+
|
366 |
+
class BasicNorm(torch.nn.Module):
|
367 |
+
"""
|
368 |
+
This is intended to be a simpler, and hopefully cheaper, replacement for
|
369 |
+
LayerNorm. The observation this is based on, is that Transformer-type
|
370 |
+
networks, especially with pre-norm, sometimes seem to set one of the
|
371 |
+
feature dimensions to a large constant value (e.g. 50), which "defeats"
|
372 |
+
the LayerNorm because the output magnitude is then not strongly dependent
|
373 |
+
on the other (useful) features. Presumably the weight and bias of the
|
374 |
+
LayerNorm are required to allow it to do this.
|
375 |
+
|
376 |
+
So the idea is to introduce this large constant value as an explicit
|
377 |
+
parameter, that takes the role of the "eps" in LayerNorm, so the network
|
378 |
+
doesn't have to do this trick. We make the "eps" learnable.
|
379 |
+
|
380 |
+
Args:
|
381 |
+
num_channels: the number of channels, e.g. 512.
|
382 |
+
channel_dim: the axis/dimension corresponding to the channel,
|
383 |
+
interprted as an offset from the input's ndim if negative.
|
384 |
+
shis is NOT the num_channels; it should typically be one of
|
385 |
+
{-2, -1, 0, 1, 2, 3}.
|
386 |
+
eps: the initial "epsilon" that we add as ballast in:
|
387 |
+
scale = ((input_vec**2).mean() + epsilon)**-0.5
|
388 |
+
Note: our epsilon is actually large, but we keep the name
|
389 |
+
to indicate the connection with conventional LayerNorm.
|
390 |
+
learn_eps: if true, we learn epsilon; if false, we keep it
|
391 |
+
at the initial value.
|
392 |
+
eps_min: float
|
393 |
+
eps_max: float
|
394 |
+
"""
|
395 |
+
|
396 |
+
def __init__(
|
397 |
+
self,
|
398 |
+
num_channels: int,
|
399 |
+
channel_dim: int = -1, # CAUTION: see documentation.
|
400 |
+
eps: float = 0.25,
|
401 |
+
learn_eps: bool = True,
|
402 |
+
eps_min: float = -3.0,
|
403 |
+
eps_max: float = 3.0,
|
404 |
+
) -> None:
|
405 |
+
super(BasicNorm, self).__init__()
|
406 |
+
self.num_channels = num_channels
|
407 |
+
self.channel_dim = channel_dim
|
408 |
+
if learn_eps:
|
409 |
+
self.eps = nn.Parameter(torch.tensor(eps).log().detach())
|
410 |
+
else:
|
411 |
+
self.register_buffer("eps", torch.tensor(eps).log().detach())
|
412 |
+
self.eps_min = eps_min
|
413 |
+
self.eps_max = eps_max
|
414 |
+
|
415 |
+
def forward(self, x: Tensor) -> Tensor:
|
416 |
+
assert x.shape[self.channel_dim] == self.num_channels
|
417 |
+
eps = self.eps
|
418 |
+
if self.training and random.random() < 0.25:
|
419 |
+
# with probability 0.25, in training mode, clamp eps between the min
|
420 |
+
# and max; this will encourage it to learn parameters within the
|
421 |
+
# allowed range by making parameters that are outside the allowed
|
422 |
+
# range noisy.
|
423 |
+
|
424 |
+
# gradients to allow the parameter to get back into the allowed region if it happens to exit it.
|
425 |
+
eps = eps.clamp(min=self.eps_min, max=self.eps_max)
|
426 |
+
scales = (
|
427 |
+
torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps.exp()
|
428 |
+
) ** -0.5
|
429 |
+
return x * scales
|
430 |
+
|
431 |
+
|
432 |
+
def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
|
433 |
+
"""
|
434 |
+
Behaves like a constructor of a modified version of nn.Linear
|
435 |
+
that gives an easy way to set the default initial parameter scale.
|
436 |
+
|
437 |
+
Args:
|
438 |
+
Accepts the standard args and kwargs that nn.Linear accepts
|
439 |
+
e.g. in_features, out_features, bias=False.
|
440 |
+
|
441 |
+
initial_scale: you can override this if you want to increase
|
442 |
+
or decrease the initial magnitude of the module's output
|
443 |
+
(affects the initialization of weight_scale and bias_scale).
|
444 |
+
Another option, if you want to do something like this, is
|
445 |
+
to re-initialize the parameters.
|
446 |
+
"""
|
447 |
+
ans = nn.Linear(*args, **kwargs)
|
448 |
+
with torch.no_grad():
|
449 |
+
ans.weight[:] *= initial_scale
|
450 |
+
if ans.bias is not None:
|
451 |
+
torch.nn.init.uniform_(
|
452 |
+
ans.bias, -0.1 * initial_scale, 0.1 * initial_scale
|
453 |
+
)
|
454 |
+
return ans
|
455 |
+
|
456 |
+
|
457 |
+
def ScaledConv1d(
|
458 |
+
*args,
|
459 |
+
initial_scale: float = 1.0,
|
460 |
+
kernel_size: int = 3,
|
461 |
+
padding: str = "same",
|
462 |
+
**kwargs,
|
463 |
+
) -> nn.Conv1d:
|
464 |
+
"""
|
465 |
+
Behaves like a constructor of a modified version of nn.Conv1d
|
466 |
+
that gives an easy way to set the default initial parameter scale.
|
467 |
+
|
468 |
+
Args:
|
469 |
+
Accepts the standard args and kwargs that nn.Linear accepts
|
470 |
+
e.g. in_features, out_features, bias=False.
|
471 |
+
|
472 |
+
initial_scale: you can override this if you want to increase
|
473 |
+
or decrease the initial magnitude of the module's output
|
474 |
+
(affects the initialization of weight_scale and bias_scale).
|
475 |
+
Another option, if you want to do something like this, is
|
476 |
+
to re-initialize the parameters.
|
477 |
+
"""
|
478 |
+
ans = nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs)
|
479 |
+
with torch.no_grad():
|
480 |
+
ans.weight[:] *= initial_scale
|
481 |
+
if ans.bias is not None:
|
482 |
+
torch.nn.init.uniform_(
|
483 |
+
ans.bias, -0.1 * initial_scale, 0.1 * initial_scale
|
484 |
+
)
|
485 |
+
return ans
|
486 |
+
|
487 |
+
|
488 |
+
def TransposeScaledConv1d(
|
489 |
+
*args,
|
490 |
+
initial_scale: float = 1.0,
|
491 |
+
kernel_size: int = 3,
|
492 |
+
padding: str = "same",
|
493 |
+
**kwargs,
|
494 |
+
) -> nn.Sequential:
|
495 |
+
"""
|
496 |
+
Transpose -> ScaledConv1d
|
497 |
+
"""
|
498 |
+
return nn.Sequential(
|
499 |
+
Transpose(),
|
500 |
+
ScaledConv1d(
|
501 |
+
*args,
|
502 |
+
initial_scale=initial_scale,
|
503 |
+
kernel_size=kernel_size,
|
504 |
+
padding=padding,
|
505 |
+
**kwargs,
|
506 |
+
),
|
507 |
+
)
|
508 |
+
|
509 |
+
|
510 |
+
def ScaledConv1dTranspose(
|
511 |
+
*args,
|
512 |
+
initial_scale: float = 1.0,
|
513 |
+
kernel_size: int = 3,
|
514 |
+
padding: str = "same",
|
515 |
+
**kwargs,
|
516 |
+
) -> nn.Sequential:
|
517 |
+
"""
|
518 |
+
Transpose -> ScaledConv1d
|
519 |
+
"""
|
520 |
+
return nn.Sequential(
|
521 |
+
ScaledConv1d(
|
522 |
+
*args,
|
523 |
+
initial_scale=initial_scale,
|
524 |
+
kernel_size=kernel_size,
|
525 |
+
padding=padding,
|
526 |
+
**kwargs,
|
527 |
+
),
|
528 |
+
Transpose(),
|
529 |
+
)
|
530 |
+
|
531 |
+
|
532 |
+
def TransposeConv1d(
|
533 |
+
*args, kernel_size: int = 3, padding: str = "same", **kwargs
|
534 |
+
) -> nn.Sequential:
|
535 |
+
"""
|
536 |
+
Transpose -> Conv1d
|
537 |
+
"""
|
538 |
+
return nn.Sequential(
|
539 |
+
Transpose(),
|
540 |
+
nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
|
541 |
+
)
|
542 |
+
|
543 |
+
|
544 |
+
def Conv1dTranspose(
|
545 |
+
*args, kernel_size: int = 3, padding: str = "same", **kwargs
|
546 |
+
) -> nn.Sequential:
|
547 |
+
"""
|
548 |
+
ScaledConv1d -> Transpose
|
549 |
+
"""
|
550 |
+
return nn.Sequential(
|
551 |
+
nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
|
552 |
+
Transpose(),
|
553 |
+
)
|
554 |
+
|
555 |
+
|
556 |
+
class SRLinear(nn.Linear):
|
557 |
+
"""https://arxiv.org/abs/2303.06296
|
558 |
+
Stabilizing Transformer Training by Preventing Attention Entropy Collapse
|
559 |
+
"""
|
560 |
+
|
561 |
+
def __init__(self, in_features, out_features, bias=True, **kwargs):
|
562 |
+
super().__init__(in_features, out_features, bias=bias, **kwargs)
|
563 |
+
self.register_buffer(
|
564 |
+
"u", nn.functional.normalize(torch.randn(in_features), dim=0)
|
565 |
+
)
|
566 |
+
with torch.no_grad():
|
567 |
+
sigma = self.get_sigma()
|
568 |
+
self.register_buffer("spectral_norm", sigma)
|
569 |
+
self.sigma = nn.Parameter(torch.ones(1))
|
570 |
+
|
571 |
+
def get_sigma(self):
|
572 |
+
with torch.no_grad():
|
573 |
+
u = self.u
|
574 |
+
v = self.weight.mv(u)
|
575 |
+
v = nn.functional.normalize(v, dim=0)
|
576 |
+
u = self.weight.T.mv(v)
|
577 |
+
u = nn.functional.normalize(u, dim=0)
|
578 |
+
self.u.data.copy_(u)
|
579 |
+
return torch.einsum("c,cd,d->", v, self.weight, u)
|
580 |
+
|
581 |
+
def get_weight(self):
|
582 |
+
sigma = self.get_sigma()
|
583 |
+
if self.training:
|
584 |
+
self.spectral_norm.data.copy_(sigma)
|
585 |
+
weight = (self.sigma / sigma) * self.weight
|
586 |
+
return weight
|
587 |
+
|
588 |
+
def forward(self, x):
|
589 |
+
return nn.functional.linear(x, self.get_weight(), self.bias)
|
590 |
+
|
591 |
+
|
592 |
+
class SRConv1d(SRLinear):
|
593 |
+
def __init__(
|
594 |
+
self,
|
595 |
+
in_features,
|
596 |
+
out_features,
|
597 |
+
kernel_size,
|
598 |
+
stride: int = 1,
|
599 |
+
padding: str = "same",
|
600 |
+
bias: bool = True,
|
601 |
+
**kwargs,
|
602 |
+
):
|
603 |
+
in_features = in_features * kernel_size
|
604 |
+
super().__init__(in_features, out_features, bias=bias, **kwargs)
|
605 |
+
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
606 |
+
self.kernel_size = kernel_size
|
607 |
+
self.stride = stride
|
608 |
+
self.padding = padding
|
609 |
+
|
610 |
+
def forward(self, x):
|
611 |
+
in_features = self.in_features // self.kernel_size
|
612 |
+
weight = self.get_weight().view(
|
613 |
+
self.out_features, in_features, self.kernel_size
|
614 |
+
)
|
615 |
+
return nn.functional.conv1d(
|
616 |
+
x, weight, bias=self.bias, stride=self.stride, padding=self.padding
|
617 |
+
)
|
618 |
+
|
619 |
+
|
620 |
+
def TransposeSRConv1d(
|
621 |
+
*args, kernel_size: int = 3, padding: str = "same", **kwargs
|
622 |
+
) -> nn.Sequential:
|
623 |
+
"""
|
624 |
+
Transpose -> SRConv1d
|
625 |
+
"""
|
626 |
+
return nn.Sequential(
|
627 |
+
Transpose(),
|
628 |
+
SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
|
629 |
+
)
|
630 |
+
|
631 |
+
|
632 |
+
def SRConv1dTranspose(
|
633 |
+
*args, kernel_size: int = 3, padding: str = "same", **kwargs
|
634 |
+
) -> nn.Sequential:
|
635 |
+
"""
|
636 |
+
SRConv1d -> Transpose
|
637 |
+
"""
|
638 |
+
return nn.Sequential(
|
639 |
+
SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
|
640 |
+
Transpose(),
|
641 |
+
)
|
642 |
+
|
643 |
+
|
644 |
+
class ActivationBalancer(torch.nn.Module):
|
645 |
+
"""
|
646 |
+
Modifies the backpropped derivatives of a function to try to encourage, for
|
647 |
+
each channel, that it is positive at least a proportion `threshold` of the
|
648 |
+
time. It does this by multiplying negative derivative values by up to
|
649 |
+
(1+max_factor), and positive derivative values by up to (1-max_factor),
|
650 |
+
interpolated from 1 at the threshold to those extremal values when none
|
651 |
+
of the inputs are positive.
|
652 |
+
|
653 |
+
Args:
|
654 |
+
num_channels: the number of channels
|
655 |
+
channel_dim: the dimension/axis corresponding to the channel, e.g.
|
656 |
+
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
|
657 |
+
min_positive: the minimum, per channel, of the proportion of the time
|
658 |
+
that (x > 0), below which we start to modify the derivatives.
|
659 |
+
max_positive: the maximum, per channel, of the proportion of the time
|
660 |
+
that (x > 0), above which we start to modify the derivatives.
|
661 |
+
max_factor: the maximum factor by which we modify the derivatives for
|
662 |
+
either the sign constraint or the magnitude constraint;
|
663 |
+
e.g. with max_factor=0.02, the the derivatives would be multiplied by
|
664 |
+
values in the range [0.98..1.02].
|
665 |
+
sign_gain_factor: determines the 'gain' with which we increase the
|
666 |
+
change in gradient once the constraints on min_positive and max_positive
|
667 |
+
are violated.
|
668 |
+
scale_gain_factor: determines the 'gain' with which we increase the
|
669 |
+
change in gradient once the constraints on min_abs and max_abs
|
670 |
+
are violated.
|
671 |
+
min_abs: the minimum average-absolute-value difference from the mean
|
672 |
+
value per channel, which we allow, before we start to modify
|
673 |
+
the derivatives to prevent this.
|
674 |
+
max_abs: the maximum average-absolute-value difference from the mean
|
675 |
+
value per channel, which we allow, before we start to modify
|
676 |
+
the derivatives to prevent this.
|
677 |
+
min_prob: determines the minimum probability with which we modify the
|
678 |
+
gradients for the {min,max}_positive and {min,max}_abs constraints,
|
679 |
+
on each forward(). This is done randomly to prevent all layers
|
680 |
+
from doing it at the same time. Early in training we may use
|
681 |
+
higher probabilities than this; it will decay to this value.
|
682 |
+
"""
|
683 |
+
|
684 |
+
def __init__(
|
685 |
+
self,
|
686 |
+
num_channels: int,
|
687 |
+
channel_dim: int,
|
688 |
+
min_positive: float = 0.05,
|
689 |
+
max_positive: float = 0.95,
|
690 |
+
max_factor: float = 0.04,
|
691 |
+
sign_gain_factor: float = 0.01,
|
692 |
+
scale_gain_factor: float = 0.02,
|
693 |
+
min_abs: float = 0.2,
|
694 |
+
max_abs: float = 100.0,
|
695 |
+
min_prob: float = 0.1,
|
696 |
+
):
|
697 |
+
super(ActivationBalancer, self).__init__()
|
698 |
+
self.num_channels = num_channels
|
699 |
+
self.channel_dim = channel_dim
|
700 |
+
self.min_positive = min_positive
|
701 |
+
self.max_positive = max_positive
|
702 |
+
self.max_factor = max_factor
|
703 |
+
self.min_abs = min_abs
|
704 |
+
self.max_abs = max_abs
|
705 |
+
self.min_prob = min_prob
|
706 |
+
self.sign_gain_factor = sign_gain_factor
|
707 |
+
self.scale_gain_factor = scale_gain_factor
|
708 |
+
|
709 |
+
# count measures how many times the forward() function has been called.
|
710 |
+
# We occasionally sync this to a tensor called `count`, that exists to
|
711 |
+
# make sure it is synced to disk when we load and save the model.
|
712 |
+
self.cpu_count = 0
|
713 |
+
self.register_buffer("count", torch.tensor(0, dtype=torch.int64))
|
714 |
+
|
715 |
+
def forward(self, x: Tensor) -> Tensor:
|
716 |
+
if (
|
717 |
+
torch.jit.is_scripting()
|
718 |
+
or not x.requires_grad
|
719 |
+
or torch.jit.is_tracing()
|
720 |
+
):
|
721 |
+
return _no_op(x)
|
722 |
+
|
723 |
+
count = self.cpu_count
|
724 |
+
self.cpu_count += 1
|
725 |
+
|
726 |
+
if random.random() < 0.01:
|
727 |
+
# Occasionally sync self.cpu_count with self.count.
|
728 |
+
# count affects the decay of 'prob'. don't do this on every iter,
|
729 |
+
# because syncing with the GPU is slow.
|
730 |
+
self.cpu_count = max(self.cpu_count, self.count.item())
|
731 |
+
self.count.fill_(self.cpu_count)
|
732 |
+
|
733 |
+
# the prob of doing some work exponentially decreases from 0.5 till it hits
|
734 |
+
# a floor at min_prob (==0.1, by default)
|
735 |
+
prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0)))
|
736 |
+
|
737 |
+
if random.random() < prob:
|
738 |
+
sign_gain_factor = 0.5
|
739 |
+
if self.min_positive != 0.0 or self.max_positive != 1.0:
|
740 |
+
sign_factor = _compute_sign_factor(
|
741 |
+
x,
|
742 |
+
self.channel_dim,
|
743 |
+
self.min_positive,
|
744 |
+
self.max_positive,
|
745 |
+
gain_factor=self.sign_gain_factor / prob,
|
746 |
+
max_factor=self.max_factor,
|
747 |
+
)
|
748 |
+
else:
|
749 |
+
sign_factor = None
|
750 |
+
|
751 |
+
scale_factor = _compute_scale_factor(
|
752 |
+
x.detach(),
|
753 |
+
self.channel_dim,
|
754 |
+
min_abs=self.min_abs,
|
755 |
+
max_abs=self.max_abs,
|
756 |
+
gain_factor=self.scale_gain_factor / prob,
|
757 |
+
max_factor=self.max_factor,
|
758 |
+
)
|
759 |
+
return ActivationBalancerFunction.apply(
|
760 |
+
x,
|
761 |
+
scale_factor,
|
762 |
+
sign_factor,
|
763 |
+
self.channel_dim,
|
764 |
+
)
|
765 |
+
else:
|
766 |
+
return _no_op(x)
|
767 |
+
|
768 |
+
|
769 |
+
def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor:
|
770 |
+
"""
|
771 |
+
Returns x unmodified, but in backprop will put a penalty for the excess of
|
772 |
+
the absolute values of elements of x over the limit "limit". E.g. if
|
773 |
+
limit == 10.0, then if x has any values over 10 it will get a penalty.
|
774 |
+
|
775 |
+
Caution: the value of this penalty will be affected by grad scaling used
|
776 |
+
in automatic mixed precision training. For this reasons we use this,
|
777 |
+
it shouldn't really matter, or may even be helpful; we just use this
|
778 |
+
to disallow really implausible values of scores to be given to softmax.
|
779 |
+
"""
|
780 |
+
x_sign = x.sign()
|
781 |
+
over_limit = (x.abs() - limit) > 0
|
782 |
+
# The following is a memory efficient way to penalize the absolute values of
|
783 |
+
# x that's over the limit. (The memory efficiency comes when you think
|
784 |
+
# about which items torch needs to cache for the autograd, and which ones it
|
785 |
+
# can throw away). The numerical value of aux_loss as computed here will
|
786 |
+
# actually be larger than it should be, by limit * over_limit.sum(), but it
|
787 |
+
# has the same derivative as the real aux_loss which is penalty * (x.abs() -
|
788 |
+
# limit).relu().
|
789 |
+
aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x)
|
790 |
+
# note: we don't do sum() here on aux)_loss, but it's as if we had done
|
791 |
+
# sum() due to how with_loss() works.
|
792 |
+
x = with_loss(x, aux_loss)
|
793 |
+
# you must use x for something, or this will be ineffective.
|
794 |
+
return x
|
795 |
+
|
796 |
+
|
797 |
+
def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
|
798 |
+
if x.ndim == 2:
|
799 |
+
return x.diag()
|
800 |
+
else:
|
801 |
+
(batch, dim, dim) = x.shape
|
802 |
+
x = x.reshape(batch, dim * dim)
|
803 |
+
x = x[:, :: dim + 1]
|
804 |
+
assert x.shape == (batch, dim)
|
805 |
+
return x
|
806 |
+
|
807 |
+
|
808 |
+
def _whitening_metric(x: Tensor, num_groups: int):
|
809 |
+
"""
|
810 |
+
Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of
|
811 |
+
of the centered feature covariance are the same within each group's covariance matrix
|
812 |
+
and also between groups.
|
813 |
+
Args:
|
814 |
+
x: a Tensor of shape (*, num_channels)
|
815 |
+
num_groups: the number of groups of channels, a number >=1 that divides num_channels
|
816 |
+
Returns:
|
817 |
+
Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and
|
818 |
+
greater than 1.0 otherwise.
|
819 |
+
"""
|
820 |
+
assert x.dtype != torch.float16
|
821 |
+
x = x.reshape(-1, x.shape[-1])
|
822 |
+
(num_frames, num_channels) = x.shape
|
823 |
+
assert num_channels % num_groups == 0
|
824 |
+
channels_per_group = num_channels // num_groups
|
825 |
+
x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1)
|
826 |
+
# x now has shape (num_groups, num_frames, channels_per_group)
|
827 |
+
# subtract the mean so we use the centered, not uncentered, covariance.
|
828 |
+
# My experience has been that when we "mess with the gradients" like this,
|
829 |
+
# it's better not do anything that tries to move the mean around, because
|
830 |
+
# that can easily cause instability.
|
831 |
+
x = x - x.mean(dim=1, keepdim=True)
|
832 |
+
# x_covar: (num_groups, channels_per_group, channels_per_group)
|
833 |
+
x_covar = torch.matmul(x.transpose(1, 2), x)
|
834 |
+
x_covar_mean_diag = _diag(x_covar).mean()
|
835 |
+
# the following expression is what we'd get if we took the matrix product
|
836 |
+
# of each covariance and measured the mean of its trace, i.e.
|
837 |
+
# the same as _diag(torch.matmul(x_covar, x_covar)).mean().
|
838 |
+
x_covarsq_mean_diag = (x_covar ** 2).sum() / (
|
839 |
+
num_groups * channels_per_group
|
840 |
+
)
|
841 |
+
# this metric will be >= 1.0; the larger it is, the less 'white' the data was.
|
842 |
+
metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20)
|
843 |
+
return metric
|
844 |
+
|
845 |
+
|
846 |
+
class WhiteningPenaltyFunction(torch.autograd.Function):
|
847 |
+
@staticmethod
|
848 |
+
def forward(
|
849 |
+
ctx,
|
850 |
+
x: Tensor,
|
851 |
+
num_groups: int,
|
852 |
+
whitening_limit: float,
|
853 |
+
grad_scale: float,
|
854 |
+
) -> Tensor:
|
855 |
+
ctx.save_for_backward(x)
|
856 |
+
ctx.num_groups = num_groups
|
857 |
+
ctx.whitening_limit = whitening_limit
|
858 |
+
ctx.grad_scale = grad_scale
|
859 |
+
return x
|
860 |
+
|
861 |
+
@staticmethod
|
862 |
+
def backward(ctx, x_grad: Tensor):
|
863 |
+
(x_orig,) = ctx.saved_tensors
|
864 |
+
with torch.enable_grad():
|
865 |
+
with torch.cuda.amp.autocast(enabled=False):
|
866 |
+
x_detached = x_orig.to(torch.float32).detach()
|
867 |
+
x_detached.requires_grad = True
|
868 |
+
|
869 |
+
metric = _whitening_metric(x_detached, ctx.num_groups)
|
870 |
+
|
871 |
+
if random.random() < 0.005 or __name__ == "__main__":
|
872 |
+
logging.info(
|
873 |
+
f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, "
|
874 |
+
f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}"
|
875 |
+
)
|
876 |
+
|
877 |
+
(metric - ctx.whitening_limit).relu().backward()
|
878 |
+
penalty_grad = x_detached.grad
|
879 |
+
scale = ctx.grad_scale * (
|
880 |
+
x_grad.to(torch.float32).norm()
|
881 |
+
/ (penalty_grad.norm() + 1.0e-20)
|
882 |
+
)
|
883 |
+
penalty_grad = penalty_grad * scale
|
884 |
+
return x_grad + penalty_grad.to(x_grad.dtype), None, None, None
|
885 |
+
|
886 |
+
|
887 |
+
class Whiten(nn.Module):
|
888 |
+
def __init__(
|
889 |
+
self,
|
890 |
+
num_groups: int,
|
891 |
+
whitening_limit: float,
|
892 |
+
prob: Union[float, Tuple[float, float]],
|
893 |
+
grad_scale: float,
|
894 |
+
):
|
895 |
+
"""
|
896 |
+
Args:
|
897 |
+
num_groups: the number of groups to divide the channel dim into before
|
898 |
+
whitening. We will attempt to make the feature covariance
|
899 |
+
within each group, after mean subtraction, as "white" as possible,
|
900 |
+
while having the same trace across all groups.
|
901 |
+
whitening_limit: a value greater than 1.0, that dictates how much
|
902 |
+
freedom we have to violate the constraints. 1.0 would mean perfectly
|
903 |
+
white, with exactly the same trace across groups; larger values
|
904 |
+
give more freedom. E.g. 2.0.
|
905 |
+
prob: the probability with which we apply the gradient modification
|
906 |
+
(also affects the grad scale). May be supplied as a float,
|
907 |
+
or as a pair (min_prob, max_prob)
|
908 |
+
|
909 |
+
grad_scale: determines the scale on the gradient term from this object,
|
910 |
+
relative to the rest of the gradient on the attention weights.
|
911 |
+
E.g. 0.02 (you may want to use smaller values than this if prob is large)
|
912 |
+
"""
|
913 |
+
super(Whiten, self).__init__()
|
914 |
+
assert num_groups >= 1
|
915 |
+
assert whitening_limit >= 1
|
916 |
+
assert grad_scale >= 0
|
917 |
+
self.num_groups = num_groups
|
918 |
+
self.whitening_limit = whitening_limit
|
919 |
+
if isinstance(prob, float):
|
920 |
+
assert 0 < prob <= 1
|
921 |
+
self.prob = prob
|
922 |
+
else:
|
923 |
+
(self.min_prob, self.max_prob) = prob
|
924 |
+
assert 0 < self.min_prob < self.max_prob <= 1
|
925 |
+
self.prob = self.max_prob
|
926 |
+
|
927 |
+
self.grad_scale = grad_scale
|
928 |
+
|
929 |
+
def forward(self, x: Tensor) -> Tensor:
|
930 |
+
"""
|
931 |
+
In the forward pass, this function just returns the input unmodified.
|
932 |
+
In the backward pass, it will modify the gradients to ensure that the
|
933 |
+
distribution in each group has close to (lambda times I) as the covariance
|
934 |
+
after mean subtraction, with the same lambda across groups.
|
935 |
+
For whitening_limit > 1, there will be more freedom to violate this
|
936 |
+
constraint.
|
937 |
+
|
938 |
+
Args:
|
939 |
+
x: the input of shape (*, num_channels)
|
940 |
+
|
941 |
+
Returns:
|
942 |
+
x, unmodified. You should make sure
|
943 |
+
you use the returned value, or the graph will be freed
|
944 |
+
and nothing will happen in backprop.
|
945 |
+
"""
|
946 |
+
if (
|
947 |
+
not x.requires_grad
|
948 |
+
or random.random() > self.prob
|
949 |
+
or self.grad_scale == 0
|
950 |
+
):
|
951 |
+
return _no_op(x)
|
952 |
+
else:
|
953 |
+
if hasattr(self, "min_prob") and random.random() < 0.25:
|
954 |
+
# occasionally switch between min_prob and max_prob, based on whether
|
955 |
+
# we are above or below the threshold.
|
956 |
+
if (
|
957 |
+
_whitening_metric(x.to(torch.float32), self.num_groups)
|
958 |
+
> self.whitening_limit
|
959 |
+
):
|
960 |
+
# there would be a change to the grad.
|
961 |
+
self.prob = self.max_prob
|
962 |
+
else:
|
963 |
+
self.prob = self.min_prob
|
964 |
+
|
965 |
+
return WhiteningPenaltyFunction.apply(
|
966 |
+
x, self.num_groups, self.whitening_limit, self.grad_scale
|
967 |
+
)
|
968 |
+
|
969 |
+
|
970 |
+
class WithLoss(torch.autograd.Function):
|
971 |
+
@staticmethod
|
972 |
+
def forward(ctx, x: Tensor, y: Tensor):
|
973 |
+
ctx.y_shape = y.shape
|
974 |
+
return x
|
975 |
+
|
976 |
+
@staticmethod
|
977 |
+
def backward(ctx, ans_grad: Tensor):
|
978 |
+
return ans_grad, torch.ones(
|
979 |
+
ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device
|
980 |
+
)
|
981 |
+
|
982 |
+
|
983 |
+
def with_loss(x, y):
|
984 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
985 |
+
return x
|
986 |
+
# returns x but adds y.sum() to the loss function.
|
987 |
+
return WithLoss.apply(x, y)
|
988 |
+
|
989 |
+
|
990 |
+
def _no_op(x: Tensor) -> Tensor:
|
991 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
992 |
+
return x
|
993 |
+
else:
|
994 |
+
# a no-op function that will have a node in the autograd graph,
|
995 |
+
# to avoid certain bugs relating to backward hooks
|
996 |
+
return x.chunk(1, dim=-1)[0]
|
997 |
+
|
998 |
+
|
999 |
+
class Identity(torch.nn.Module):
|
1000 |
+
def __init__(self):
|
1001 |
+
super(Identity, self).__init__()
|
1002 |
+
|
1003 |
+
def forward(self, x):
|
1004 |
+
return _no_op(x)
|
1005 |
+
|
1006 |
+
|
1007 |
+
class MaxEig(torch.nn.Module):
|
1008 |
+
"""
|
1009 |
+
Modifies the backpropped derivatives of a function to try to discourage
|
1010 |
+
that any given direction in activation space accounts for more than
|
1011 |
+
a specified proportion of the covariance (e.g. 0.2).
|
1012 |
+
|
1013 |
+
|
1014 |
+
Args:
|
1015 |
+
num_channels: the number of channels
|
1016 |
+
channel_dim: the dimension/axis corresponding to the channel, e.g.
|
1017 |
+
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
|
1018 |
+
max_var_per_eig: the maximum proportion of the variance of the
|
1019 |
+
features/channels, after mean subtraction, that can come from
|
1020 |
+
any given eigenvalue.
|
1021 |
+
min_prob: the minimum probability with which we apply this during any invocation
|
1022 |
+
of forward(), assuming last time we applied the constraint it was
|
1023 |
+
not active; supplied for speed.
|
1024 |
+
scale: determines the scale with which we modify the gradients, relative
|
1025 |
+
to the existing / unmodified gradients
|
1026 |
+
"""
|
1027 |
+
|
1028 |
+
def __init__(
|
1029 |
+
self,
|
1030 |
+
num_channels: int,
|
1031 |
+
channel_dim: int,
|
1032 |
+
max_var_per_eig: float = 0.2,
|
1033 |
+
min_prob: float = 0.01,
|
1034 |
+
scale: float = 0.01,
|
1035 |
+
):
|
1036 |
+
super(MaxEig, self).__init__()
|
1037 |
+
self.num_channels = num_channels
|
1038 |
+
self.channel_dim = channel_dim
|
1039 |
+
self.scale = scale
|
1040 |
+
assert max_var_per_eig == 0.0 or max_var_per_eig > 1.0 / num_channels
|
1041 |
+
self.max_var_per_eig = max_var_per_eig
|
1042 |
+
|
1043 |
+
# we figure out the dominant direction using the power method: starting with
|
1044 |
+
# a random vector, keep multiplying by the covariance and renormalizing.
|
1045 |
+
with torch.no_grad():
|
1046 |
+
# arbitrary.. would use randn() but want to leave the rest of the model's
|
1047 |
+
# random parameters unchanged for comparison
|
1048 |
+
direction = torch.arange(num_channels).to(torch.float)
|
1049 |
+
direction = direction / direction.norm()
|
1050 |
+
self.register_buffer("max_eig_direction", direction)
|
1051 |
+
|
1052 |
+
self.min_prob = min_prob
|
1053 |
+
# cur_prob is the current probability we'll use to apply the ActivationBalancer.
|
1054 |
+
# We'll regress this towards prob, each tiem we try to apply it and it is not
|
1055 |
+
# active.
|
1056 |
+
self.cur_prob = 1.0
|
1057 |
+
|
1058 |
+
def forward(self, x: Tensor) -> Tensor:
|
1059 |
+
if (
|
1060 |
+
torch.jit.is_scripting()
|
1061 |
+
or self.max_var_per_eig <= 0
|
1062 |
+
or random.random() > self.cur_prob
|
1063 |
+
or torch.jit.is_tracing()
|
1064 |
+
):
|
1065 |
+
return _no_op(x)
|
1066 |
+
|
1067 |
+
with torch.cuda.amp.autocast(enabled=False):
|
1068 |
+
eps = 1.0e-20
|
1069 |
+
orig_x = x
|
1070 |
+
x = x.to(torch.float32)
|
1071 |
+
with torch.no_grad():
|
1072 |
+
x = x.transpose(self.channel_dim, -1).reshape(
|
1073 |
+
-1, self.num_channels
|
1074 |
+
)
|
1075 |
+
x = x - x.mean(dim=0)
|
1076 |
+
new_direction, coeffs = self._find_direction_coeffs(
|
1077 |
+
x, self.max_eig_direction
|
1078 |
+
)
|
1079 |
+
x_var = (x ** 2).mean()
|
1080 |
+
x_residual = x - coeffs * new_direction
|
1081 |
+
x_residual_var = (x_residual ** 2).mean()
|
1082 |
+
|
1083 |
+
# `variance_proportion` is the proportion of the variance accounted for
|
1084 |
+
# by the top eigen-direction.
|
1085 |
+
variance_proportion = (x_var - x_residual_var) / (
|
1086 |
+
x_var + 1.0e-20
|
1087 |
+
)
|
1088 |
+
|
1089 |
+
# ensure new direction is nonzero even if x == 0, by including `direction`.
|
1090 |
+
self._set_direction(
|
1091 |
+
0.1 * self.max_eig_direction + new_direction
|
1092 |
+
)
|
1093 |
+
|
1094 |
+
if random.random() < 0.01 or __name__ == "__main__":
|
1095 |
+
logging.info(
|
1096 |
+
f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}"
|
1097 |
+
)
|
1098 |
+
|
1099 |
+
if variance_proportion >= self.max_var_per_eig:
|
1100 |
+
# The constraint is active. Note, we should quite rarely
|
1101 |
+
# reach here, only near the beginning of training if we are
|
1102 |
+
# starting to diverge, should this constraint be active.
|
1103 |
+
cur_prob = self.cur_prob
|
1104 |
+
self.cur_prob = (
|
1105 |
+
1.0 # next time, do the update with probability 1.0.
|
1106 |
+
)
|
1107 |
+
return MaxEigLimiterFunction.apply(
|
1108 |
+
orig_x, coeffs, new_direction, self.channel_dim, self.scale
|
1109 |
+
)
|
1110 |
+
else:
|
1111 |
+
# let self.cur_prob exponentially approach self.min_prob, as
|
1112 |
+
# long as the constraint is inactive.
|
1113 |
+
self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob
|
1114 |
+
return orig_x
|
1115 |
+
|
1116 |
+
def _set_direction(self, direction: Tensor):
|
1117 |
+
"""
|
1118 |
+
Sets self.max_eig_direction to a normalized version of `direction`
|
1119 |
+
"""
|
1120 |
+
direction = direction.detach()
|
1121 |
+
direction = direction / direction.norm()
|
1122 |
+
direction_sum = direction.sum().item()
|
1123 |
+
if direction_sum - direction_sum == 0: # no inf/nan
|
1124 |
+
self.max_eig_direction[:] = direction
|
1125 |
+
else:
|
1126 |
+
logging.info(
|
1127 |
+
f"Warning: sum of direction in MaxEig is {direction_sum}, "
|
1128 |
+
"num_channels={self.num_channels}, channel_dim={self.channel_dim}"
|
1129 |
+
)
|
1130 |
+
|
1131 |
+
def _find_direction_coeffs(
|
1132 |
+
self, x: Tensor, prev_direction: Tensor
|
1133 |
+
) -> Tuple[Tensor, Tensor, Tensor]:
|
1134 |
+
"""
|
1135 |
+
Figure out (an approximation to) the proportion of the variance of a set of
|
1136 |
+
feature vectors that can be attributed to the top eigen-direction.
|
1137 |
+
Args:
|
1138 |
+
x: a Tensor of shape (num_frames, num_channels), with num_frames > 1.
|
1139 |
+
prev_direction: a Tensor of shape (num_channels,), that is our previous estimate
|
1140 |
+
of the top eigen-direction, or a random direction if this is the first
|
1141 |
+
iteration. Does not have to be normalized, but should be nonzero.
|
1142 |
+
|
1143 |
+
Returns: (cur_direction, coeffs), where:
|
1144 |
+
cur_direction: a Tensor of shape (num_channels,) that is the current
|
1145 |
+
estimate of the top eigen-direction.
|
1146 |
+
coeffs: a Tensor of shape (num_frames, 1) that minimizes, or
|
1147 |
+
approximately minimizes, (x - coeffs * cur_direction).norm()
|
1148 |
+
"""
|
1149 |
+
(num_frames, num_channels) = x.shape
|
1150 |
+
assert num_channels > 1 and num_frames > 1
|
1151 |
+
assert prev_direction.shape == (num_channels,)
|
1152 |
+
# `coeffs` are the coefficients of `prev_direction` in x.
|
1153 |
+
# actually represent the coeffs up to a constant positive factor.
|
1154 |
+
coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10
|
1155 |
+
cur_direction = (x * coeffs).sum(dim=0) / (
|
1156 |
+
(coeffs ** 2).sum() + 1.0e-20
|
1157 |
+
)
|
1158 |
+
return cur_direction, coeffs
|
1159 |
+
|
1160 |
+
|
1161 |
+
class DoubleSwishFunction(torch.autograd.Function):
|
1162 |
+
"""
|
1163 |
+
double_swish(x) = x * torch.sigmoid(x-1)
|
1164 |
+
This is a definition, originally motivated by its close numerical
|
1165 |
+
similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
|
1166 |
+
|
1167 |
+
Memory-efficient derivative computation:
|
1168 |
+
double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
|
1169 |
+
double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
|
1170 |
+
Now, s'(x) = s(x) * (1-s(x)).
|
1171 |
+
double_swish'(x) = x * s'(x) + s(x).
|
1172 |
+
= x * s(x) * (1-s(x)) + s(x).
|
1173 |
+
= double_swish(x) * (1-s(x)) + s(x)
|
1174 |
+
... so we just need to remember s(x) but not x itself.
|
1175 |
+
"""
|
1176 |
+
|
1177 |
+
@staticmethod
|
1178 |
+
def forward(ctx, x: Tensor) -> Tensor:
|
1179 |
+
requires_grad = x.requires_grad
|
1180 |
+
x_dtype = x.dtype
|
1181 |
+
if x.dtype == torch.float16:
|
1182 |
+
x = x.to(torch.float32)
|
1183 |
+
|
1184 |
+
s = torch.sigmoid(x - 1.0)
|
1185 |
+
y = x * s
|
1186 |
+
|
1187 |
+
if requires_grad:
|
1188 |
+
deriv = y * (1 - s) + s
|
1189 |
+
# notes on derivative of x * sigmoid(x - 1):
|
1190 |
+
# https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29
|
1191 |
+
# min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund
|
1192 |
+
# max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound.
|
1193 |
+
# the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which
|
1194 |
+
# floors), should be expectation-preserving.
|
1195 |
+
floor = -0.043637
|
1196 |
+
ceil = 1.2
|
1197 |
+
d_scaled = (deriv - floor) * (
|
1198 |
+
255.0 / (ceil - floor)
|
1199 |
+
) + torch.rand_like(deriv)
|
1200 |
+
if __name__ == "__main__":
|
1201 |
+
# for self-testing only.
|
1202 |
+
assert d_scaled.min() >= 0.0
|
1203 |
+
assert d_scaled.max() < 256.0
|
1204 |
+
d_int = d_scaled.to(torch.uint8)
|
1205 |
+
ctx.save_for_backward(d_int)
|
1206 |
+
if x.dtype == torch.float16 or torch.is_autocast_enabled():
|
1207 |
+
y = y.to(torch.float16)
|
1208 |
+
return y
|
1209 |
+
|
1210 |
+
@staticmethod
|
1211 |
+
def backward(ctx, y_grad: Tensor) -> Tensor:
|
1212 |
+
(d,) = ctx.saved_tensors
|
1213 |
+
# the same constants as used in forward pass.
|
1214 |
+
floor = -0.043637
|
1215 |
+
ceil = 1.2
|
1216 |
+
d = d * ((ceil - floor) / 255.0) + floor
|
1217 |
+
return y_grad * d
|
1218 |
+
|
1219 |
+
|
1220 |
+
class DoubleSwish(torch.nn.Module):
|
1221 |
+
def forward(self, x: Tensor) -> Tensor:
|
1222 |
+
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
1223 |
+
that we approximate closely with x * sigmoid(x-1).
|
1224 |
+
"""
|
1225 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
1226 |
+
return x * torch.sigmoid(x - 1.0)
|
1227 |
+
return DoubleSwishFunction.apply(x)
|
1228 |
+
|
1229 |
+
|
1230 |
+
def BalancedDoubleSwish(
|
1231 |
+
d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25
|
1232 |
+
) -> nn.Sequential:
|
1233 |
+
"""
|
1234 |
+
ActivationBalancer -> DoubleSwish
|
1235 |
+
"""
|
1236 |
+
balancer = ActivationBalancer(
|
1237 |
+
d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob
|
1238 |
+
)
|
1239 |
+
return nn.Sequential(
|
1240 |
+
balancer,
|
1241 |
+
DoubleSwish(),
|
1242 |
+
)
|
1243 |
+
|
1244 |
+
|
1245 |
+
def _test_max_eig():
|
1246 |
+
for proportion in [0.1, 0.5, 10.0]:
|
1247 |
+
logging.info(f"proportion = {proportion}")
|
1248 |
+
x = torch.randn(100, 128)
|
1249 |
+
direction = torch.randn(128)
|
1250 |
+
coeffs = torch.randn(100, 1)
|
1251 |
+
x += proportion * direction * coeffs
|
1252 |
+
|
1253 |
+
x.requires_grad = True
|
1254 |
+
|
1255 |
+
num_channels = 128
|
1256 |
+
m = MaxEig(
|
1257 |
+
num_channels, 1, 0.5, scale=0.1 # channel_dim # max_var_per_eig
|
1258 |
+
) # grad_scale
|
1259 |
+
|
1260 |
+
for _ in range(4):
|
1261 |
+
y = m(x)
|
1262 |
+
|
1263 |
+
y_grad = torch.randn_like(x)
|
1264 |
+
y.backward(gradient=y_grad)
|
1265 |
+
|
1266 |
+
if proportion < 0.2:
|
1267 |
+
assert torch.allclose(x.grad, y_grad, atol=1.0e-02)
|
1268 |
+
elif proportion > 1.0:
|
1269 |
+
assert not torch.allclose(x.grad, y_grad)
|
1270 |
+
|
1271 |
+
|
1272 |
+
def _test_whiten():
|
1273 |
+
for proportion in [0.1, 0.5, 10.0]:
|
1274 |
+
logging.info(f"_test_whiten(): proportion = {proportion}")
|
1275 |
+
x = torch.randn(100, 128)
|
1276 |
+
direction = torch.randn(128)
|
1277 |
+
coeffs = torch.randn(100, 1)
|
1278 |
+
x += proportion * direction * coeffs
|
1279 |
+
|
1280 |
+
x.requires_grad = True
|
1281 |
+
|
1282 |
+
num_channels = 128
|
1283 |
+
m = Whiten(
|
1284 |
+
1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit,
|
1285 |
+
) # grad_scale
|
1286 |
+
|
1287 |
+
for _ in range(4):
|
1288 |
+
y = m(x)
|
1289 |
+
|
1290 |
+
y_grad = torch.randn_like(x)
|
1291 |
+
y.backward(gradient=y_grad)
|
1292 |
+
|
1293 |
+
if proportion < 0.2:
|
1294 |
+
assert torch.allclose(x.grad, y_grad)
|
1295 |
+
elif proportion > 1.0:
|
1296 |
+
assert not torch.allclose(x.grad, y_grad)
|
1297 |
+
|
1298 |
+
|
1299 |
+
def _test_activation_balancer_sign():
|
1300 |
+
probs = torch.arange(0, 1, 0.01)
|
1301 |
+
N = 1000
|
1302 |
+
x = 1.0 * (
|
1303 |
+
(2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0
|
1304 |
+
)
|
1305 |
+
x = x.detach()
|
1306 |
+
x.requires_grad = True
|
1307 |
+
m = ActivationBalancer(
|
1308 |
+
probs.numel(),
|
1309 |
+
channel_dim=0,
|
1310 |
+
min_positive=0.05,
|
1311 |
+
max_positive=0.95,
|
1312 |
+
max_factor=0.2,
|
1313 |
+
min_abs=0.0,
|
1314 |
+
)
|
1315 |
+
|
1316 |
+
y_grad = torch.sign(torch.randn(probs.numel(), N))
|
1317 |
+
|
1318 |
+
y = m(x)
|
1319 |
+
y.backward(gradient=y_grad)
|
1320 |
+
print("_test_activation_balancer_sign: x = ", x)
|
1321 |
+
print("_test_activation_balancer_sign: y grad = ", y_grad)
|
1322 |
+
print("_test_activation_balancer_sign: x grad = ", x.grad)
|
1323 |
+
|
1324 |
+
|
1325 |
+
def _test_activation_balancer_magnitude():
|
1326 |
+
magnitudes = torch.arange(0, 1, 0.01)
|
1327 |
+
N = 1000
|
1328 |
+
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(
|
1329 |
+
-1
|
1330 |
+
)
|
1331 |
+
x = x.detach()
|
1332 |
+
x.requires_grad = True
|
1333 |
+
m = ActivationBalancer(
|
1334 |
+
magnitudes.numel(),
|
1335 |
+
channel_dim=0,
|
1336 |
+
min_positive=0.0,
|
1337 |
+
max_positive=1.0,
|
1338 |
+
max_factor=0.2,
|
1339 |
+
min_abs=0.2,
|
1340 |
+
max_abs=0.8,
|
1341 |
+
min_prob=1.0,
|
1342 |
+
)
|
1343 |
+
|
1344 |
+
y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
|
1345 |
+
|
1346 |
+
y = m(x)
|
1347 |
+
y.backward(gradient=y_grad)
|
1348 |
+
print("_test_activation_balancer_magnitude: x = ", x)
|
1349 |
+
print("_test_activation_balancer_magnitude: y grad = ", y_grad)
|
1350 |
+
print("_test_activation_balancer_magnitude: x grad = ", x.grad)
|
1351 |
+
|
1352 |
+
|
1353 |
+
def _test_basic_norm():
|
1354 |
+
num_channels = 128
|
1355 |
+
m = BasicNorm(num_channels=num_channels, channel_dim=1)
|
1356 |
+
|
1357 |
+
x = torch.randn(500, num_channels)
|
1358 |
+
|
1359 |
+
y = m(x)
|
1360 |
+
|
1361 |
+
assert y.shape == x.shape
|
1362 |
+
x_rms = (x ** 2).mean().sqrt()
|
1363 |
+
y_rms = (y ** 2).mean().sqrt()
|
1364 |
+
print("x rms = ", x_rms)
|
1365 |
+
print("y rms = ", y_rms)
|
1366 |
+
assert y_rms < x_rms
|
1367 |
+
assert y_rms > 0.5 * x_rms
|
1368 |
+
|
1369 |
+
|
1370 |
+
def _test_double_swish_deriv():
|
1371 |
+
x = torch.randn(10, 12, dtype=torch.double) * 3.0
|
1372 |
+
x.requires_grad = True
|
1373 |
+
m = DoubleSwish()
|
1374 |
+
|
1375 |
+
tol = (1.2 - (-0.043637)) / 255.0
|
1376 |
+
torch.autograd.gradcheck(m, x, atol=tol)
|
1377 |
+
|
1378 |
+
# for self-test.
|
1379 |
+
x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
|
1380 |
+
x.requires_grad = True
|
1381 |
+
y = m(x)
|
1382 |
+
|
1383 |
+
|
1384 |
+
def _test_softmax():
|
1385 |
+
a = torch.randn(2, 10, dtype=torch.float64)
|
1386 |
+
b = a.clone()
|
1387 |
+
a.requires_grad = True
|
1388 |
+
b.requires_grad = True
|
1389 |
+
a.softmax(dim=1)[:, 0].sum().backward()
|
1390 |
+
print("a grad = ", a.grad)
|
1391 |
+
softmax(b, dim=1)[:, 0].sum().backward()
|
1392 |
+
print("b grad = ", b.grad)
|
1393 |
+
assert torch.allclose(a.grad, b.grad)
|
1394 |
+
|
1395 |
+
|
1396 |
+
if __name__ == "__main__":
|
1397 |
+
logging.getLogger().setLevel(logging.INFO)
|
1398 |
+
torch.set_num_threads(1)
|
1399 |
+
torch.set_num_interop_threads(1)
|
1400 |
+
_test_softmax()
|
1401 |
+
_test_whiten()
|
1402 |
+
_test_max_eig()
|
1403 |
+
_test_activation_balancer_sign()
|
1404 |
+
_test_activation_balancer_magnitude()
|
1405 |
+
_test_basic_norm()
|
1406 |
+
_test_double_swish_deriv()
|
models/modules/transformer.py
ADDED
@@ -0,0 +1,1089 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy, logging
|
2 |
+
import numbers
|
3 |
+
from functools import partial
|
4 |
+
from typing import Any, Callable, List, Optional, Tuple, Union
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch import Tensor, nn
|
8 |
+
from torch.nn import functional as F
|
9 |
+
|
10 |
+
from .activation import MultiheadAttention
|
11 |
+
from .scaling import ActivationBalancer, BalancedDoubleSwish
|
12 |
+
from .scaling import BasicNorm as _BasicNorm
|
13 |
+
|
14 |
+
_shape_t = Union[int, List[int], torch.Size]
|
15 |
+
|
16 |
+
|
17 |
+
class LayerNorm(nn.Module):
|
18 |
+
__constants__ = ["normalized_shape", "eps", "elementwise_affine"]
|
19 |
+
normalized_shape: Tuple[int, ...]
|
20 |
+
eps: float
|
21 |
+
elementwise_affine: bool
|
22 |
+
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
normalized_shape: _shape_t,
|
26 |
+
eps: float = 1e-5,
|
27 |
+
elementwise_affine: bool = True,
|
28 |
+
device=None,
|
29 |
+
dtype=None,
|
30 |
+
) -> None:
|
31 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
32 |
+
super(LayerNorm, self).__init__()
|
33 |
+
if isinstance(normalized_shape, numbers.Integral):
|
34 |
+
# mypy error: incompatible types in assignment
|
35 |
+
normalized_shape = (normalized_shape,) # type: ignore[assignment]
|
36 |
+
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
|
37 |
+
self.eps = eps
|
38 |
+
self.elementwise_affine = elementwise_affine
|
39 |
+
if self.elementwise_affine:
|
40 |
+
self.weight = nn.Parameter(
|
41 |
+
torch.empty(self.normalized_shape, **factory_kwargs)
|
42 |
+
)
|
43 |
+
self.bias = nn.Parameter(
|
44 |
+
torch.empty(self.normalized_shape, **factory_kwargs)
|
45 |
+
)
|
46 |
+
else:
|
47 |
+
self.register_parameter("weight", None)
|
48 |
+
self.register_parameter("bias", None)
|
49 |
+
|
50 |
+
self.reset_parameters()
|
51 |
+
|
52 |
+
def reset_parameters(self) -> None:
|
53 |
+
if self.elementwise_affine:
|
54 |
+
nn.init.ones_(self.weight)
|
55 |
+
nn.init.zeros_(self.bias)
|
56 |
+
|
57 |
+
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
|
58 |
+
if isinstance(input, tuple):
|
59 |
+
input, embedding = input
|
60 |
+
return (
|
61 |
+
F.layer_norm(
|
62 |
+
input,
|
63 |
+
self.normalized_shape,
|
64 |
+
self.weight,
|
65 |
+
self.bias,
|
66 |
+
self.eps,
|
67 |
+
),
|
68 |
+
embedding,
|
69 |
+
)
|
70 |
+
|
71 |
+
assert embedding is None
|
72 |
+
return F.layer_norm(
|
73 |
+
input, self.normalized_shape, self.weight, self.bias, self.eps
|
74 |
+
)
|
75 |
+
|
76 |
+
def extra_repr(self) -> str:
|
77 |
+
return (
|
78 |
+
"{normalized_shape}, eps={eps}, "
|
79 |
+
"elementwise_affine={elementwise_affine}".format(**self.__dict__)
|
80 |
+
)
|
81 |
+
|
82 |
+
|
83 |
+
class AdaptiveLayerNorm(nn.Module):
|
84 |
+
r"""Adaptive Layer Normalization"""
|
85 |
+
|
86 |
+
def __init__(self, d_model, norm) -> None:
|
87 |
+
super(AdaptiveLayerNorm, self).__init__()
|
88 |
+
self.project_layer = nn.Linear(d_model, 2 * d_model)
|
89 |
+
self.norm = norm
|
90 |
+
self.d_model = d_model
|
91 |
+
self.eps = self.norm.eps
|
92 |
+
|
93 |
+
def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
|
94 |
+
if isinstance(input, tuple):
|
95 |
+
input, embedding = input
|
96 |
+
weight, bias = torch.split(
|
97 |
+
self.project_layer(embedding),
|
98 |
+
split_size_or_sections=self.d_model,
|
99 |
+
dim=-1,
|
100 |
+
)
|
101 |
+
return (weight * self.norm(input) + bias, embedding)
|
102 |
+
|
103 |
+
weight, bias = torch.split(
|
104 |
+
self.project_layer(embedding),
|
105 |
+
split_size_or_sections=self.d_model,
|
106 |
+
dim=-1,
|
107 |
+
)
|
108 |
+
return weight * self.norm(input) + bias
|
109 |
+
|
110 |
+
|
111 |
+
class BasicNorm(_BasicNorm):
|
112 |
+
def __init__(
|
113 |
+
self,
|
114 |
+
d_model: int,
|
115 |
+
eps: float = 1e-5,
|
116 |
+
device=None,
|
117 |
+
dtype=None,
|
118 |
+
):
|
119 |
+
super(BasicNorm, self).__init__(d_model, eps=eps)
|
120 |
+
|
121 |
+
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
|
122 |
+
if isinstance(input, tuple):
|
123 |
+
input, embedding = input
|
124 |
+
return (
|
125 |
+
super(BasicNorm, self).forward(input),
|
126 |
+
embedding,
|
127 |
+
)
|
128 |
+
|
129 |
+
assert embedding is None
|
130 |
+
return super(BasicNorm, self).forward(input)
|
131 |
+
|
132 |
+
|
133 |
+
class BalancedBasicNorm(nn.Module):
|
134 |
+
def __init__(
|
135 |
+
self,
|
136 |
+
d_model: int,
|
137 |
+
eps: float = 1e-5,
|
138 |
+
device=None,
|
139 |
+
dtype=None,
|
140 |
+
):
|
141 |
+
super(BalancedBasicNorm, self).__init__()
|
142 |
+
self.balancer = ActivationBalancer(
|
143 |
+
d_model,
|
144 |
+
channel_dim=-1,
|
145 |
+
min_positive=0.45,
|
146 |
+
max_positive=0.55,
|
147 |
+
max_abs=6.0,
|
148 |
+
)
|
149 |
+
self.norm = BasicNorm(d_model, eps, device=device, dtype=dtype)
|
150 |
+
|
151 |
+
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
|
152 |
+
if isinstance(input, tuple):
|
153 |
+
input, embedding = input
|
154 |
+
return self.norm((self.balancer(input), embedding))
|
155 |
+
|
156 |
+
assert embedding is None
|
157 |
+
return self.norm(self.balancer(input))
|
158 |
+
|
159 |
+
|
160 |
+
class IdentityNorm(nn.Module):
|
161 |
+
def __init__(
|
162 |
+
self,
|
163 |
+
d_model: int,
|
164 |
+
eps: float = 1e-5,
|
165 |
+
device=None,
|
166 |
+
dtype=None,
|
167 |
+
) -> None:
|
168 |
+
super(IdentityNorm, self).__init__()
|
169 |
+
|
170 |
+
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
|
171 |
+
if isinstance(input, tuple):
|
172 |
+
return input
|
173 |
+
|
174 |
+
assert embedding is None
|
175 |
+
return input
|
176 |
+
|
177 |
+
|
178 |
+
class TransformerEncoderLayer(nn.Module):
|
179 |
+
__constants__ = ["batch_first", "norm_first"]
|
180 |
+
|
181 |
+
def __init__(
|
182 |
+
self,
|
183 |
+
d_model: int,
|
184 |
+
nhead: int,
|
185 |
+
dim_feedforward: int = 2048,
|
186 |
+
dropout: float = 0.1,
|
187 |
+
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
|
188 |
+
batch_first: bool = False,
|
189 |
+
norm_first: bool = False,
|
190 |
+
device=None,
|
191 |
+
dtype=None,
|
192 |
+
linear1_self_attention_cls: nn.Module = nn.Linear,
|
193 |
+
linear2_self_attention_cls: nn.Module = nn.Linear,
|
194 |
+
linear1_feedforward_cls: nn.Module = nn.Linear,
|
195 |
+
linear2_feedforward_cls: nn.Module = nn.Linear,
|
196 |
+
layer_norm_cls: nn.Module = LayerNorm,
|
197 |
+
layer_norm_eps: float = 1e-5,
|
198 |
+
adaptive_layer_norm=False,
|
199 |
+
) -> None:
|
200 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
201 |
+
super(TransformerEncoderLayer, self).__init__()
|
202 |
+
self.self_attn = MultiheadAttention(
|
203 |
+
d_model,
|
204 |
+
nhead,
|
205 |
+
dropout=dropout,
|
206 |
+
batch_first=batch_first,
|
207 |
+
linear1_cls=linear1_self_attention_cls,
|
208 |
+
linear2_cls=linear2_self_attention_cls,
|
209 |
+
**factory_kwargs,
|
210 |
+
)
|
211 |
+
|
212 |
+
# Implementation of Feedforward model
|
213 |
+
self.linear1 = linear1_feedforward_cls(
|
214 |
+
d_model, dim_feedforward, **factory_kwargs
|
215 |
+
)
|
216 |
+
self.dropout = nn.Dropout(dropout)
|
217 |
+
self.linear2 = linear2_feedforward_cls(
|
218 |
+
dim_feedforward, d_model, **factory_kwargs
|
219 |
+
)
|
220 |
+
|
221 |
+
self.norm_first = norm_first
|
222 |
+
self.dropout1 = nn.Dropout(dropout)
|
223 |
+
self.dropout2 = nn.Dropout(dropout)
|
224 |
+
|
225 |
+
# Legacy string support for activation function.
|
226 |
+
if isinstance(activation, str):
|
227 |
+
activation = _get_activation_fn(activation)
|
228 |
+
elif isinstance(activation, partial):
|
229 |
+
activation = activation(d_model)
|
230 |
+
elif activation == BalancedDoubleSwish:
|
231 |
+
activation = BalancedDoubleSwish(d_model)
|
232 |
+
|
233 |
+
# # We can't test self.activation in forward() in TorchScript,
|
234 |
+
# # so stash some information about it instead.
|
235 |
+
# if activation is F.relu or isinstance(activation, torch.nn.ReLU):
|
236 |
+
# self.activation_relu_or_gelu = 1
|
237 |
+
# elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
|
238 |
+
# self.activation_relu_or_gelu = 2
|
239 |
+
# else:
|
240 |
+
# self.activation_relu_or_gelu = 0
|
241 |
+
self.activation = activation
|
242 |
+
|
243 |
+
norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
|
244 |
+
if layer_norm_cls == IdentityNorm:
|
245 |
+
norm2 = BalancedBasicNorm(
|
246 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
247 |
+
)
|
248 |
+
else:
|
249 |
+
norm2 = layer_norm_cls(
|
250 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
251 |
+
)
|
252 |
+
|
253 |
+
if adaptive_layer_norm:
|
254 |
+
self.norm1 = AdaptiveLayerNorm(d_model, norm1)
|
255 |
+
self.norm2 = AdaptiveLayerNorm(d_model, norm2)
|
256 |
+
else:
|
257 |
+
self.norm1 = norm1
|
258 |
+
self.norm2 = norm2
|
259 |
+
|
260 |
+
def __setstate__(self, state):
|
261 |
+
super(TransformerEncoderLayer, self).__setstate__(state)
|
262 |
+
if not hasattr(self, "activation"):
|
263 |
+
self.activation = F.relu
|
264 |
+
|
265 |
+
def forward(
|
266 |
+
self,
|
267 |
+
src,
|
268 |
+
src_mask: Optional[Tensor] = None,
|
269 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
270 |
+
need_weights: Optional[bool] = False,
|
271 |
+
past: Optional[Tensor] = None,
|
272 |
+
) -> Tensor:
|
273 |
+
r"""Pass the input through the encoder layer.
|
274 |
+
|
275 |
+
Args:
|
276 |
+
src: the sequence to the encoder layer (required).
|
277 |
+
src_mask: the mask for the src sequence (optional).
|
278 |
+
src_key_padding_mask: the mask for the src keys per batch (optional).
|
279 |
+
|
280 |
+
Shape:
|
281 |
+
see the docs in Transformer class.
|
282 |
+
"""
|
283 |
+
if isinstance(src, dict):
|
284 |
+
sinu = src["sinu"]
|
285 |
+
pm_sinu = src["pm_sinu"]
|
286 |
+
src = src["input"]
|
287 |
+
else:
|
288 |
+
sinu = None
|
289 |
+
pm_sinu = None
|
290 |
+
x, stage_embedding = src, None
|
291 |
+
is_src_tuple = False
|
292 |
+
if isinstance(src, tuple):
|
293 |
+
x, stage_embedding = src
|
294 |
+
is_src_tuple = True
|
295 |
+
|
296 |
+
if src_key_padding_mask is not None:
|
297 |
+
_skpm_dtype = src_key_padding_mask.dtype
|
298 |
+
if _skpm_dtype != torch.bool and not torch.is_floating_point(
|
299 |
+
src_key_padding_mask
|
300 |
+
):
|
301 |
+
raise AssertionError(
|
302 |
+
"only bool and floating types of key_padding_mask are supported"
|
303 |
+
)
|
304 |
+
if need_weights:
|
305 |
+
raise NotImplementedError
|
306 |
+
if self.norm_first:
|
307 |
+
out, attn = self._sa_block_attn(
|
308 |
+
self.norm1(x, stage_embedding),
|
309 |
+
src_mask,
|
310 |
+
src_key_padding_mask,
|
311 |
+
past, sinu = sinu
|
312 |
+
)
|
313 |
+
out, present = out # present is the kvcache of the present timestep
|
314 |
+
x = x + out
|
315 |
+
x = x + self._ff_block(self.norm2(x, stage_embedding))
|
316 |
+
else:
|
317 |
+
out, attn = self._sa_block_attn(x, src_mask, src_key_padding_mask, past, sinu = sinu)
|
318 |
+
out, present = out # present is the kvcache of the present timestep
|
319 |
+
x = self.norm1(
|
320 |
+
x + out,
|
321 |
+
stage_embedding,
|
322 |
+
)
|
323 |
+
x = self.norm2(x + self._ff_block(x), stage_embedding)
|
324 |
+
assert not is_src_tuple
|
325 |
+
# return (x, stage_embedding)
|
326 |
+
return (x, attn)
|
327 |
+
else:
|
328 |
+
if self.norm_first:
|
329 |
+
out = self._sa_block(
|
330 |
+
self.norm1(x, stage_embedding),
|
331 |
+
src_mask,
|
332 |
+
src_key_padding_mask, past, sinu = sinu, q_sinu=pm_sinu['q'], k_sinu=pm_sinu['q']
|
333 |
+
)
|
334 |
+
out, present = out # present is the kvcache of the present timestep
|
335 |
+
x = x + out
|
336 |
+
x = x + self._ff_block(self.norm2(x, stage_embedding))
|
337 |
+
else:
|
338 |
+
out = self._sa_block(x, src_mask, src_key_padding_mask, sinu = sinu, q_sinu=pm_sinu['q'], k_sinu=pm_sinu['q'])
|
339 |
+
out, present = out # present is the kvcache of the present timestep
|
340 |
+
x = self.norm1(
|
341 |
+
x + out,
|
342 |
+
stage_embedding, past
|
343 |
+
)
|
344 |
+
x = self.norm2(x + self._ff_block(x), stage_embedding)
|
345 |
+
|
346 |
+
if is_src_tuple:
|
347 |
+
x = (x, stage_embedding)
|
348 |
+
if present != None:
|
349 |
+
x = [x, present]
|
350 |
+
return x
|
351 |
+
|
352 |
+
# self-attention block
|
353 |
+
def _sa_block(
|
354 |
+
self,
|
355 |
+
x: Tensor,
|
356 |
+
attn_mask: Optional[Tensor],
|
357 |
+
key_padding_mask: Optional[Tensor],
|
358 |
+
past: Optional[Tensor] = None,
|
359 |
+
sinu = None,
|
360 |
+
q_sinu = None,
|
361 |
+
k_sinu = None
|
362 |
+
) -> Tensor:
|
363 |
+
x = self.self_attn(
|
364 |
+
x,
|
365 |
+
x,
|
366 |
+
x,
|
367 |
+
attn_mask=attn_mask,
|
368 |
+
key_padding_mask=key_padding_mask,
|
369 |
+
need_weights=False,
|
370 |
+
past=past,
|
371 |
+
sinu = sinu,
|
372 |
+
q_sinu = q_sinu,
|
373 |
+
k_sinu = k_sinu
|
374 |
+
)
|
375 |
+
x, present = x
|
376 |
+
return self.dropout1(x), present
|
377 |
+
|
378 |
+
# self-attention block, also return attention weights
|
379 |
+
def _sa_block_attn(
|
380 |
+
self,
|
381 |
+
x: Tensor,
|
382 |
+
attn_mask: Optional[Tensor],
|
383 |
+
key_padding_mask: Optional[Tensor],
|
384 |
+
past: Optional[Tensor] = None,
|
385 |
+
) -> Tensor:
|
386 |
+
x, attn = self.self_attn(
|
387 |
+
x,
|
388 |
+
x,
|
389 |
+
x,
|
390 |
+
attn_mask=attn_mask,
|
391 |
+
key_padding_mask=key_padding_mask,
|
392 |
+
need_weights=True,
|
393 |
+
past=past
|
394 |
+
)
|
395 |
+
x, present = x
|
396 |
+
return (self.dropout1(x), present), attn
|
397 |
+
|
398 |
+
# feed forward block
|
399 |
+
def _ff_block(self, x: Tensor) -> Tensor:
|
400 |
+
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
401 |
+
return self.dropout2(x)
|
402 |
+
|
403 |
+
def pre_compute_sinusoidal(dim, base, max_len = 10000): # 4000 max length equivalent of mimi code is 320s, as mimi is 12.5hz
|
404 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
|
405 |
+
position_ids_expanded = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) # [x_len_max, 1]
|
406 |
+
inv_freq_expanded = inv_freq.unsqueeze(0).float() # [1, d//2]
|
407 |
+
freqs = position_ids_expanded @ inv_freq_expanded # [x_len_max, d//2]
|
408 |
+
freqs = torch.cat((freqs, freqs), dim=-1).unsqueeze(0) # [1, x_len_max, d]
|
409 |
+
return {"sin": freqs.sin(), "cos": freqs.cos()}
|
410 |
+
|
411 |
+
def pre_compute_freqs(dim, base, max_len = 10000): # 4000 max length equivalent of mimi code is 320s, as mimi is 12.5hz
|
412 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
|
413 |
+
position_ids_expanded = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) # [x_len_max, 1]
|
414 |
+
inv_freq_expanded = inv_freq.unsqueeze(0).float() # [1, d//2]
|
415 |
+
freqs = position_ids_expanded @ inv_freq_expanded # [x_len_max, d//2]
|
416 |
+
freqs = torch.cat((freqs, freqs), dim=-1).unsqueeze(0) # [1, x_len_max, d]
|
417 |
+
return freqs
|
418 |
+
|
419 |
+
class TransformerEncoder(nn.Module):
|
420 |
+
r"""TransformerEncoder is a stack of N encoder layers. Users can build the
|
421 |
+
BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
|
422 |
+
|
423 |
+
Args:
|
424 |
+
encoder_layer: an instance of the TransformerEncoderLayer() class (required).
|
425 |
+
num_layers: the number of sub-encoder-layers in the encoder (required).
|
426 |
+
norm: the layer normalization component (optional).
|
427 |
+
enable_nested_tensor: if True, input will automatically convert to nested tensor
|
428 |
+
(and convert back on output). This will improve the overall performance of
|
429 |
+
TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
|
430 |
+
|
431 |
+
Examples::
|
432 |
+
>>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
|
433 |
+
>>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)
|
434 |
+
>>> src = torch.rand(10, 32, 512)
|
435 |
+
>>> out = transformer_encoder(src)
|
436 |
+
"""
|
437 |
+
__constants__ = ["norm"]
|
438 |
+
|
439 |
+
def __init__(self, encoder_layer, num_layers, norm=None, rope_base=None, d_model=None, nhead=None, args=None):
|
440 |
+
super(TransformerEncoder, self).__init__()
|
441 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
442 |
+
self.num_layers = num_layers
|
443 |
+
self.norm = norm
|
444 |
+
if args != None:
|
445 |
+
self.progress_no_multiple = args.progress_no_multiple
|
446 |
+
self.progress_scale = args.progress_scale
|
447 |
+
else:
|
448 |
+
self.progress_no_multiple = False
|
449 |
+
self.progress_scale = 1
|
450 |
+
|
451 |
+
if rope_base is not None:
|
452 |
+
if self.progress_no_multiple:
|
453 |
+
self.pm_freqs = pre_compute_freqs(d_model//nhead, rope_base)
|
454 |
+
self.sinu = None
|
455 |
+
else:
|
456 |
+
self.sinu = pre_compute_sinusoidal(d_model/nhead, rope_base)
|
457 |
+
self.pm_freqs = None
|
458 |
+
# logging.info(f"get precomputed sinusoidal for {rope_base=}: {self.sinu=}")
|
459 |
+
else:
|
460 |
+
self.sinu = None
|
461 |
+
self.pm_freqs = None
|
462 |
+
|
463 |
+
def forward(
|
464 |
+
self,
|
465 |
+
src: Tensor,
|
466 |
+
mask: Optional[Tensor] = None,
|
467 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
468 |
+
return_layer_states: bool = False,
|
469 |
+
need_weights:Optional[bool] = False,
|
470 |
+
past: Optional[Tensor] = None,
|
471 |
+
) -> Tensor:
|
472 |
+
r"""Pass the input through the encoder layers in turn.
|
473 |
+
|
474 |
+
Args:
|
475 |
+
src: the sequence to the encoder (required).
|
476 |
+
mask: the mask for the src sequence (optional).
|
477 |
+
src_key_padding_mask: the mask for the src keys per batch (optional).
|
478 |
+
return_layer_states: return layers' state (optional).
|
479 |
+
|
480 |
+
Shape:
|
481 |
+
see the docs in Transformer class.
|
482 |
+
"""
|
483 |
+
if return_layer_states:
|
484 |
+
raise NotImplementedError
|
485 |
+
assert not need_weights
|
486 |
+
layer_states = [] # layers' output
|
487 |
+
output = src
|
488 |
+
for mod in self.layers:
|
489 |
+
output = mod(
|
490 |
+
output,
|
491 |
+
src_mask=mask,
|
492 |
+
src_key_padding_mask=src_key_padding_mask,
|
493 |
+
past=past
|
494 |
+
)
|
495 |
+
layer_states.append(output[0])
|
496 |
+
|
497 |
+
if self.norm is not None:
|
498 |
+
output = self.norm(output)
|
499 |
+
|
500 |
+
return layer_states, output
|
501 |
+
if need_weights:
|
502 |
+
raise NotImplementedError
|
503 |
+
assert not return_layer_states
|
504 |
+
layer_attn = [] # layers' output
|
505 |
+
output = src
|
506 |
+
for mod in self.layers:
|
507 |
+
output = mod(
|
508 |
+
output,
|
509 |
+
src_mask=mask,
|
510 |
+
src_key_padding_mask=src_key_padding_mask,
|
511 |
+
need_weights=True,
|
512 |
+
past=past
|
513 |
+
)
|
514 |
+
layer_attn.append(output[1])
|
515 |
+
|
516 |
+
if self.norm is not None:
|
517 |
+
output = self.norm(output)
|
518 |
+
|
519 |
+
return layer_attn, output
|
520 |
+
|
521 |
+
output = src
|
522 |
+
all_present = []
|
523 |
+
if self.sinu is not None:
|
524 |
+
# use rope
|
525 |
+
assert self.pm_freqs is None
|
526 |
+
for k, v in self.sinu.items():
|
527 |
+
self.sinu[k] = v.to(output.device)
|
528 |
+
if self.pm_freqs is not None:
|
529 |
+
assert self.sinu is None
|
530 |
+
self.pm_freqs = self.pm_freqs.to(output.device)
|
531 |
+
if src_key_padding_mask != None:
|
532 |
+
query_lens = (~src_key_padding_mask).int().sum(-1).to(output.device)
|
533 |
+
else:
|
534 |
+
query_lens = torch.tensor([output.shape[1]]*output.shape[0]).to(output.device)
|
535 |
+
assert query_lens.ndim==1, query_lens
|
536 |
+
q_lens_expanded = query_lens.unsqueeze(-1).unsqueeze(-1) # [B, 1, 1]
|
537 |
+
query_ids_multiple = q_lens_expanded / (q_lens_expanded - 1)
|
538 |
+
q_emb = self.pm_freqs * query_ids_multiple # [B, q_len_max, d]
|
539 |
+
q_emb = q_emb / q_lens_expanded * self.progress_scale
|
540 |
+
q_cos = q_emb.cos().unsqueeze(1) # [B, 1, q_len_max, d] # 1 is for nhead
|
541 |
+
q_sin = q_emb.sin().unsqueeze(1)
|
542 |
+
self.pm_sinu = {"q": {"cos": q_cos, "sin": q_sin}}
|
543 |
+
else:
|
544 |
+
self.pm_sinu = {"q": None}
|
545 |
+
|
546 |
+
output = {"input": output, "sinu": self.sinu, "pm_sinu": self.pm_sinu}
|
547 |
+
for n_layer, mod in enumerate(self.layers):
|
548 |
+
output = mod(
|
549 |
+
output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, past=None if past is None else past[n_layer]
|
550 |
+
)
|
551 |
+
if isinstance(output, list):
|
552 |
+
output, present = output
|
553 |
+
all_present.append(present)
|
554 |
+
if self.sinu is not None or self.pm_sinu is not None:
|
555 |
+
output = {"input": output, "sinu": self.sinu, "pm_sinu": self.pm_sinu}
|
556 |
+
if self.sinu is not None or self.pm_sinu is not None:
|
557 |
+
output = output["input"]
|
558 |
+
if self.norm is not None:
|
559 |
+
output = self.norm(output)
|
560 |
+
if all_present != []:
|
561 |
+
all_present = torch.stack(all_present, dim=0) # (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
|
562 |
+
output = [output, all_present]
|
563 |
+
return output
|
564 |
+
|
565 |
+
|
566 |
+
class TransformerDecoderLayer(nn.Module):
|
567 |
+
__constants__ = ["batch_first", "norm_first"]
|
568 |
+
|
569 |
+
def __init__(
|
570 |
+
self,
|
571 |
+
d_model: int,
|
572 |
+
nhead: int,
|
573 |
+
dim_feedforward: int = 2048,
|
574 |
+
dropout: float = 0.1,
|
575 |
+
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
|
576 |
+
linear1_self_attention_cls: nn.Module = nn.Linear,
|
577 |
+
linear2_self_attention_cls: nn.Module = nn.Linear,
|
578 |
+
linear1_feedforward_cls: nn.Module = nn.Linear,
|
579 |
+
linear2_feedforward_cls: nn.Module = nn.Linear,
|
580 |
+
batch_first: bool = False,
|
581 |
+
norm_first: bool = False,
|
582 |
+
device=None,
|
583 |
+
dtype=None,
|
584 |
+
layer_norm_cls: nn.Module = LayerNorm,
|
585 |
+
layer_norm_eps: float = 1e-5,
|
586 |
+
adaptive_layer_norm=False,
|
587 |
+
) -> None:
|
588 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
589 |
+
super(TransformerDecoderLayer, self).__init__()
|
590 |
+
self.self_attn = MultiheadAttention(
|
591 |
+
d_model,
|
592 |
+
nhead,
|
593 |
+
dropout=dropout,
|
594 |
+
batch_first=batch_first,
|
595 |
+
linear1_cls=linear1_self_attention_cls,
|
596 |
+
linear2_cls=linear2_self_attention_cls,
|
597 |
+
**factory_kwargs,
|
598 |
+
)
|
599 |
+
self.multihead_attn = MultiheadAttention(
|
600 |
+
d_model,
|
601 |
+
nhead,
|
602 |
+
dropout=dropout,
|
603 |
+
batch_first=batch_first,
|
604 |
+
linear1_cls=linear1_self_attention_cls,
|
605 |
+
linear2_cls=linear2_self_attention_cls,
|
606 |
+
**factory_kwargs,
|
607 |
+
)
|
608 |
+
# Implementation of Feedforward model
|
609 |
+
self.linear1 = linear1_feedforward_cls(
|
610 |
+
d_model, dim_feedforward, **factory_kwargs
|
611 |
+
)
|
612 |
+
self.dropout = nn.Dropout(dropout)
|
613 |
+
self.linear2 = linear2_feedforward_cls(
|
614 |
+
dim_feedforward, d_model, **factory_kwargs
|
615 |
+
)
|
616 |
+
|
617 |
+
self.norm_first = norm_first
|
618 |
+
self.dropout1 = nn.Dropout(dropout)
|
619 |
+
self.dropout2 = nn.Dropout(dropout)
|
620 |
+
self.dropout3 = nn.Dropout(dropout)
|
621 |
+
|
622 |
+
# Legacy string support for activation function.
|
623 |
+
if isinstance(activation, str):
|
624 |
+
self.activation = _get_activation_fn(activation)
|
625 |
+
elif isinstance(activation, partial):
|
626 |
+
self.activation = activation(d_model)
|
627 |
+
elif activation == BalancedDoubleSwish:
|
628 |
+
self.activation = BalancedDoubleSwish(d_model)
|
629 |
+
else:
|
630 |
+
self.activation = activation
|
631 |
+
|
632 |
+
if adaptive_layer_norm:
|
633 |
+
norm1 = layer_norm_cls(
|
634 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
635 |
+
)
|
636 |
+
norm2 = layer_norm_cls(
|
637 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
638 |
+
)
|
639 |
+
norm3 = layer_norm_cls(
|
640 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
641 |
+
)
|
642 |
+
|
643 |
+
self.norm1 = AdaptiveLayerNorm(d_model, norm1)
|
644 |
+
self.norm2 = AdaptiveLayerNorm(d_model, norm2)
|
645 |
+
self.norm3 = AdaptiveLayerNorm(d_model, norm3)
|
646 |
+
else:
|
647 |
+
self.norm1 = layer_norm_cls(
|
648 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
649 |
+
)
|
650 |
+
self.norm2 = layer_norm_cls(
|
651 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
652 |
+
)
|
653 |
+
if layer_norm_cls == IdentityNorm:
|
654 |
+
self.norm3 = BalancedBasicNorm(
|
655 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
656 |
+
)
|
657 |
+
else:
|
658 |
+
self.norm3 = layer_norm_cls(
|
659 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
660 |
+
)
|
661 |
+
|
662 |
+
def forward(
|
663 |
+
self,
|
664 |
+
tgt: Tensor,
|
665 |
+
memory: Tensor,
|
666 |
+
tgt_mask: Optional[Tensor] = None,
|
667 |
+
memory_mask: Optional[Tensor] = None,
|
668 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
669 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
670 |
+
tgt_is_causal: Optional[bool] = False, # for compatibility with the nn.TransformerDecoder, not used
|
671 |
+
memory_is_causal: Optional[bool] = False, # for compatibility with the nn.TransformerDecoder, not used
|
672 |
+
past: Optional[Tensor] = None,
|
673 |
+
) -> Tensor:
|
674 |
+
r"""Pass the inputs (and mask) through the decoder layer.
|
675 |
+
|
676 |
+
Args:
|
677 |
+
tgt: the sequence to the decoder layer (required).
|
678 |
+
memory: the sequence from the last layer of the encoder (required).
|
679 |
+
tgt_mask: the mask for the tgt sequence (optional).
|
680 |
+
memory_mask: the mask for the memory sequence (optional).
|
681 |
+
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
|
682 |
+
memory_key_padding_mask: the mask for the memory keys per batch (optional).
|
683 |
+
past: the previous kvcache of the decoder (optional). shape: (2, batch_size, num_heads, seq_len, head_dim)
|
684 |
+
|
685 |
+
Shape:
|
686 |
+
see the docs in Transformer class.
|
687 |
+
"""
|
688 |
+
if isinstance(tgt, dict):
|
689 |
+
pm_sinu = tgt["pm_sinu"]
|
690 |
+
sinu = tgt["sinu"]
|
691 |
+
args = tgt["args"]
|
692 |
+
tgt = tgt["input"]
|
693 |
+
else:
|
694 |
+
pm_sinu = None
|
695 |
+
sinu = None
|
696 |
+
args = None
|
697 |
+
tgt_is_tuple = False
|
698 |
+
if isinstance(tgt, tuple):
|
699 |
+
x, stage_embedding = tgt
|
700 |
+
tgt_is_tuple = True
|
701 |
+
else:
|
702 |
+
x, stage_embedding = tgt, None
|
703 |
+
# logging.info(f"{tgt_key_padding_mask=}, {memory_key_padding_mask=}")
|
704 |
+
# logging.info(f"{tgt_key_padding_mask.shape=}, {memory_key_padding_mask.shape=}")
|
705 |
+
# logging.info(f"{query_lens=}, {key_lens=}")
|
706 |
+
|
707 |
+
# past stores the kvcache for self-attention, and it can also be used to infer q_offset
|
708 |
+
if past is not None and past.ndim > 2:
|
709 |
+
q_offset = past[0].shape[-2] # past is (2, batch_size, num_heads, seq_len, head_dim), 2 contains [k, v], these are for self-attn, therefore also reflect the length of q
|
710 |
+
else:
|
711 |
+
q_offset = 0
|
712 |
+
|
713 |
+
|
714 |
+
if self.norm_first:
|
715 |
+
temp = self._sa_block(
|
716 |
+
self.norm1(x, stage_embedding), tgt_mask, tgt_key_padding_mask, q_sinu=pm_sinu['q'], k_sinu=pm_sinu['q'], sinu=sinu, args = args, past=past, q_offset=q_offset
|
717 |
+
)
|
718 |
+
present = temp[1]
|
719 |
+
x = x + temp[0]
|
720 |
+
cross_out = self._mha_block(
|
721 |
+
self.norm2(x, stage_embedding),
|
722 |
+
memory,
|
723 |
+
memory_mask,
|
724 |
+
memory_key_padding_mask, q_sinu=pm_sinu['q'], k_sinu=pm_sinu['k'], sinu=sinu, args = args, q_offset=q_offset
|
725 |
+
)
|
726 |
+
if isinstance(cross_out, dict):
|
727 |
+
attention_weights = cross_out["attention_weights"]
|
728 |
+
cross_out = cross_out["x"]
|
729 |
+
else:
|
730 |
+
attention_weights = None
|
731 |
+
x = x + cross_out
|
732 |
+
x = x + self._ff_block(self.norm3(x, stage_embedding))
|
733 |
+
else:
|
734 |
+
temp = self._sa_block(x, tgt_mask, tgt_key_padding_mask, q_sinu=pm_sinu['q'], k_sinu=pm_sinu['q'], sinu=sinu, args = args, past=past, q_offset=q_offset)
|
735 |
+
present = temp[1]
|
736 |
+
x = self.norm1(
|
737 |
+
x + temp[0],
|
738 |
+
stage_embedding,
|
739 |
+
)
|
740 |
+
cross_out = self._mha_block(
|
741 |
+
x, memory, memory_mask, memory_key_padding_mask, q_sinu=pm_sinu['q'], k_sinu=pm_sinu['k'], sinu=sinu, args=args, q_offset=q_offset
|
742 |
+
)
|
743 |
+
if isinstance(cross_out, dict):
|
744 |
+
attention_weights = cross_out["attention_weights"]
|
745 |
+
cross_out = cross_out["x"]
|
746 |
+
else:
|
747 |
+
attention_weights = None
|
748 |
+
x = self.norm2(
|
749 |
+
x
|
750 |
+
+ cross_out,
|
751 |
+
stage_embedding,
|
752 |
+
)
|
753 |
+
x = self.norm3(x + self._ff_block(x), stage_embedding)
|
754 |
+
|
755 |
+
if attention_weights is not None:
|
756 |
+
x = {"x": x, "attention_weights": attention_weights}
|
757 |
+
if tgt_is_tuple:
|
758 |
+
x = (x, stage_embedding)
|
759 |
+
if present != None:
|
760 |
+
x = [x, present]
|
761 |
+
return x
|
762 |
+
|
763 |
+
# self-attention block
|
764 |
+
def _sa_block(
|
765 |
+
self,
|
766 |
+
x: Tensor,
|
767 |
+
attn_mask: Optional[Tensor],
|
768 |
+
key_padding_mask: Optional[Tensor],
|
769 |
+
q_sinu=None,
|
770 |
+
k_sinu=None,
|
771 |
+
sinu = None,
|
772 |
+
args = None,
|
773 |
+
past = None,
|
774 |
+
q_offset = 0
|
775 |
+
) -> Tensor:
|
776 |
+
# if past is not None and past.ndim > 2:
|
777 |
+
# print(f"self-attn, k len: {past[0].shape[-2] + x.shape[-2]}, q len: {x.shape[-2]} q_offset: {q_offset}")
|
778 |
+
# else:
|
779 |
+
# print(f"self-attn, k len: {x.shape[-2]}, q len: {x.shape[-2]} q_offset: {q_offset}")
|
780 |
+
x = self.self_attn(
|
781 |
+
x,
|
782 |
+
x,
|
783 |
+
x,
|
784 |
+
attn_mask=attn_mask,
|
785 |
+
key_padding_mask=key_padding_mask,
|
786 |
+
need_weights=False,
|
787 |
+
q_sinu = q_sinu,
|
788 |
+
k_sinu = k_sinu,
|
789 |
+
sinu = sinu,
|
790 |
+
past = past,
|
791 |
+
q_offset = q_offset
|
792 |
+
)
|
793 |
+
x, present = x
|
794 |
+
return self.dropout1(x), present
|
795 |
+
|
796 |
+
# multihead attention block
|
797 |
+
def _mha_block(
|
798 |
+
self,
|
799 |
+
x: Tensor,
|
800 |
+
mem: Tensor,
|
801 |
+
attn_mask: Optional[Tensor],
|
802 |
+
key_padding_mask: Optional[Tensor],
|
803 |
+
q_sinu = None,
|
804 |
+
k_sinu = None,
|
805 |
+
sinu = None,
|
806 |
+
args = None,
|
807 |
+
q_offset = 0
|
808 |
+
) -> Tensor:
|
809 |
+
# print(f"cross-attn, k len: {mem.shape[-2]}, q len: {x.shape[-2]} q_offset: {q_offset}")
|
810 |
+
x = self.multihead_attn(
|
811 |
+
x,
|
812 |
+
mem,
|
813 |
+
mem,
|
814 |
+
attn_mask=attn_mask,
|
815 |
+
key_padding_mask=key_padding_mask,
|
816 |
+
need_weights=False,
|
817 |
+
q_sinu = q_sinu,
|
818 |
+
k_sinu = k_sinu,
|
819 |
+
sinu = sinu,
|
820 |
+
args = args,
|
821 |
+
q_offset = q_offset
|
822 |
+
)
|
823 |
+
if len(x) == 2 and isinstance(x[0], dict) and "attention_weights" in x[0]:
|
824 |
+
x, present = x
|
825 |
+
attention_weights = x['attention_weights']
|
826 |
+
x = x['attn_output']
|
827 |
+
return {"x": self.dropout2(x), "attention_weights": attention_weights}
|
828 |
+
elif len(x) == 2:
|
829 |
+
x = x[0]
|
830 |
+
return self.dropout2(x)
|
831 |
+
|
832 |
+
# feed forward block
|
833 |
+
def _ff_block(self, x: Tensor) -> Tensor:
|
834 |
+
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
835 |
+
return self.dropout3(x)
|
836 |
+
|
837 |
+
|
838 |
+
def _get_clones(module, N):
|
839 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
840 |
+
|
841 |
+
|
842 |
+
def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
|
843 |
+
if activation == "relu":
|
844 |
+
return F.relu
|
845 |
+
elif activation == "gelu":
|
846 |
+
return F.gelu
|
847 |
+
|
848 |
+
raise RuntimeError(
|
849 |
+
"activation should be relu/gelu, not {}".format(activation)
|
850 |
+
)
|
851 |
+
def _generate_square_subsequent_mask(
|
852 |
+
sz: int,
|
853 |
+
device: Optional[torch.device] = None,
|
854 |
+
dtype: Optional[torch.dtype] = None,
|
855 |
+
) -> Tensor:
|
856 |
+
r"""Generate a square causal mask for the sequence.
|
857 |
+
|
858 |
+
The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0).
|
859 |
+
"""
|
860 |
+
if device is None:
|
861 |
+
device = torch.device('cpu')
|
862 |
+
if dtype is None:
|
863 |
+
dtype = torch.float32
|
864 |
+
return torch.triu(
|
865 |
+
torch.full((sz, sz), float('-inf'), dtype=dtype, device=device),
|
866 |
+
diagonal=1,
|
867 |
+
)
|
868 |
+
def _get_seq_len(
|
869 |
+
src: Tensor,
|
870 |
+
batch_first: bool
|
871 |
+
) -> Optional[int]:
|
872 |
+
|
873 |
+
if src.is_nested:
|
874 |
+
return None
|
875 |
+
else:
|
876 |
+
src_size = src.size()
|
877 |
+
if len(src_size) == 2:
|
878 |
+
# unbatched: S, E
|
879 |
+
return src_size[0]
|
880 |
+
else:
|
881 |
+
# batched: B, S, E if batch_first else S, B, E
|
882 |
+
seq_len_pos = 1 if batch_first else 0
|
883 |
+
return src_size[seq_len_pos]
|
884 |
+
|
885 |
+
def _detect_is_causal_mask(
|
886 |
+
mask: Optional[Tensor],
|
887 |
+
is_causal: Optional[bool] = None,
|
888 |
+
size: Optional[int] = None,
|
889 |
+
) -> bool:
|
890 |
+
"""Return whether the given attention mask is causal.
|
891 |
+
|
892 |
+
Warning:
|
893 |
+
If ``is_causal`` is not ``None``, its value will be returned as is. If a
|
894 |
+
user supplies an incorrect ``is_causal`` hint,
|
895 |
+
|
896 |
+
``is_causal=False`` when the mask is in fact a causal attention.mask
|
897 |
+
may lead to reduced performance relative to what would be achievable
|
898 |
+
with ``is_causal=True``;
|
899 |
+
``is_causal=True`` when the mask is in fact not a causal attention.mask
|
900 |
+
may lead to incorrect and unpredictable execution - in some scenarios,
|
901 |
+
a causal mask may be applied based on the hint, in other execution
|
902 |
+
scenarios the specified mask may be used. The choice may not appear
|
903 |
+
to be deterministic, in that a number of factors like alignment,
|
904 |
+
hardware SKU, etc influence the decision whether to use a mask or
|
905 |
+
rely on the hint.
|
906 |
+
``size`` if not None, check whether the mask is a causal mask of the provided size
|
907 |
+
Otherwise, checks for any causal mask.
|
908 |
+
"""
|
909 |
+
# Prevent type refinement
|
910 |
+
make_causal = (is_causal is True)
|
911 |
+
|
912 |
+
if is_causal is None and mask is not None:
|
913 |
+
sz = size if size is not None else mask.size(-2)
|
914 |
+
causal_comparison = _generate_square_subsequent_mask(
|
915 |
+
sz, device=mask.device, dtype=mask.dtype)
|
916 |
+
|
917 |
+
# Do not use `torch.equal` so we handle batched masks by
|
918 |
+
# broadcasting the comparison.
|
919 |
+
if mask.size() == causal_comparison.size():
|
920 |
+
make_causal = bool((mask == causal_comparison).all())
|
921 |
+
else:
|
922 |
+
make_causal = False
|
923 |
+
|
924 |
+
return make_causal
|
925 |
+
|
926 |
+
class TransformerDecoder(nn.Module):
|
927 |
+
r"""TransformerDecoder is a stack of N decoder layers.
|
928 |
+
|
929 |
+
Args:
|
930 |
+
decoder_layer: an instance of the TransformerDecoderLayer() class (required).
|
931 |
+
num_layers: the number of sub-decoder-layers in the decoder (required).
|
932 |
+
norm: the layer normalization component (optional).
|
933 |
+
|
934 |
+
Examples::
|
935 |
+
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
|
936 |
+
>>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
|
937 |
+
>>> memory = torch.rand(10, 32, 512)
|
938 |
+
>>> tgt = torch.rand(20, 32, 512)
|
939 |
+
>>> out = transformer_decoder(tgt, memory)
|
940 |
+
"""
|
941 |
+
|
942 |
+
__constants__ = ['norm']
|
943 |
+
|
944 |
+
def __init__(
|
945 |
+
self,
|
946 |
+
decoder_layer: "TransformerDecoderLayer",
|
947 |
+
num_layers: int,
|
948 |
+
norm: Optional[nn.Module] = None,
|
949 |
+
rope_base=None,
|
950 |
+
d_model=None,
|
951 |
+
nhead=None,
|
952 |
+
args=None
|
953 |
+
) -> None:
|
954 |
+
super().__init__()
|
955 |
+
torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
|
956 |
+
self.layers = _get_clones(decoder_layer, num_layers)
|
957 |
+
self.num_layers = num_layers
|
958 |
+
self.norm = norm
|
959 |
+
self.args = args
|
960 |
+
if getattr(self.args, 'decoder_regular_rope', False):
|
961 |
+
self.sinu = pre_compute_sinusoidal(d_model/nhead, rope_base)
|
962 |
+
self.pm_freqs = None
|
963 |
+
else:
|
964 |
+
self.sinu = None
|
965 |
+
if rope_base is not None:
|
966 |
+
self.pm_freqs = pre_compute_freqs(d_model/nhead, rope_base)
|
967 |
+
# logging.info(f"get precomputed freqs for {rope_base=}: {self.freqs=}")
|
968 |
+
else:
|
969 |
+
self.pm_freqs = None
|
970 |
+
self.progress_scale = getattr(self.args, "progress_scale", 1.0)
|
971 |
+
|
972 |
+
def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
|
973 |
+
memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
|
974 |
+
memory_key_padding_mask: Optional[Tensor] = None, tgt_is_causal: Optional[bool] = None,
|
975 |
+
memory_is_causal: bool = False, query_lens: Optional[Tensor] = None, key_lens: Optional[Tensor] = None, past: Optional[Tensor] = None) -> Tensor:
|
976 |
+
r"""Pass the inputs (and mask) through the decoder layer in turn.
|
977 |
+
|
978 |
+
Args:
|
979 |
+
tgt: the sequence to the decoder (required).
|
980 |
+
memory: the sequence from the last layer of the encoder (required).
|
981 |
+
tgt_mask: the mask for the tgt sequence (optional).
|
982 |
+
memory_mask: the mask for the memory sequence (optional).
|
983 |
+
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
|
984 |
+
memory_key_padding_mask: the mask for the memory keys per batch (optional).
|
985 |
+
tgt_is_causal: If specified, applies a causal mask as ``tgt mask``.
|
986 |
+
Default: ``None``; try to detect a causal mask.
|
987 |
+
Warning:
|
988 |
+
``tgt_is_causal`` provides a hint that ``tgt_mask`` is
|
989 |
+
the causal mask. Providing incorrect hints can result in
|
990 |
+
incorrect execution, including forward and backward
|
991 |
+
compatibility.
|
992 |
+
memory_is_causal: If specified, applies a causal mask as
|
993 |
+
``memory mask``.
|
994 |
+
Default: ``False``.
|
995 |
+
Warning:
|
996 |
+
``memory_is_causal`` provides a hint that
|
997 |
+
``memory_mask`` is the causal mask. Providing incorrect
|
998 |
+
hints can result in incorrect execution, including
|
999 |
+
forward and backward compatibility.
|
1000 |
+
|
1001 |
+
Shape:
|
1002 |
+
see the docs in :class:`~torch.nn.Transformer`.
|
1003 |
+
"""
|
1004 |
+
output = tgt
|
1005 |
+
|
1006 |
+
# seq_len = _get_seq_len(tgt, self.layers[0].self_attn.batch_first)
|
1007 |
+
# tgt_is_causal = _detect_is_causal_mask(tgt_mask, tgt_is_causal, seq_len)
|
1008 |
+
if self.sinu is not None:
|
1009 |
+
assert self.pm_freqs is None
|
1010 |
+
for key in self.sinu:
|
1011 |
+
self.sinu[key] = self.sinu[key].to(output.device)
|
1012 |
+
if self.pm_freqs is not None:
|
1013 |
+
assert self.sinu is None
|
1014 |
+
if not self.training and hasattr(self, "pm_sinu") and past is not None and past[0].ndim > 2: # inference mode, will use cached sinu for the same example
|
1015 |
+
assert self.pm_sinu['q'] is not None and self.pm_sinu['k'] is not None
|
1016 |
+
# check batch size, need to modify the batch size if we use multi_trial during inference
|
1017 |
+
if self.pm_sinu['q']['cos'].shape[0] != tgt.shape[0]:
|
1018 |
+
if self.pm_sinu['q']['cos'].shape[0] > tgt.shape[0]:
|
1019 |
+
self.pm_sinu['q']['cos'] = self.pm_sinu['q']['cos'][:tgt.shape[0]]
|
1020 |
+
self.pm_sinu['q']['sin'] = self.pm_sinu['q']['sin'][:tgt.shape[0]]
|
1021 |
+
self.pm_sinu['k']['cos'] = self.pm_sinu['k']['cos'][:tgt.shape[0]]
|
1022 |
+
self.pm_sinu['k']['sin'] = self.pm_sinu['k']['sin'][:tgt.shape[0]]
|
1023 |
+
else:
|
1024 |
+
assert self.pm_sinu['q']['cos'].shape[0] == 1
|
1025 |
+
self.pm_sinu['q']['cos'] = self.pm_sinu['q']['cos'].repeat(tgt.shape[0], 1, 1, 1)
|
1026 |
+
self.pm_sinu['q']['sin'] = self.pm_sinu['q']['sin'].repeat(tgt.shape[0], 1, 1, 1)
|
1027 |
+
self.pm_sinu['k']['cos'] = self.pm_sinu['k']['cos'].repeat(tgt.shape[0], 1, 1, 1)
|
1028 |
+
self.pm_sinu['k']['sin'] = self.pm_sinu['k']['sin'].repeat(tgt.shape[0], 1, 1, 1)
|
1029 |
+
pass
|
1030 |
+
else:
|
1031 |
+
self.pm_freqs = self.pm_freqs.to(output.device)
|
1032 |
+
if query_lens is None:
|
1033 |
+
query_lens = (~tgt_key_padding_mask).int().sum(-1).to(tgt.device)
|
1034 |
+
if key_lens is None:
|
1035 |
+
key_lens = (~memory_key_padding_mask).int().sum(-1).to(tgt.device)
|
1036 |
+
assert key_lens.ndim==1, key_lens
|
1037 |
+
assert query_lens.ndim==1, query_lens
|
1038 |
+
q_lens_expanded = query_lens.unsqueeze(-1).unsqueeze(-1) # [B, 1, 1]
|
1039 |
+
k_lens_expanded = key_lens.unsqueeze(-1).unsqueeze(-1) # [B, 1, 1]
|
1040 |
+
query_ids_multiple = q_lens_expanded / (q_lens_expanded - 1)
|
1041 |
+
key_ids_multiple = k_lens_expanded / (k_lens_expanded - 1)
|
1042 |
+
q_emb = self.pm_freqs * query_ids_multiple # [B, q_len_max, d]
|
1043 |
+
k_emb = self.pm_freqs * key_ids_multiple # [B, k_len_max, d]
|
1044 |
+
q_emb = q_emb / q_lens_expanded * self.progress_scale
|
1045 |
+
k_emb = k_emb / k_lens_expanded * self.progress_scale
|
1046 |
+
q_cos = q_emb.cos().unsqueeze(1) # [B, 1, q_len_max, d] # 1 is for nhead
|
1047 |
+
q_sin = q_emb.sin().unsqueeze(1)
|
1048 |
+
k_cos = k_emb.cos().unsqueeze(1)
|
1049 |
+
k_sin = k_emb.sin().unsqueeze(1)
|
1050 |
+
self.pm_sinu = {"q": {"cos": q_cos, "sin": q_sin}, "k": {"cos": k_cos, "sin": k_sin}}
|
1051 |
+
else:
|
1052 |
+
self.pm_sinu = {"q": None, "k": None}
|
1053 |
+
|
1054 |
+
output = {"input": output, "pm_sinu": self.pm_sinu, "sinu": self.sinu, "args": self.args}
|
1055 |
+
if past != None:
|
1056 |
+
all_present = []
|
1057 |
+
if self.training and getattr(self.args, "attention_alignment_loss", 0):
|
1058 |
+
all_attn_weights = []
|
1059 |
+
for i, mod in enumerate(self.layers):
|
1060 |
+
output = mod(output, memory, tgt_mask=tgt_mask,
|
1061 |
+
memory_mask=memory_mask,
|
1062 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
1063 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
1064 |
+
past=past[i] if past != None else None
|
1065 |
+
# tgt_is_causal=tgt_is_causal,
|
1066 |
+
# memory_is_causal=memory_is_causal
|
1067 |
+
)
|
1068 |
+
if past != None:
|
1069 |
+
output, cur_present = output
|
1070 |
+
all_present.append(cur_present)
|
1071 |
+
if isinstance(output, dict):
|
1072 |
+
current_attn_weights = output["attention_weights"]
|
1073 |
+
all_attn_weights.append(current_attn_weights)
|
1074 |
+
output = output["x"]
|
1075 |
+
if self.sinu is not None or self.pm_sinu is not None:
|
1076 |
+
output = {"input": output, "pm_sinu": self.pm_sinu, "sinu": self.sinu, "args": self.args}
|
1077 |
+
if self.pm_sinu is not None or self.sinu is not None:
|
1078 |
+
output = output["input"]
|
1079 |
+
if self.norm is not None:
|
1080 |
+
output = self.norm(output)
|
1081 |
+
if self.training and getattr(self.args, "attention_alignment_loss", 0):
|
1082 |
+
assert len(all_attn_weights) == self.num_layers, f"{len(all_attn_weights)=}, {self.num_layers=}"
|
1083 |
+
output = {"output": output, "attention_weights": all_attn_weights}
|
1084 |
+
if past != None:
|
1085 |
+
all_present = torch.stack(all_present, dim=0)
|
1086 |
+
output = [output, all_present]
|
1087 |
+
else:
|
1088 |
+
output = [output, None]
|
1089 |
+
return output
|
models/modules/utils.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
5 |
+
"""
|
6 |
+
Args:
|
7 |
+
lengths:
|
8 |
+
A 1-D tensor containing sentence lengths.
|
9 |
+
max_len:
|
10 |
+
The length of masks.
|
11 |
+
Returns:
|
12 |
+
Return a 2-D bool tensor, where masked positions
|
13 |
+
are filled with `True` and non-masked positions are
|
14 |
+
filled with `False`.
|
15 |
+
|
16 |
+
>>> lengths = torch.tensor([1, 3, 2, 5])
|
17 |
+
>>> make_pad_mask(lengths)
|
18 |
+
tensor([[False, True, True, True, True],
|
19 |
+
[False, False, False, True, True],
|
20 |
+
[False, False, True, True, True],
|
21 |
+
[False, False, False, False, False]])
|
22 |
+
"""
|
23 |
+
assert lengths.ndim == 1, lengths.ndim
|
24 |
+
max_len = max(max_len, lengths.max())
|
25 |
+
n = lengths.size(0)
|
26 |
+
seq_range = torch.arange(0, max_len, device=lengths.device)
|
27 |
+
expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
|
28 |
+
|
29 |
+
return expaned_lengths >= lengths.unsqueeze(-1)
|
30 |
+
|
31 |
+
def generate_partial_autoregressive_mask(sz, start, end):
|
32 |
+
mask = torch.zeros(sz, sz).bool()
|
33 |
+
mask[start:end, start:end] = torch.triu(torch.ones(end-start, end-start,dtype=torch.bool), diagonal=1)
|
34 |
+
mask[:start, start:end] = True
|
35 |
+
mask[end:, start:end] = True
|
36 |
+
return mask
|
models/modules/visualizer.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/models/visualizer.py
|
2 |
+
#!/usr/bin/env python3
|
3 |
+
# Copyright 2023 (authors: Feiteng Li)
|
4 |
+
#
|
5 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
6 |
+
#
|
7 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
8 |
+
# you may not use this file except in compliance with the License.
|
9 |
+
# You may obtain a copy of the License at
|
10 |
+
#
|
11 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
12 |
+
#
|
13 |
+
# Unless required by applicable law or agreed to in writing, software
|
14 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
15 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
16 |
+
# See the License for the specific language governing permissions and
|
17 |
+
# limitations under the License.
|
18 |
+
|
19 |
+
|
20 |
+
from typing import Dict, List, Tuple, Union
|
21 |
+
|
22 |
+
import matplotlib.pyplot as plt
|
23 |
+
import numpy as np
|
24 |
+
import torch
|
25 |
+
|
26 |
+
|
27 |
+
def visualize(
|
28 |
+
predicts: Tuple[torch.Tensor],
|
29 |
+
batch: Dict[str, Union[List, torch.Tensor]],
|
30 |
+
output_dir: str,
|
31 |
+
limit: int = 4,
|
32 |
+
) -> None:
|
33 |
+
text_tokens = batch["text_tokens"].to("cpu").detach().numpy()
|
34 |
+
text_tokens_lens = batch["text_tokens_lens"].to("cpu").detach().numpy()
|
35 |
+
audio_features = batch["audio_features"].to("cpu").detach().numpy()
|
36 |
+
audio_features_lens = (
|
37 |
+
batch["audio_features_lens"].to("cpu").detach().numpy()
|
38 |
+
)
|
39 |
+
assert text_tokens.ndim == 2
|
40 |
+
|
41 |
+
utt_ids, texts = batch["utt_id"], batch["text"]
|
42 |
+
|
43 |
+
encoder_outputs = predicts[0].to("cpu").type(torch.float32).detach().numpy()
|
44 |
+
decoder_outputs = predicts[1]
|
45 |
+
if isinstance(decoder_outputs, list):
|
46 |
+
decoder_outputs = decoder_outputs[-1]
|
47 |
+
decoder_outputs = (
|
48 |
+
decoder_outputs.to("cpu").type(torch.float32).detach().numpy()
|
49 |
+
)
|
50 |
+
|
51 |
+
vmin, vmax = 0, 1024 # Encodec
|
52 |
+
if decoder_outputs.dtype == np.float32:
|
53 |
+
vmin, vmax = -6, 0 # Fbank
|
54 |
+
|
55 |
+
num_figures = 3
|
56 |
+
for b, (utt_id, text) in enumerate(zip(utt_ids[:limit], texts[:limit])):
|
57 |
+
_ = plt.figure(figsize=(14, 8 * num_figures))
|
58 |
+
|
59 |
+
S = text_tokens_lens[b]
|
60 |
+
T = audio_features_lens[b]
|
61 |
+
|
62 |
+
# encoder
|
63 |
+
plt.subplot(num_figures, 1, 1)
|
64 |
+
plt.title(f"Text: {text}")
|
65 |
+
plt.imshow(
|
66 |
+
X=np.transpose(encoder_outputs[b]),
|
67 |
+
cmap=plt.get_cmap("jet"),
|
68 |
+
aspect="auto",
|
69 |
+
interpolation="nearest",
|
70 |
+
)
|
71 |
+
plt.gca().invert_yaxis()
|
72 |
+
plt.axvline(x=S - 0.4, linewidth=2, color="r")
|
73 |
+
plt.xlabel("Encoder Output")
|
74 |
+
plt.colorbar()
|
75 |
+
|
76 |
+
# decoder
|
77 |
+
plt.subplot(num_figures, 1, 2)
|
78 |
+
plt.imshow(
|
79 |
+
X=np.transpose(decoder_outputs[b]),
|
80 |
+
cmap=plt.get_cmap("jet"),
|
81 |
+
aspect="auto",
|
82 |
+
interpolation="nearest",
|
83 |
+
vmin=vmin,
|
84 |
+
vmax=vmax,
|
85 |
+
)
|
86 |
+
plt.gca().invert_yaxis()
|
87 |
+
plt.axvline(x=T - 0.4, linewidth=2, color="r")
|
88 |
+
plt.xlabel("Decoder Output")
|
89 |
+
plt.colorbar()
|
90 |
+
|
91 |
+
# target
|
92 |
+
plt.subplot(num_figures, 1, 3)
|
93 |
+
plt.imshow(
|
94 |
+
X=np.transpose(audio_features[b]),
|
95 |
+
cmap=plt.get_cmap("jet"),
|
96 |
+
aspect="auto",
|
97 |
+
interpolation="nearest",
|
98 |
+
vmin=vmin,
|
99 |
+
vmax=vmax,
|
100 |
+
)
|
101 |
+
plt.gca().invert_yaxis()
|
102 |
+
plt.axvline(x=T - 0.4, linewidth=2, color="r")
|
103 |
+
plt.xlabel("Decoder Target")
|
104 |
+
plt.colorbar()
|
105 |
+
|
106 |
+
plt.savefig(f"{output_dir}/{utt_id}.png")
|
107 |
+
plt.close()
|
models/voice_star.py
ADDED
@@ -0,0 +1,784 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random, os, copy
|
2 |
+
from typing import Dict, Iterator, List, Tuple, Union
|
3 |
+
import logging
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torchmetrics.classification import MulticlassAccuracy
|
10 |
+
import torch.distributed as dist
|
11 |
+
|
12 |
+
from .modules.utils import make_pad_mask, generate_partial_autoregressive_mask
|
13 |
+
|
14 |
+
from .modules.embedding import SinePositionalEmbedding, TokenEmbedding, SinePositionalEmbedding_progress
|
15 |
+
from .modules.transformer import (
|
16 |
+
AdaptiveLayerNorm,
|
17 |
+
LayerNorm,
|
18 |
+
TransformerDecoderLayer,
|
19 |
+
TransformerDecoder,
|
20 |
+
TransformerEncoder,
|
21 |
+
TransformerEncoderLayer,
|
22 |
+
)
|
23 |
+
|
24 |
+
def top_k_top_p_filtering(
|
25 |
+
logits, top_k=0, top_p=1.0, min_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
|
26 |
+
):
|
27 |
+
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
28 |
+
Args:
|
29 |
+
logits: logits distribution shape (batch size, vocabulary size)
|
30 |
+
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
31 |
+
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
32 |
+
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
33 |
+
Make sure we keep at least min_tokens_to_keep per batch example in the output
|
34 |
+
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
35 |
+
"""
|
36 |
+
if min_p < 1.0:
|
37 |
+
probs = F.softmax(logits, dim=-1)
|
38 |
+
indices_to_remove = probs < min_p
|
39 |
+
if not torch.any(indices_to_remove.sum(-1) == logits.size(-1)):
|
40 |
+
logits[indices_to_remove] = filter_value
|
41 |
+
top_k = 0
|
42 |
+
top_p = 1.0
|
43 |
+
# else will use other types of sampling, or no filtering
|
44 |
+
|
45 |
+
# If top_k is a single integer
|
46 |
+
if isinstance(top_k, int) and top_k > 0:
|
47 |
+
# Safety check to ensure we don't ask for more than available
|
48 |
+
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1))
|
49 |
+
|
50 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
51 |
+
threshold = torch.topk(logits, top_k, dim=-1)[0][..., -1, None]
|
52 |
+
indices_to_remove = logits < threshold
|
53 |
+
logits[indices_to_remove] = filter_value
|
54 |
+
|
55 |
+
# If top_k is a list, assume it has the same length as M
|
56 |
+
elif isinstance(top_k, list):
|
57 |
+
# Ensure the length matches the first dimension
|
58 |
+
assert len(top_k) == logits.size(0), \
|
59 |
+
f"top_k list length ({len(top_k)}) must match logits.size(0) ({logits.size(0)})"
|
60 |
+
|
61 |
+
for i in range(logits.size(0)):
|
62 |
+
k_i = top_k[i]
|
63 |
+
if k_i > 0:
|
64 |
+
# Safety check
|
65 |
+
k_i = min(max(k_i, min_tokens_to_keep), logits.size(-1))
|
66 |
+
row_threshold = torch.topk(logits[i], k_i, dim=-1)[0][-1]
|
67 |
+
indices_to_remove_i = logits[i] < row_threshold
|
68 |
+
logits[i, indices_to_remove_i] = filter_value
|
69 |
+
|
70 |
+
if top_p < 1.0:
|
71 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
72 |
+
cumulative_probs = torch.cumsum(
|
73 |
+
F.softmax(sorted_logits, dim=-1), dim=-1
|
74 |
+
)
|
75 |
+
|
76 |
+
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
|
77 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
78 |
+
if min_tokens_to_keep > 1:
|
79 |
+
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
80 |
+
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
|
81 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
82 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
|
83 |
+
..., :-1
|
84 |
+
].clone()
|
85 |
+
sorted_indices_to_remove[..., 0] = 0
|
86 |
+
|
87 |
+
return logits
|
88 |
+
|
89 |
+
|
90 |
+
def topk_sampling(logits, top_k=10, top_p=1.0, min_p=1.0, temperature=1.0):
|
91 |
+
# temperature: (`optional`) float
|
92 |
+
# The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
|
93 |
+
# top_k: (`optional`) int
|
94 |
+
# The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
|
95 |
+
# top_p: (`optional`) float
|
96 |
+
# The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
|
97 |
+
|
98 |
+
# Temperature (higher temperature => more likely to sample low probability tokens)
|
99 |
+
if temperature != 1.0:
|
100 |
+
logits = logits / temperature
|
101 |
+
# Top-p/top-k filtering
|
102 |
+
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p, min_p=min_p)
|
103 |
+
# Sample
|
104 |
+
token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
|
105 |
+
return token
|
106 |
+
|
107 |
+
|
108 |
+
|
109 |
+
class VoiceStar(nn.Module):
|
110 |
+
def __init__(self, args):
|
111 |
+
super().__init__()
|
112 |
+
self.args = args
|
113 |
+
assert self.args.enc_dec ^ self.args.dec, f"self.args.enc_dec: {self.args.enc_dec}, self.args.dec: {self.args.dec}"
|
114 |
+
if not getattr(self.args, "special_first", False):
|
115 |
+
self.args.special_first = 0
|
116 |
+
if not getattr(self.args, "n_special", False):
|
117 |
+
self.args.n_special = 3
|
118 |
+
self.args.eos = getattr(self.args, "eos", -1)
|
119 |
+
self.eog = nn.Parameter(torch.full((self.args.n_codebooks, 1), self.args.eog, dtype=torch.long), requires_grad=False) # [K 1]
|
120 |
+
if self.args.eos > 0:
|
121 |
+
assert self.args.eos != self.args.audio_pad_token and self.args.eos != self.args.empty_token, self.args.eos
|
122 |
+
self.eos = nn.Parameter(torch.full((self.args.n_codebooks, 1), self.args.eos, dtype=torch.long), requires_grad=False) # [K 1]
|
123 |
+
if type(self.args.audio_vocab_size) == str:
|
124 |
+
self.args.audio_vocab_size = eval(self.args.audio_vocab_size)
|
125 |
+
if type(self.args.audio_vocab_size) == list: # otherwise they are all lists
|
126 |
+
assert self.args.special_first
|
127 |
+
|
128 |
+
|
129 |
+
self.n_text_tokens = self.args.text_vocab_size + 1
|
130 |
+
assert self.args.text_pad_token == self.args.text_vocab_size, f"self.args.text_vocab_size: {self.args.text_vocab_size}, self.args.text_pad_token: {self.args.text_pad_token}"
|
131 |
+
|
132 |
+
if self.args.special_first and type(self.args.audio_vocab_size) == list:
|
133 |
+
self.n_audio_tokens = [tok + self.args.n_special for tok in self.args.audio_vocab_size] # special tokens: empty token, EOG token, audio pad token
|
134 |
+
assert self.args.empty_token == 0, self.args.empty_token
|
135 |
+
assert self.args.eog == 1, self.args.eog
|
136 |
+
assert self.args.audio_pad_token == 2, self.args.audio_pad_token
|
137 |
+
else:
|
138 |
+
self.n_audio_tokens = [self.args.audio_vocab_size + self.args.n_special] * self.args.n_codebooks # special tokens: empty token, EOG token, audio pad token
|
139 |
+
assert self.args.audio_vocab_size == self.args.empty_token, self.args.empty_token
|
140 |
+
assert self.args.eog == self.args.audio_vocab_size + 1, self.args.eog
|
141 |
+
assert self.args.audio_pad_token == self.args.audio_vocab_size + 2, self.args.audio_pad_token
|
142 |
+
|
143 |
+
self.text_embedding = TokenEmbedding(
|
144 |
+
dim_model=self.args.d_model,
|
145 |
+
vocab_size=self.n_text_tokens,
|
146 |
+
dropout=self.args.text_embedding_dropout
|
147 |
+
)
|
148 |
+
|
149 |
+
self.audio_embedding = nn.ModuleList(
|
150 |
+
[
|
151 |
+
TokenEmbedding(
|
152 |
+
dim_model=self.args.audio_embedding_dim,
|
153 |
+
vocab_size=self.n_audio_tokens[k],
|
154 |
+
dropout=self.args.audio_embedding_dropout
|
155 |
+
) for k in range(self.args.n_codebooks)
|
156 |
+
]
|
157 |
+
)
|
158 |
+
|
159 |
+
rope_base = getattr(self.args, "rope_base", None)
|
160 |
+
use_sinusoidal = getattr(self.args, "use_sinusoidal", False)
|
161 |
+
use_sinusoidal_progress = getattr(self.args, "use_sinusoidal_progress", False)
|
162 |
+
logging.info(f"rope_base: {rope_base}, use_sinusoidal: {use_sinusoidal}")
|
163 |
+
if use_sinusoidal:
|
164 |
+
self.text_positional_embedding = SinePositionalEmbedding(
|
165 |
+
self.args.d_model,
|
166 |
+
dropout=self.args.text_positional_embedding_dropout,
|
167 |
+
scale=False,
|
168 |
+
alpha=True, # learnable scaler, scale the volume of positional embedding
|
169 |
+
)
|
170 |
+
self.audio_positional_embedding = SinePositionalEmbedding(
|
171 |
+
self.args.d_model,
|
172 |
+
dropout=self.args.audio_positional_embedding_dropout,
|
173 |
+
scale=False,
|
174 |
+
alpha=True, # learnable scaler, scale the volume of positional embedding
|
175 |
+
)
|
176 |
+
elif use_sinusoidal_progress:
|
177 |
+
self.text_positional_embedding = SinePositionalEmbedding_progress(
|
178 |
+
self.args.d_model,
|
179 |
+
dropout=self.args.text_positional_embedding_dropout,
|
180 |
+
scale=False,
|
181 |
+
alpha=True, # learnable scaler, scale the volume of positional embedding
|
182 |
+
args = self.args
|
183 |
+
)
|
184 |
+
self.audio_positional_embedding = SinePositionalEmbedding_progress(
|
185 |
+
self.args.d_model,
|
186 |
+
dropout=self.args.audio_positional_embedding_dropout,
|
187 |
+
scale=False,
|
188 |
+
alpha=True, # learnable scaler, scale the volume of positional embedding
|
189 |
+
args = self.args
|
190 |
+
)
|
191 |
+
|
192 |
+
else:
|
193 |
+
class NoOp:
|
194 |
+
def __init__(self):
|
195 |
+
pass
|
196 |
+
def __call__(self, *args, **kwargs):
|
197 |
+
return args[0]
|
198 |
+
self.text_positional_embedding = NoOp()
|
199 |
+
self.audio_positional_embedding = NoOp()
|
200 |
+
|
201 |
+
if self.args.enc_dec:
|
202 |
+
enc_layer = TransformerEncoderLayer(
|
203 |
+
d_model=self.args.d_model,
|
204 |
+
nhead=self.args.nhead,
|
205 |
+
dim_feedforward=self.args.d_model*4,
|
206 |
+
dropout=self.args.trm_dropout,
|
207 |
+
batch_first=True,
|
208 |
+
norm_first=True,
|
209 |
+
layer_norm_cls=LayerNorm
|
210 |
+
) # use the pre-norm arch
|
211 |
+
|
212 |
+
self.encoder = TransformerEncoder(
|
213 |
+
encoder_layer=enc_layer,
|
214 |
+
num_layers=self.args.num_encoder_layers,
|
215 |
+
norm=LayerNorm(self.args.d_model),
|
216 |
+
rope_base = self.args.rope_base,
|
217 |
+
d_model = self.args.d_model,
|
218 |
+
nhead = self.args.nhead,
|
219 |
+
args = self.args
|
220 |
+
) # use the pre-norm arch
|
221 |
+
|
222 |
+
dec_layer = TransformerDecoderLayer(
|
223 |
+
d_model=self.args.d_model,
|
224 |
+
nhead=self.args.nhead,
|
225 |
+
dim_feedforward=self.args.d_model*4,
|
226 |
+
dropout=self.args.trm_dropout,
|
227 |
+
batch_first=True,
|
228 |
+
norm_first=True,
|
229 |
+
layer_norm_cls=LayerNorm
|
230 |
+
)
|
231 |
+
|
232 |
+
self.decoder = TransformerDecoder(
|
233 |
+
decoder_layer=dec_layer,
|
234 |
+
num_layers=self.args.num_decoder_layers,
|
235 |
+
norm=LayerNorm(self.args.d_model),
|
236 |
+
rope_base = self.args.rope_base,
|
237 |
+
d_model = self.args.d_model,
|
238 |
+
nhead = self.args.nhead,
|
239 |
+
args = self.args
|
240 |
+
) # NOTE: this one I use torch.nn native implementation, as it's not implemented in .modules
|
241 |
+
|
242 |
+
else:
|
243 |
+
dec_layer = TransformerEncoderLayer(
|
244 |
+
self.args.d_model,
|
245 |
+
self.args.nhead,
|
246 |
+
dim_feedforward=self.args.d_model * 4,
|
247 |
+
dropout=self.args.trm_dropout,
|
248 |
+
batch_first=True,
|
249 |
+
norm_first=True,
|
250 |
+
layer_norm_cls=LayerNorm
|
251 |
+
)
|
252 |
+
self.decoder = TransformerEncoder(
|
253 |
+
dec_layer,
|
254 |
+
num_layers=self.args.num_decoder_layers,
|
255 |
+
norm=LayerNorm(self.args.d_model),
|
256 |
+
)
|
257 |
+
|
258 |
+
if type(self.args.audio_vocab_size) == int:
|
259 |
+
self.predict_layer = nn.ModuleList(
|
260 |
+
[
|
261 |
+
nn.Sequential(nn.Linear(self.args.d_model, self.args.audio_vocab_size//2), nn.GELU(), nn.Linear(self.args.audio_vocab_size//2, self.n_audio_tokens[k])) for k in range(self.args.n_codebooks)
|
262 |
+
]
|
263 |
+
)
|
264 |
+
else:
|
265 |
+
self.predict_layer = nn.ModuleList(
|
266 |
+
[
|
267 |
+
nn.Sequential(nn.Linear(self.args.d_model, self.args.d_model//2), nn.GELU(), nn.Linear(self.args.d_model//2, self.n_audio_tokens[k])) for k in range(self.args.n_codebooks)
|
268 |
+
]
|
269 |
+
)
|
270 |
+
|
271 |
+
self.accuracy_metrics = nn.ModuleList(
|
272 |
+
[MulticlassAccuracy(
|
273 |
+
self.n_audio_tokens[k],
|
274 |
+
top_k=10,
|
275 |
+
average="micro",
|
276 |
+
multidim_average="global",
|
277 |
+
ignore_index=None,
|
278 |
+
) for k in range(self.args.n_codebooks)]
|
279 |
+
)
|
280 |
+
|
281 |
+
if self.args.eog_weight != 1:
|
282 |
+
raise NotImplementedError("now have different vocab_size for different codebooks, therefore currently don't support eog_weight")
|
283 |
+
self.class_weight = nn.Parameter(torch.ones(self.n_audio_tokens), requires_grad=False)
|
284 |
+
self.class_weight.data[self.args.eog] = self.args.eog_weight
|
285 |
+
|
286 |
+
def dec_forward(
|
287 |
+
self,
|
288 |
+
x_input,
|
289 |
+
x_lens,
|
290 |
+
x_attention_mask,
|
291 |
+
x_padding_mask,
|
292 |
+
y_input,
|
293 |
+
new_y_lens,
|
294 |
+
y_attention_mask,
|
295 |
+
y_padding_mask,
|
296 |
+
need_weights=False,
|
297 |
+
past=None,
|
298 |
+
last_3_tokens=False
|
299 |
+
):
|
300 |
+
x_attn_mask = F.pad(
|
301 |
+
x_attention_mask,
|
302 |
+
(0, new_y_lens.max()),
|
303 |
+
value=True,
|
304 |
+
) # x attn to all x, doesn't attn to any y, this follow figure 3 of the valle paper
|
305 |
+
y_attn_mask = F.pad(
|
306 |
+
y_attention_mask,
|
307 |
+
(x_lens.max(), 0), # y is padded at the front
|
308 |
+
value=False,
|
309 |
+
) # y attn to all x, for y itself use lower triangle mask to ensure autoregressive
|
310 |
+
xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
|
311 |
+
|
312 |
+
# merge key padding and attention masks
|
313 |
+
bsz, src_len = x_input.shape[0], x_lens.max() + new_y_lens.max()
|
314 |
+
xy_padding_mask = torch.concat([x_padding_mask, y_padding_mask], dim=1)
|
315 |
+
_xy_padding_mask = (
|
316 |
+
xy_padding_mask.view(bsz, 1, 1, src_len)
|
317 |
+
.expand(-1, self.args.nhead, -1, -1)
|
318 |
+
.reshape(bsz * self.args.nhead, 1, src_len)
|
319 |
+
)
|
320 |
+
xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
|
321 |
+
|
322 |
+
new_attn_mask = torch.zeros_like(xy_attn_mask)
|
323 |
+
new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
|
324 |
+
xy_attn_mask = new_attn_mask
|
325 |
+
|
326 |
+
xy_input = torch.cat([x_input, y_input], dim=1)
|
327 |
+
if need_weights:
|
328 |
+
raise NotImplementedError("not implemented yet")
|
329 |
+
out, layer_attn_weights = self.decoder((xy_input, None), mask=xy_attn_mask, need_weights=True)
|
330 |
+
return layer_attn_weights
|
331 |
+
|
332 |
+
if past == None: # do not use kvcache
|
333 |
+
out, _ = self.decoder((xy_input, None), mask=xy_attn_mask)
|
334 |
+
return out[:, x_lens.max():], None
|
335 |
+
else: # use kvcache
|
336 |
+
if past.ndim > 3: # uses kvcache, only need to pass the last tokens, this doesn't work with multi-span speech editing yet
|
337 |
+
if last_3_tokens:
|
338 |
+
xy_input = xy_input[:, -3:]
|
339 |
+
xy_attn_mask = xy_attn_mask[:, -3:]
|
340 |
+
else:
|
341 |
+
xy_input = xy_input[:, -1:]
|
342 |
+
xy_attn_mask = xy_attn_mask[:, -1:]
|
343 |
+
|
344 |
+
out, present = self.decoder((xy_input, None), mask=xy_attn_mask, past=past)
|
345 |
+
if isinstance(out, tuple): # get rid of stage_embedding
|
346 |
+
out = out[0]
|
347 |
+
|
348 |
+
if out.shape[1] > x_lens.max(): # the first pass, not kvcache yet
|
349 |
+
return out[:, x_lens.max():], present
|
350 |
+
else: # used kvcache
|
351 |
+
return out, present
|
352 |
+
|
353 |
+
def enc_dec_forward(
|
354 |
+
self,
|
355 |
+
xa,
|
356 |
+
x_attention_mask,
|
357 |
+
x_padding_mask,
|
358 |
+
y_input,
|
359 |
+
new_y_lens,
|
360 |
+
y_attention_mask,
|
361 |
+
y_padding_mask,
|
362 |
+
tgt_y_lens=None,
|
363 |
+
need_weights=False,
|
364 |
+
past=None,
|
365 |
+
last_3_tokens=False
|
366 |
+
):
|
367 |
+
assert not need_weights
|
368 |
+
if past != None and past.ndim > 3:
|
369 |
+
y_input = y_input[:, -1:]
|
370 |
+
y_attention_mask = y_attention_mask[-1:]
|
371 |
+
yhat, present = self.decoder(tgt=y_input, memory=xa, tgt_mask=y_attention_mask, tgt_key_padding_mask=y_padding_mask, memory_key_padding_mask=x_padding_mask, query_lens=tgt_y_lens, past=past)
|
372 |
+
return yhat, present
|
373 |
+
|
374 |
+
def forward(self, batch, calc_loss = False):
|
375 |
+
"""
|
376 |
+
Args:
|
377 |
+
x:
|
378 |
+
A 2-D tensor of shape (N, S).
|
379 |
+
x_lens:
|
380 |
+
A 1-D tensor of shape (N,). It contains the number of tokens in `x`
|
381 |
+
before padding.
|
382 |
+
y:
|
383 |
+
A 3-D tensor of shape (N, K, T).
|
384 |
+
where K is the number of codebooks
|
385 |
+
y_lens:
|
386 |
+
A 1-D tensor of shape (N,). It contains the number of tokens in `x`
|
387 |
+
before padding.
|
388 |
+
"""
|
389 |
+
x, x_lens, y, y_lens = batch["x"], batch["x_lens"], batch["y"], batch["y_lens"]
|
390 |
+
if len(x) == 0:
|
391 |
+
return None
|
392 |
+
x = x[:, :x_lens.max()] # this deal with gradient accumulation, where x_lens.max() might not be longer than the length of the current slice of x
|
393 |
+
y = y[...,:y_lens.max()]
|
394 |
+
assert x.ndim == 2, x.shape
|
395 |
+
assert x_lens.ndim == 1, x_lens.shape
|
396 |
+
assert y.ndim == 3 and y.shape[1] == self.args.n_codebooks, y.shape
|
397 |
+
assert y_lens.ndim == 1, y_lens.shape
|
398 |
+
x_padding_mask = make_pad_mask(x_lens).to(x.device)
|
399 |
+
x_attention_mask = torch.triu(torch.ones(x.shape[1], x.shape[1]), diagonal=1).bool().to(x_padding_mask.device)
|
400 |
+
x_input = self.text_embedding(x)
|
401 |
+
x_input = self.text_positional_embedding(x_input, x_lens)
|
402 |
+
y_with_eos = [torch.cat([item[:, :y_lens[i]], self.eos], dim=-1) for i, item in enumerate(y)]
|
403 |
+
targets = y_with_eos
|
404 |
+
# apply delayed stacking on y
|
405 |
+
shifted_y = []
|
406 |
+
patterns = []
|
407 |
+
new_y_lens = []
|
408 |
+
if getattr(self, "empty_tokens", None) == None:
|
409 |
+
self.empty_tokens = torch.full((self.args.n_codebooks, self.args.n_codebooks), self.args.empty_token, dtype=torch.long).to(y.device) # [K, K]
|
410 |
+
for i in range(len(y)):
|
411 |
+
tmp = torch.cat([y_with_eos[i], self.empty_tokens], dim=-1) # [K, T+n_codebooks]
|
412 |
+
for ii in range(self.args.n_codebooks):
|
413 |
+
tmp[ii] = torch.roll(tmp[ii], shifts=ii+1, dims=0)
|
414 |
+
shifted_y.append(tmp.transpose(1,0)) # [K, T+n_codebooks] -> [T+n_codebooks, K]
|
415 |
+
new_y_lens.append(y_with_eos[i].shape[1] + self.empty_tokens.shape[1])
|
416 |
+
|
417 |
+
new_y_lens = torch.LongTensor(new_y_lens).to(y.device)
|
418 |
+
|
419 |
+
cated_y = torch.nn.utils.rnn.pad_sequence(shifted_y, batch_first=False, padding_value=self.args.audio_pad_token)
|
420 |
+
assert cated_y.shape == torch.Size([max(new_y_lens), len(y), self.args.n_codebooks]), cated_y.shape
|
421 |
+
cated_y = cated_y.permute(2,0,1) # [T,B,K]->[K,T,B]
|
422 |
+
stacked_embedded_y = torch.stack([self.audio_embedding[k](cated_y[k]) for k in range(self.args.n_codebooks)], dim=0) # [K, T, B, D]
|
423 |
+
assert stacked_embedded_y.shape[0] == self.args.n_codebooks and stacked_embedded_y.shape[2] == len(y) and stacked_embedded_y.shape[-1] == self.args.d_model, stacked_embedded_y.shape
|
424 |
+
embedded_y = stacked_embedded_y.sum(dim=0) # [K,T,B,D]->[T,B,D]
|
425 |
+
embedded_y = embedded_y.transpose(1,0) # [T,B,D]->[B,T,D]
|
426 |
+
assert embedded_y.shape[1:] == torch.Size([max(new_y_lens), self.args.d_model]), embedded_y.shape
|
427 |
+
y_input = self.audio_positional_embedding(embedded_y, new_y_lens)
|
428 |
+
y_padding_mask = make_pad_mask(new_y_lens).to(y.device)
|
429 |
+
y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y_padding_mask.device)
|
430 |
+
if self.args.dec:
|
431 |
+
y_out = self.dec_forward(
|
432 |
+
x_input,
|
433 |
+
x_lens,
|
434 |
+
x_attention_mask,
|
435 |
+
x_padding_mask,
|
436 |
+
y_input,
|
437 |
+
new_y_lens,
|
438 |
+
y_attention_mask,
|
439 |
+
y_padding_mask
|
440 |
+
)
|
441 |
+
else:
|
442 |
+
xa = self.encoder(src=x_input, src_key_padding_mask=x_padding_mask)
|
443 |
+
y_out = self.enc_dec_forward(
|
444 |
+
xa,
|
445 |
+
x_attention_mask,
|
446 |
+
x_padding_mask,
|
447 |
+
y_input,
|
448 |
+
new_y_lens,
|
449 |
+
y_attention_mask,
|
450 |
+
y_padding_mask
|
451 |
+
)
|
452 |
+
y_out = y_out[0] # no kv-caching during training
|
453 |
+
assert y_out.shape == y_input.shape, f"y_out.shape: {y_out.shape}, y_input.shape: {y_input.shape}" # [B S D]
|
454 |
+
logits = torch.stack([self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)], dim=1) # [B K S card]
|
455 |
+
assert logits.shape[1] == self.args.n_codebooks and logits.shape[3] == self.n_audio_tokens[0], logits.shape
|
456 |
+
logits_use = [logit[:, :new_y_lens[i]] for i, logit in enumerate(logits)] # each of shape [K, T, card]
|
457 |
+
logits_final = []
|
458 |
+
for i, logit in enumerate(logits_use):
|
459 |
+
logit_copy = logit.clone()
|
460 |
+
for ii in range(self.args.n_codebooks):
|
461 |
+
logit_copy[ii] = torch.roll(logit_copy[ii], shifts=-ii, dims=0)
|
462 |
+
logit = logit_copy[:, :-self.args.n_codebooks] # [K, T, card] -> [K, T-n_codebooks, card]
|
463 |
+
logits_final.append(logit)
|
464 |
+
if self.args.no_loss_on_prefix:
|
465 |
+
assert "y_sep_token_position" in batch, f"y_sep_token_position should be in batch, but it's not"
|
466 |
+
logit_temp = []
|
467 |
+
target_temp = []
|
468 |
+
for jj, (logit, target) in enumerate(zip(logits_final, targets)):
|
469 |
+
# TODO already taken into consideration in depth transformer
|
470 |
+
logit_temp.append(logit[:, batch['y_sep_token_position'][jj]:])
|
471 |
+
target_temp.append(target[:, batch['y_sep_token_position'][jj]:])
|
472 |
+
logits_final = logit_temp
|
473 |
+
targets = target_temp
|
474 |
+
logits = torch.cat(logits_final, dim=1) # [K, T1+T2+T3+..., card]
|
475 |
+
targets = torch.cat(targets, dim=1) # [K, T1+T2+T3+...]
|
476 |
+
|
477 |
+
assert targets.shape[:2] == logits.shape[:2], f"{targets.shape}, {logits.shape}"
|
478 |
+
loss = []
|
479 |
+
ntokens = []
|
480 |
+
top10acc = []
|
481 |
+
for k, (logit, target) in enumerate(zip(logits, targets)): # even though the loss and top10acc is calculated in a loop (loop through n_codebooks), validation is still taking a lot of mem, need to optimize this a little more
|
482 |
+
loss.append(F.cross_entropy(logit, target, reduction='mean', weight=self.class_weight.data if self.args.eog_weight!=1 else None, ignore_index=self.args.y_sep_token if self.args.y_sep_token != None else -100)) # ignore audio sep token as it's unpredictable (like the random early stop bug happened in 2023)
|
483 |
+
# NOTE have to ignore the sep token in the loss calculation
|
484 |
+
top10acc.append(self.accuracy_metrics[k](logit.detach(), target))
|
485 |
+
ntokens.append(len(logit))
|
486 |
+
|
487 |
+
all_ntokens = sum(ntokens)
|
488 |
+
if self.args.codebook_weight != None:
|
489 |
+
codebook_weight = eval(self.args.codebook_weight) if isinstance(self.args.codebook_weight, str) else self.args.codebook_weight
|
490 |
+
else:
|
491 |
+
codebook_weight = [1.] * self.args.n_codebooks
|
492 |
+
perplexity_by_codebook = [torch.exp(l) for l in loss]
|
493 |
+
loss = sum([l*nt*cw for l, nt, cw in zip(loss, ntokens, codebook_weight)])
|
494 |
+
|
495 |
+
top10acc_by_codebook = [t10a*nt for t10a, nt in zip(top10acc, ntokens)]
|
496 |
+
top10acc = sum(top10acc_by_codebook)
|
497 |
+
|
498 |
+
ntokens = torch.tensor(all_ntokens).to(logits.device)
|
499 |
+
|
500 |
+
ret = {
|
501 |
+
"loss": loss,
|
502 |
+
"perplexity_by_codebook": perplexity_by_codebook,
|
503 |
+
"top10acc": top10acc,
|
504 |
+
"top10acc_by_codebook": top10acc_by_codebook,
|
505 |
+
"effective_ntoken": ntokens,
|
506 |
+
}
|
507 |
+
|
508 |
+
return ret
|
509 |
+
|
510 |
+
def inference_tts(
|
511 |
+
self,
|
512 |
+
x: torch.Tensor,
|
513 |
+
x_lens: torch.Tensor,
|
514 |
+
y: torch.Tensor,
|
515 |
+
tgt_y_lens: torch.Tensor, #
|
516 |
+
top_k: Union[int, list[int]]=-100,
|
517 |
+
top_p: float=1.0,
|
518 |
+
min_p: float=1.0,
|
519 |
+
temperature: float=1.0,
|
520 |
+
stop_repetition: int=3,
|
521 |
+
kvcache: int=1,
|
522 |
+
silence_tokens: list[int]=[],
|
523 |
+
multi_trial: list[int]=[],
|
524 |
+
*kargs
|
525 |
+
) -> torch.Tensor:
|
526 |
+
"""
|
527 |
+
different from inference_tts, this implementation uses kvcache, which should have significant speed up
|
528 |
+
Args:
|
529 |
+
x:
|
530 |
+
A 2-D tensor of shape (1, L).
|
531 |
+
x_lens:
|
532 |
+
A 1-D tensor of shape (1,). It contains the number of tokens in `x`
|
533 |
+
before padding.
|
534 |
+
y:
|
535 |
+
A 3-D tensor of shape (1, T, K).
|
536 |
+
tgt_y_lens:
|
537 |
+
*new arg* this specify the target length of y
|
538 |
+
top_k: (`optional`) int
|
539 |
+
The number of highest probability tokens to keep for top-k-filtering. Default to -100.
|
540 |
+
top_p: (`optional`) float
|
541 |
+
For Neucleus sampling
|
542 |
+
min_p: (`optional`) float
|
543 |
+
For min_p filtered sampling
|
544 |
+
temperature: (`optional`) float
|
545 |
+
The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
|
546 |
+
multi_trial: (`optional`) list[int]
|
547 |
+
If not empty, it will be [n_trials, beam_size, trial_interval]
|
548 |
+
from the start and begining trial_interval, we duplicate the current sample by beam_size,
|
549 |
+
at the end of every trial_interval, we choose the sample with the highest log likelihood to keep and throw away the rest
|
550 |
+
"""
|
551 |
+
eog_inference = self.args.eos if self.args.eos>0 else self.args.eog
|
552 |
+
assert x.ndim == 2, x.shape
|
553 |
+
assert x_lens.ndim == 1, x_lens.shape
|
554 |
+
assert y.ndim == 3, y.shape
|
555 |
+
if self.args.special_first:
|
556 |
+
y = y + int(self.args.n_special)
|
557 |
+
y = y.transpose(2,1) # [1,T,K] -> [1,K,T]
|
558 |
+
assert y.shape[0] == 1 and y.shape[1] == self.args.n_codebooks, y.shape # there is no padding
|
559 |
+
|
560 |
+
# make x attention mask and x_input
|
561 |
+
x_attention_mask = torch.triu(torch.ones(x.shape[1], x.shape[1]), diagonal=1).bool().to(x.device)
|
562 |
+
# x_attention_mask = torch.zeros(x.shape[1], x.shape[1]).bool().to(x.device)
|
563 |
+
x_input = self.text_embedding(x)
|
564 |
+
x_input = self.text_positional_embedding(x_input, x_lens)
|
565 |
+
|
566 |
+
y_len = y.shape[2]
|
567 |
+
y_lens = torch.LongTensor([y_len]).to(y.device)
|
568 |
+
|
569 |
+
# rearrange y, we don't add eog to the end, this doesn't actually do anything in the tts scenario
|
570 |
+
rearranged_y = [[y[0]]]
|
571 |
+
assert rearranged_y[0][0].shape[0] == self.args.n_codebooks, rearranged_y[0][0].shape
|
572 |
+
|
573 |
+
# # shift y to create the delayed pattern
|
574 |
+
if getattr(self, "empty_tokens", None) == None:
|
575 |
+
self.empty_tokens = torch.full((self.args.n_codebooks, self.args.n_codebooks), self.args.empty_token, dtype=torch.long).to(y.device) # [K, K]
|
576 |
+
temp = rearranged_y[0][0]
|
577 |
+
assert temp.ndim == 2 and temp.shape[0] == self.args.n_codebooks, temp.shape
|
578 |
+
temp = torch.cat([temp, self.empty_tokens], dim=-1) # [K, T+n_codebooks]
|
579 |
+
for ii in range(self.args.n_codebooks):
|
580 |
+
temp[ii] = torch.roll(temp[ii], shifts=ii+1, dims=0)
|
581 |
+
shifted_y = [[temp]]
|
582 |
+
|
583 |
+
# below is different from forward or inference
|
584 |
+
# where we cut this shifted part
|
585 |
+
shifted_y[0][0] = shifted_y[0][0][:, :-(self.args.n_codebooks-1)]
|
586 |
+
assert not (shifted_y[0][0][self.args.n_codebooks:] == self.args.empty_token).any() and not (shifted_y[0][0][self.args.n_codebooks:] == self.args.eog).any(), shifted_y[0][0]
|
587 |
+
# next section in inference is insert mask at the intersection of each tensor in a sample, but we don't need to do that
|
588 |
+
# next section is concate tensors of each sample to one tensor, which we also don't need
|
589 |
+
cated_y = shifted_y[0][0].unsqueeze(-1) #[K,S]->[K,S,B]
|
590 |
+
new_y_lens = torch.LongTensor([cated_y.shape[1]]).to(cated_y.device)
|
591 |
+
assert cated_y.shape == torch.Size((self.args.n_codebooks, cated_y.shape[1], 1))
|
592 |
+
assert not (cated_y == self.args.audio_pad_token).any(), cated_y
|
593 |
+
|
594 |
+
# replace tokens in y with the embeddings, add sum codebooks up
|
595 |
+
embedded_y = torch.stack([self.audio_embedding[k](cated_y[k]) for k in range(self.args.n_codebooks)], dim=0) # [K, S, B, D]
|
596 |
+
assert embedded_y.shape[0] == self.args.n_codebooks, embedded_y.shape
|
597 |
+
assert embedded_y.shape[-1] == self.args.d_model, embedded_y.shape
|
598 |
+
embedded_y = embedded_y.sum(dim=0) # [K,S,B,D]->[S,B,D]
|
599 |
+
embedded_y = embedded_y.transpose(1,0) # [S,B,D]->[B,S,D]
|
600 |
+
|
601 |
+
# positional embedding
|
602 |
+
y_input = self.audio_positional_embedding(embedded_y, tgt_y_lens)
|
603 |
+
|
604 |
+
# make attention mask and padding mask
|
605 |
+
y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device)
|
606 |
+
|
607 |
+
x_padding_mask = torch.full((1,x_lens[0]), False).to(x.device)
|
608 |
+
y_padding_mask = torch.full((1,new_y_lens[0]), False).to(y.device)
|
609 |
+
|
610 |
+
# entering the generation stage
|
611 |
+
# starting from line 708
|
612 |
+
codebook_eog = [False] * self.args.n_codebooks
|
613 |
+
generated = [] # doesn't contain any empty token, contain eog
|
614 |
+
cur_generated = []
|
615 |
+
# say 0 is empty, 4 is eog
|
616 |
+
# tensor([[ 1, 2, 3, 4, 0, 0],
|
617 |
+
# [ 0, 1, 2, 3, 4, 0],
|
618 |
+
# [ 0, 0, 1, 2, 3, 4]])
|
619 |
+
num_gen = []
|
620 |
+
cur_num_gen = 0
|
621 |
+
##################### silence repetition handling #####################
|
622 |
+
##################### silence repetition handling #####################
|
623 |
+
# silence_tokens = [1388,1898,131] # [1388, 2045, 2041, 1996]
|
624 |
+
# silence_tokens = []
|
625 |
+
consec_silence_count = 0
|
626 |
+
prev_token = None
|
627 |
+
##################### silence repetition handling #####################
|
628 |
+
##################### silence repetition handling #####################
|
629 |
+
|
630 |
+
def sample_helper(n_eog, logits, codebook_eog, top_k, top_p, min_p, temperature, prev_token, consec_silence_count, stop_repetition, silence_tokens, cur_num_gen):
|
631 |
+
if n_eog == 0:
|
632 |
+
logits_adjust = logits
|
633 |
+
for jj in range(1,self.args.n_codebooks):
|
634 |
+
logits_adjust[jj][eog_inference] = -10000
|
635 |
+
logits_adjust[jj][self.args.empty_token] = -10000
|
636 |
+
if cur_num_gen <= self.args.encodec_sr // 5: # this shouldn't happen, but just in case the model stopped too early
|
637 |
+
logits_adjust[0][eog_inference] = -10000
|
638 |
+
##################### silence repetition handling #####################
|
639 |
+
if stop_repetition > 0 and prev_token in silence_tokens and consec_silence_count > stop_repetition:
|
640 |
+
if logits_adjust[0, prev_token] < 0:
|
641 |
+
logits_adjust[0, prev_token] = logits_adjust[0, prev_token] * (consec_silence_count - (stop_repetition-1))
|
642 |
+
else:
|
643 |
+
logits_adjust[0, prev_token] = logits_adjust[0, prev_token] / (consec_silence_count - (stop_repetition-1))
|
644 |
+
##################### silence repetition handling #####################
|
645 |
+
samples = topk_sampling(
|
646 |
+
logits_adjust, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature
|
647 |
+
) # [K, 1]
|
648 |
+
assert samples.shape == torch.Size((self.args.n_codebooks, 1)), f"samples.shape: {samples.shape}"
|
649 |
+
if cur_num_gen < self.args.n_codebooks-1:
|
650 |
+
for jj in range(1, self.args.n_codebooks - cur_num_gen):
|
651 |
+
samples[-jj, 0] = self.args.empty_token
|
652 |
+
|
653 |
+
if (
|
654 |
+
samples[0,0] == eog_inference or torch.argmax(logits[0], dim=-1) == eog_inference or y_input.shape[1] > x_lens[0] * (self.args.encodec_sr//4)
|
655 |
+
) or self.args.rope_base is not None and not self.args.decoder_regular_rope and self.args.progress_no_multiple and cur_num_gen > (tgt_y_lens[0] + self.args.encodec_sr * getattr(self.args, "extra_cutoff", 5)):
|
656 |
+
# last one condition in the first bracket means y is already too long, shouldn't happen, but put it here
|
657 |
+
# the second bracket means we are using progress-monitoring RoPE, but the model is generating excessively long sequence (5 seconds more than specified), in which case we terminate the generation
|
658 |
+
samples[0,0] = eog_inference
|
659 |
+
codebook_eog[0] = True
|
660 |
+
##################### silence repetition handling #####################
|
661 |
+
if samples[0,0] in silence_tokens and samples[0,0] == prev_token:
|
662 |
+
consec_silence_count += 1
|
663 |
+
else:
|
664 |
+
consec_silence_count = 0
|
665 |
+
prev_token = samples[0,0]
|
666 |
+
##################### silence repetition handling #####################
|
667 |
+
return samples, codebook_eog, prev_token, consec_silence_count
|
668 |
+
else:
|
669 |
+
assert sum(codebook_eog[i] for i in range(n_eog)) == n_eog, f"codebook_eog: {codebook_eog}, but n_eog: {n_eog}"
|
670 |
+
logits_adjust = logits
|
671 |
+
for jj in range(n_eog+1,self.args.n_codebooks):
|
672 |
+
logits_adjust[jj][eog_inference] = -10000
|
673 |
+
logits_adjust[jj][self.args.empty_token] = -10000
|
674 |
+
samples = topk_sampling(
|
675 |
+
logits_adjust, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature
|
676 |
+
) # [K, 1]
|
677 |
+
for jj in range(n_eog):
|
678 |
+
samples[jj, 0] = self.args.empty_token
|
679 |
+
samples[n_eog, 0] = eog_inference
|
680 |
+
codebook_eog[n_eog] = True
|
681 |
+
return samples, codebook_eog, prev_token, consec_silence_count
|
682 |
+
|
683 |
+
# prepare the cache placeholder
|
684 |
+
# n_layers, 2, bsz, num_heads, src_len, head_dim, 2 means [key, value]
|
685 |
+
past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float32) if kvcache else None
|
686 |
+
if self.args.enc_dec:
|
687 |
+
xa = self.encoder(src=x_input, src_key_padding_mask=x_padding_mask)
|
688 |
+
while True:
|
689 |
+
if self.args.dec:
|
690 |
+
y_out, present = self.dec_forward(
|
691 |
+
x_input,
|
692 |
+
x_lens,
|
693 |
+
x_attention_mask,
|
694 |
+
x_padding_mask,
|
695 |
+
y_input,
|
696 |
+
new_y_lens,
|
697 |
+
y_attention_mask,
|
698 |
+
y_padding_mask,
|
699 |
+
past=past
|
700 |
+
)
|
701 |
+
else:
|
702 |
+
y_out, present = self.enc_dec_forward(
|
703 |
+
xa,
|
704 |
+
x_attention_mask,
|
705 |
+
x_padding_mask,
|
706 |
+
y_input,
|
707 |
+
new_y_lens,
|
708 |
+
y_attention_mask,
|
709 |
+
y_padding_mask,
|
710 |
+
tgt_y_lens=tgt_y_lens,
|
711 |
+
past=past
|
712 |
+
)
|
713 |
+
if past != None:
|
714 |
+
past = torch.cat([past, present.to(past.dtype)], dim=-2) if past.ndim > 3 else present.to(past.dtype)
|
715 |
+
|
716 |
+
|
717 |
+
y_out = y_out[:, -1:] # only take the last token
|
718 |
+
|
719 |
+
logits = torch.stack([self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)], dim=1) # [B K S card], B==S==1, so [1 K 1 card]
|
720 |
+
logits = logits.squeeze(0).squeeze(1) # [K card]
|
721 |
+
assert logits.shape == torch.Size((self.args.n_codebooks, self.n_audio_tokens[0])), f"{logits.shape}"
|
722 |
+
|
723 |
+
n_eog = sum(codebook_eog)
|
724 |
+
assert n_eog < self.args.n_codebooks
|
725 |
+
if self.args.eos > 0: # if we are using end-of-sentence token (which is used by default), eog shouldn't be used here, as there is no masked spans
|
726 |
+
for jj in range(self.args.n_codebooks):
|
727 |
+
logits[jj][self.args.eog] = -10000.
|
728 |
+
|
729 |
+
samples, codebook_eog, prev_token, consec_silence_count = sample_helper(n_eog, logits, codebook_eog, top_k, top_p, min_p, temperature, prev_token, consec_silence_count, stop_repetition, silence_tokens, cur_num_gen)
|
730 |
+
# samples.shape is [K,1]
|
731 |
+
# ge samples_emb
|
732 |
+
samples_emb = torch.stack([self.audio_embedding[k](samples[k]) for k in range(self.args.n_codebooks)], dim=0) # [K,1,D]
|
733 |
+
samples_emb = samples_emb.sum(dim=0,keepdim=True) # [1,1,D]
|
734 |
+
|
735 |
+
cur_num_gen += 1
|
736 |
+
cur_generated.append(samples.squeeze(-1)) # [K,1] -> [K]
|
737 |
+
|
738 |
+
if sum(codebook_eog) == self.args.n_codebooks: # generation for the current span is done
|
739 |
+
codebook_eog = [False] * self.args.n_codebooks
|
740 |
+
num_gen.append(cur_num_gen)
|
741 |
+
cur_num_gen = 0
|
742 |
+
generated.append(cur_generated)
|
743 |
+
cur_generated = []
|
744 |
+
break
|
745 |
+
else:
|
746 |
+
assert samples_emb.shape == torch.Size((1,1,self.args.d_model)), f"samples_emb.shape: {samples_emb.shape}"
|
747 |
+
|
748 |
+
embedded_y = torch.cat([embedded_y, samples_emb], dim=1)
|
749 |
+
new_y_lens = torch.LongTensor([embedded_y.shape[1]]).to(y.device)
|
750 |
+
y_input = self.audio_positional_embedding(embedded_y, tgt_y_lens) # [B T D]
|
751 |
+
# make attention mask and padding mask
|
752 |
+
y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device)
|
753 |
+
y_padding_mask = torch.full((1,new_y_lens[0]), False).to(y.device)
|
754 |
+
|
755 |
+
assert len(generated) == 1, f"len(generated): {len(generated)}"
|
756 |
+
|
757 |
+
# revert the pattern
|
758 |
+
flatten_gen = []
|
759 |
+
for l, orig_span in enumerate(generated):
|
760 |
+
span = torch.stack(orig_span, dim=0) # [T, K]
|
761 |
+
span = span.transpose(1,0) # [K, T]
|
762 |
+
assert span.shape[0] == self.args.n_codebooks, span.shape
|
763 |
+
unshifted_span = []
|
764 |
+
for j, s in enumerate(span):
|
765 |
+
start_from = j
|
766 |
+
end_at = - (self.args.n_codebooks - start_from)
|
767 |
+
unshifted_span.append(s[start_from:end_at])
|
768 |
+
unshifted_span = torch.stack(unshifted_span, dim=0)
|
769 |
+
|
770 |
+
assert unshifted_span.shape[1] == num_gen[l] - self.args.n_codebooks, f"len(unshifted_spans[0]): {len(unshifted_span[0])}, num_gen[l]: {num_gen[l]}"
|
771 |
+
|
772 |
+
flatten_gen.append(unshifted_span)
|
773 |
+
assert len(flatten_gen) == 1, len(flatten_gen)
|
774 |
+
|
775 |
+
# combine
|
776 |
+
res = [y[0], flatten_gen[0]]
|
777 |
+
res = torch.cat(res, dim=1).unsqueeze(0) # [K, new_t] -> [1, K, new_T]
|
778 |
+
expected_y_len = y_len + sum([item - self.args.n_codebooks for item in num_gen])
|
779 |
+
assert res.shape == torch.Size((1, self.args.n_codebooks, expected_y_len)), f"res.shape: {res.shape}, expected_y_len: {expected_y_len}. y_len + sum([item - self.args.n_codebooks for item in num_gen]): {y_len} + {sum([item - self.args.n_codebooks for item in num_gen])}"
|
780 |
+
|
781 |
+
if self.args.special_first:
|
782 |
+
res = res - int(self.args.n_special)
|
783 |
+
flatten_gen = flatten_gen - int(self.args.n_special)
|
784 |
+
return res, flatten_gen[0].unsqueeze(0)
|
pretrained/.gitkeep
ADDED
File without changes
|
steps/__init__.py
ADDED
File without changes
|
steps/optim.py
ADDED
@@ -0,0 +1,1123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
|
2 |
+
#
|
3 |
+
# See ../LICENSE for clarification regarding multiple authors
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
import contextlib
|
18 |
+
import logging
|
19 |
+
import random
|
20 |
+
from collections import defaultdict
|
21 |
+
from typing import List, Optional, Tuple, Union
|
22 |
+
|
23 |
+
import torch
|
24 |
+
import torch.nn as nn
|
25 |
+
from torch import Tensor
|
26 |
+
from torch.optim import Optimizer
|
27 |
+
|
28 |
+
|
29 |
+
class BatchedOptimizer(Optimizer):
|
30 |
+
"""
|
31 |
+
This class adds to class Optimizer the capability to optimize parameters in batches:
|
32 |
+
it will stack the parameters and their grads for you so the optimizer can work
|
33 |
+
on tensors with an extra leading dimension. This is intended for speed with GPUs,
|
34 |
+
as it reduces the number of kernels launched in the optimizer.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
params:
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(self, params, defaults):
|
41 |
+
super(BatchedOptimizer, self).__init__(params, defaults)
|
42 |
+
|
43 |
+
@contextlib.contextmanager
|
44 |
+
def batched_params(self, param_group, group_params_names):
|
45 |
+
"""
|
46 |
+
This function returns (technically, yields) a list of
|
47 |
+
of tuples (p, state), where
|
48 |
+
p is a `fake` parameter that is stacked (over axis 0) from real parameters
|
49 |
+
that share the same shape, and its gradient is also stacked;
|
50 |
+
`state` is the state corresponding to this batch of parameters
|
51 |
+
(it will be physically located in the "state" for one of the real
|
52 |
+
parameters, the last one that has any particular shape and dtype).
|
53 |
+
|
54 |
+
This function is decorated as a context manager so that it can
|
55 |
+
write parameters back to their "real" locations.
|
56 |
+
|
57 |
+
The idea is, instead of doing:
|
58 |
+
<code>
|
59 |
+
for p in group["params"]:
|
60 |
+
state = self.state[p]
|
61 |
+
...
|
62 |
+
</code>
|
63 |
+
you can do:
|
64 |
+
<code>
|
65 |
+
with self.batched_params(group["params"]) as batches:
|
66 |
+
for p, state, p_names in batches:
|
67 |
+
...
|
68 |
+
</code>
|
69 |
+
|
70 |
+
Args:
|
71 |
+
group: a parameter group, which is a list of parameters; should be
|
72 |
+
one of self.param_groups.
|
73 |
+
group_params_names: name for each parameter in group,
|
74 |
+
which is List[str].
|
75 |
+
"""
|
76 |
+
batches = defaultdict(
|
77 |
+
list
|
78 |
+
) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
|
79 |
+
batches_names = defaultdict(
|
80 |
+
list
|
81 |
+
) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
|
82 |
+
|
83 |
+
assert len(param_group) == len(group_params_names), f"len(param_group): {len(param_group)}, len(group_params_names): {len(group_params_names)}"
|
84 |
+
for p, named_p in zip(param_group, group_params_names):
|
85 |
+
key = (str(p.dtype), *p.shape)
|
86 |
+
batches[key].append(p)
|
87 |
+
batches_names[key].append(named_p)
|
88 |
+
|
89 |
+
batches_names_keys = list(batches_names.keys())
|
90 |
+
sorted_idx = sorted(
|
91 |
+
range(len(batches_names)), key=lambda i: batches_names_keys[i]
|
92 |
+
)
|
93 |
+
batches_names = [
|
94 |
+
batches_names[batches_names_keys[idx]] for idx in sorted_idx
|
95 |
+
]
|
96 |
+
batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
|
97 |
+
|
98 |
+
stacked_params_dict = dict()
|
99 |
+
|
100 |
+
# turn batches into a list, in deterministic order.
|
101 |
+
# tuples will contain tuples of (stacked_param, state, stacked_params_names),
|
102 |
+
# one for each batch in `batches`.
|
103 |
+
tuples = []
|
104 |
+
|
105 |
+
for batch, batch_names in zip(batches, batches_names):
|
106 |
+
p = batch[0]
|
107 |
+
# we arbitrarily store the state in the
|
108 |
+
# state corresponding to the 1st parameter in the
|
109 |
+
# group. class Optimizer will take care of saving/loading state.
|
110 |
+
state = self.state[p]
|
111 |
+
p_stacked = torch.stack(batch)
|
112 |
+
grad = torch.stack(
|
113 |
+
[
|
114 |
+
torch.zeros_like(p) if p.grad is None else p.grad
|
115 |
+
for p in batch
|
116 |
+
]
|
117 |
+
)
|
118 |
+
p_stacked.grad = grad
|
119 |
+
stacked_params_dict[key] = p_stacked
|
120 |
+
tuples.append((p_stacked, state, batch_names))
|
121 |
+
|
122 |
+
yield tuples # <-- calling code will do the actual optimization here!
|
123 |
+
|
124 |
+
for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
|
125 |
+
for i, p in enumerate(batch): # batch is list of Parameter
|
126 |
+
p.copy_(stacked_params[i])
|
127 |
+
|
128 |
+
|
129 |
+
class ScaledAdam(BatchedOptimizer):
|
130 |
+
"""
|
131 |
+
Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
|
132 |
+
proportional to the norm of that parameter; and also learn the scale of the parameter,
|
133 |
+
in log space, subject to upper and lower limits (as if we had factored each parameter as
|
134 |
+
param = underlying_param * log_scale.exp())
|
135 |
+
|
136 |
+
|
137 |
+
Args:
|
138 |
+
params: The parameters or param_groups to optimize (like other Optimizer subclasses)
|
139 |
+
lr: The learning rate. We will typically use a learning rate schedule that starts
|
140 |
+
at 0.03 and decreases over time, i.e. much higher than other common
|
141 |
+
optimizers.
|
142 |
+
clipping_scale: (e.g. 2.0)
|
143 |
+
A scale for gradient-clipping: if specified, the normalized gradients
|
144 |
+
over the whole model will be clipped to have 2-norm equal to
|
145 |
+
`clipping_scale` times the median 2-norm over the most recent period
|
146 |
+
of `clipping_update_period` minibatches. By "normalized gradients",
|
147 |
+
we mean after multiplying by the rms parameter value for this tensor
|
148 |
+
[for non-scalars]; this is appropriate because our update is scaled
|
149 |
+
by this quantity.
|
150 |
+
betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad.
|
151 |
+
Must satisfy 0 < beta <= beta2 < 1.
|
152 |
+
scalar_lr_scale: A scaling factor on the learning rate, that we use to update the
|
153 |
+
scale of each parameter tensor and scalar parameters of the mode..
|
154 |
+
If each parameter were decomposed
|
155 |
+
as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale
|
156 |
+
would be a the scaling factor on the learning rate of p_scale.
|
157 |
+
eps: A general-purpose epsilon to prevent division by zero
|
158 |
+
param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
|
159 |
+
learning the scale on the parameters (we'll constrain the rms of each non-scalar
|
160 |
+
parameter tensor to be >= this value)
|
161 |
+
param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
|
162 |
+
learning the scale on the parameters (we'll constrain the rms of each non-scalar
|
163 |
+
parameter tensor to be <= this value)
|
164 |
+
scalar_max: Maximum absolute value for scalar parameters (applicable if your
|
165 |
+
model has any parameters with numel() == 1).
|
166 |
+
size_update_period: The periodicity, in steps, with which we update the size (scale)
|
167 |
+
of the parameter tensor. This is provided to save a little time
|
168 |
+
in the update.
|
169 |
+
clipping_update_period: if clipping_scale is specified, this is the period
|
170 |
+
"""
|
171 |
+
|
172 |
+
def __init__(
|
173 |
+
self,
|
174 |
+
params,
|
175 |
+
lr=3e-02,
|
176 |
+
clipping_scale=None,
|
177 |
+
betas=(0.9, 0.98),
|
178 |
+
scalar_lr_scale=0.1,
|
179 |
+
eps=1.0e-08,
|
180 |
+
param_min_rms=1.0e-05,
|
181 |
+
param_max_rms=3.0,
|
182 |
+
scalar_max=10.0,
|
183 |
+
size_update_period=4,
|
184 |
+
clipping_update_period=100,
|
185 |
+
parameters_names=None,
|
186 |
+
show_dominant_parameters=True,
|
187 |
+
):
|
188 |
+
|
189 |
+
assert parameters_names is not None, (
|
190 |
+
"Please prepare parameters_names,"
|
191 |
+
"which is a List[List[str]]. Each List[str] is for a group"
|
192 |
+
"and each str is for a parameter"
|
193 |
+
)
|
194 |
+
defaults = dict(
|
195 |
+
lr=lr,
|
196 |
+
clipping_scale=clipping_scale,
|
197 |
+
betas=betas,
|
198 |
+
scalar_lr_scale=scalar_lr_scale,
|
199 |
+
eps=eps,
|
200 |
+
param_min_rms=param_min_rms,
|
201 |
+
param_max_rms=param_max_rms,
|
202 |
+
scalar_max=scalar_max,
|
203 |
+
size_update_period=size_update_period,
|
204 |
+
clipping_update_period=clipping_update_period,
|
205 |
+
)
|
206 |
+
|
207 |
+
super(ScaledAdam, self).__init__(params, defaults)
|
208 |
+
assert len(self.param_groups) == len(parameters_names)
|
209 |
+
self.parameters_names = parameters_names
|
210 |
+
self.show_dominant_parameters = show_dominant_parameters
|
211 |
+
|
212 |
+
def __setstate__(self, state):
|
213 |
+
super(ScaledAdam, self).__setstate__(state)
|
214 |
+
|
215 |
+
@torch.no_grad()
|
216 |
+
def step(self, closure=None):
|
217 |
+
"""Performs a single optimization step.
|
218 |
+
|
219 |
+
Arguments:
|
220 |
+
closure (callable, optional): A closure that reevaluates the model
|
221 |
+
and returns the loss.
|
222 |
+
"""
|
223 |
+
loss = None
|
224 |
+
if closure is not None:
|
225 |
+
with torch.enable_grad():
|
226 |
+
loss = closure()
|
227 |
+
|
228 |
+
batch = True
|
229 |
+
|
230 |
+
for group, group_params_names in zip(
|
231 |
+
self.param_groups, self.parameters_names
|
232 |
+
):
|
233 |
+
|
234 |
+
with self.batched_params(
|
235 |
+
group["params"], group_params_names
|
236 |
+
) as batches:
|
237 |
+
|
238 |
+
# batches is list of pairs (stacked_param, state). stacked_param is like
|
239 |
+
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
|
240 |
+
# a stacking dim, it is not a real dim.
|
241 |
+
|
242 |
+
if (
|
243 |
+
len(batches[0][1]) == 0
|
244 |
+
): # if len(first state) == 0: not yet initialized
|
245 |
+
clipping_scale = 1
|
246 |
+
else:
|
247 |
+
clipping_scale = self._get_clipping_scale(group, batches)
|
248 |
+
|
249 |
+
for p, state, _ in batches:
|
250 |
+
# Perform optimization step.
|
251 |
+
# grad is not going to be None, we handled that when creating the batches.
|
252 |
+
grad = p.grad
|
253 |
+
if grad.is_sparse:
|
254 |
+
raise RuntimeError(
|
255 |
+
"ScaledAdam optimizer does not support sparse gradients"
|
256 |
+
)
|
257 |
+
# State initialization
|
258 |
+
if len(state) == 0:
|
259 |
+
self._init_state(group, p, state)
|
260 |
+
|
261 |
+
self._step_one_batch(group, p, state, clipping_scale)
|
262 |
+
|
263 |
+
return loss
|
264 |
+
|
265 |
+
def _init_state(self, group: dict, p: Tensor, state: dict):
|
266 |
+
"""
|
267 |
+
Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p
|
268 |
+
is actually the batch dimension, corresponding to batched-together
|
269 |
+
parameters of a given shape.
|
270 |
+
|
271 |
+
|
272 |
+
Args:
|
273 |
+
group: Dict to look up configuration values.
|
274 |
+
p: The parameter that we are initializing the state for
|
275 |
+
state: Dict from string to whatever state we are initializing
|
276 |
+
"""
|
277 |
+
size_update_period = group["size_update_period"]
|
278 |
+
|
279 |
+
state["step"] = 0
|
280 |
+
|
281 |
+
kwargs = {"device": p.device, "dtype": p.dtype}
|
282 |
+
|
283 |
+
# 'delta' implements conventional momentum. There are
|
284 |
+
# several different kinds of update going on, so rather than
|
285 |
+
# compute "exp_avg" like in Adam, we store and decay a
|
286 |
+
# parameter-change "delta", which combines all forms of
|
287 |
+
# update. this is equivalent to how it's done in Adam,
|
288 |
+
# except for the first few steps.
|
289 |
+
state["delta"] = torch.zeros_like(
|
290 |
+
p, memory_format=torch.preserve_format
|
291 |
+
)
|
292 |
+
|
293 |
+
batch_size = p.shape[0]
|
294 |
+
numel = p.numel() // batch_size
|
295 |
+
numel = p.numel()
|
296 |
+
|
297 |
+
if numel > 1:
|
298 |
+
# "param_rms" just periodically records the scalar root-mean-square value of
|
299 |
+
# the parameter tensor.
|
300 |
+
# it has a shape like (batch_size, 1, 1, 1, 1)
|
301 |
+
param_rms = (
|
302 |
+
(p ** 2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
|
303 |
+
)
|
304 |
+
state["param_rms"] = param_rms
|
305 |
+
|
306 |
+
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
|
307 |
+
state["scale_grads"] = torch.zeros(
|
308 |
+
size_update_period, *param_rms.shape, **kwargs
|
309 |
+
)
|
310 |
+
|
311 |
+
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
|
312 |
+
state["exp_avg_sq"] = torch.zeros_like(
|
313 |
+
p, memory_format=torch.preserve_format
|
314 |
+
)
|
315 |
+
|
316 |
+
def _get_clipping_scale(
|
317 |
+
self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]
|
318 |
+
) -> float:
|
319 |
+
"""
|
320 |
+
Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
|
321 |
+
by this amount before applying the rest of the update.
|
322 |
+
|
323 |
+
Args:
|
324 |
+
group: the parameter group, an item in self.param_groups
|
325 |
+
tuples: a list of tuples of (param, state, param_names)
|
326 |
+
where param is a batched set of parameters,
|
327 |
+
with a .grad (1st dim is batch dim)
|
328 |
+
and state is the state-dict where optimization parameters are kept.
|
329 |
+
param_names is a List[str] while each str is name for a parameter
|
330 |
+
in batched set of parameters "param".
|
331 |
+
"""
|
332 |
+
assert len(tuples) >= 1
|
333 |
+
clipping_scale = group["clipping_scale"]
|
334 |
+
(first_p, first_state, _) = tuples[0]
|
335 |
+
step = first_state["step"]
|
336 |
+
if clipping_scale is None or step == 0:
|
337 |
+
# no clipping. return early on step == 0 because the other
|
338 |
+
# parameters' state won't have been initialized yet.
|
339 |
+
return 1.0
|
340 |
+
clipping_update_period = group["clipping_update_period"]
|
341 |
+
|
342 |
+
tot_sumsq = torch.tensor(0.0, device=first_p.device)
|
343 |
+
for (p, state, param_names) in tuples:
|
344 |
+
grad = p.grad
|
345 |
+
if grad.is_sparse:
|
346 |
+
raise RuntimeError(
|
347 |
+
"ScaledAdam optimizer does not support sparse gradients"
|
348 |
+
)
|
349 |
+
if p.numel() == p.shape[0]: # a batch of scalars
|
350 |
+
tot_sumsq += (
|
351 |
+
grad ** 2
|
352 |
+
).sum() # sum() to change shape [1] to []
|
353 |
+
else:
|
354 |
+
tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
|
355 |
+
|
356 |
+
tot_norm = tot_sumsq.sqrt()
|
357 |
+
if "model_norms" not in first_state:
|
358 |
+
first_state["model_norms"] = torch.zeros(
|
359 |
+
clipping_update_period, device=p.device
|
360 |
+
)
|
361 |
+
first_state["model_norms"][step % clipping_update_period] = tot_norm
|
362 |
+
|
363 |
+
if step % clipping_update_period == 0:
|
364 |
+
# Print some stats.
|
365 |
+
# We don't reach here if step == 0 because we would have returned
|
366 |
+
# above.
|
367 |
+
sorted_norms = first_state["model_norms"].sort()[0].to("cpu")
|
368 |
+
quartiles = []
|
369 |
+
for n in range(0, 5):
|
370 |
+
index = min(
|
371 |
+
clipping_update_period - 1,
|
372 |
+
(clipping_update_period // 4) * n,
|
373 |
+
)
|
374 |
+
quartiles.append(sorted_norms[index].item())
|
375 |
+
|
376 |
+
median = quartiles[2]
|
377 |
+
threshold = clipping_scale * median
|
378 |
+
first_state["model_norm_threshold"] = threshold
|
379 |
+
percent_clipped = (
|
380 |
+
first_state["num_clipped"] * 100.0 / clipping_update_period
|
381 |
+
if "num_clipped" in first_state
|
382 |
+
else 0.0
|
383 |
+
)
|
384 |
+
first_state["num_clipped"] = 0
|
385 |
+
quartiles = " ".join(["%.3e" % x for x in quartiles])
|
386 |
+
logging.info(
|
387 |
+
f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
|
388 |
+
f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
|
389 |
+
)
|
390 |
+
|
391 |
+
if step < clipping_update_period:
|
392 |
+
return 1.0 # We have not yet estimated a norm to clip to.
|
393 |
+
else:
|
394 |
+
try:
|
395 |
+
model_norm_threshold = first_state["model_norm_threshold"]
|
396 |
+
except KeyError:
|
397 |
+
logging.info(
|
398 |
+
"Warning: model_norm_threshold not in state: possibly "
|
399 |
+
"you changed config when restarting, adding clipping_scale option?"
|
400 |
+
)
|
401 |
+
return 1.0
|
402 |
+
ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
|
403 |
+
if ans < 1.0:
|
404 |
+
first_state["num_clipped"] += 1
|
405 |
+
if ans < 0.1:
|
406 |
+
logging.warn(
|
407 |
+
f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
|
408 |
+
)
|
409 |
+
if self.show_dominant_parameters:
|
410 |
+
assert p.shape[0] == len(param_names)
|
411 |
+
self._show_gradient_dominating_parameter(tuples, tot_sumsq)
|
412 |
+
return ans
|
413 |
+
|
414 |
+
def _show_gradient_dominating_parameter(
|
415 |
+
self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor
|
416 |
+
):
|
417 |
+
"""
|
418 |
+
Show information of parameter wihch dominanting tot_sumsq.
|
419 |
+
|
420 |
+
Args:
|
421 |
+
tuples: a list of tuples of (param, state, param_names)
|
422 |
+
where param is a batched set of parameters,
|
423 |
+
with a .grad (1st dim is batch dim)
|
424 |
+
and state is the state-dict where optimization parameters are kept.
|
425 |
+
param_names is a List[str] while each str is name for a parameter
|
426 |
+
in batched set of parameters "param".
|
427 |
+
tot_sumsq: sumsq of all parameters. Though it's could be calculated
|
428 |
+
from tuples, we still pass it to save some time.
|
429 |
+
"""
|
430 |
+
all_sumsq_orig = {}
|
431 |
+
for (p, state, batch_param_names) in tuples:
|
432 |
+
# p is a stacked batch parameters.
|
433 |
+
batch_grad = p.grad
|
434 |
+
if p.numel() == p.shape[0]: # a batch of scalars
|
435 |
+
batch_sumsq_orig = batch_grad ** 2
|
436 |
+
# Dummpy values used by following `zip` statement.
|
437 |
+
batch_rms_orig = torch.ones(p.shape[0])
|
438 |
+
else:
|
439 |
+
batch_rms_orig = state["param_rms"]
|
440 |
+
batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum(
|
441 |
+
dim=list(range(1, batch_grad.ndim))
|
442 |
+
)
|
443 |
+
|
444 |
+
for name, sumsq_orig, rms, grad in zip(
|
445 |
+
batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
|
446 |
+
):
|
447 |
+
|
448 |
+
proportion_orig = sumsq_orig / tot_sumsq
|
449 |
+
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
|
450 |
+
|
451 |
+
assert torch.isclose(
|
452 |
+
sum([value[0] for value in all_sumsq_orig.values()]).cpu(),
|
453 |
+
torch.tensor(1.0),
|
454 |
+
)
|
455 |
+
sorted_by_proportion = {
|
456 |
+
k: v
|
457 |
+
for k, v in sorted(
|
458 |
+
all_sumsq_orig.items(),
|
459 |
+
key=lambda item: item[1][0],
|
460 |
+
reverse=True,
|
461 |
+
)
|
462 |
+
}
|
463 |
+
dominant_param_name = next(iter(sorted_by_proportion))
|
464 |
+
(
|
465 |
+
dominant_proportion,
|
466 |
+
dominant_sumsq,
|
467 |
+
dominant_rms,
|
468 |
+
dominant_grad,
|
469 |
+
) = sorted_by_proportion[dominant_param_name]
|
470 |
+
logging.info(
|
471 |
+
f"Parameter Dominanting tot_sumsq {dominant_param_name}"
|
472 |
+
f" with proportion {dominant_proportion:.2f},"
|
473 |
+
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
|
474 |
+
f"={dominant_sumsq:.3e},"
|
475 |
+
f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
|
476 |
+
f" orig_rms_sq={(dominant_rms**2).item():.3e}"
|
477 |
+
)
|
478 |
+
|
479 |
+
def _step_one_batch(
|
480 |
+
self, group: dict, p: Tensor, state: dict, clipping_scale: float
|
481 |
+
):
|
482 |
+
"""
|
483 |
+
Do the step for one parameter, which is actually going to be a batch of
|
484 |
+
`real` parameters, with dim 0 as the batch dim.
|
485 |
+
Args:
|
486 |
+
group: dict to look up configuration values
|
487 |
+
p: parameter to update (actually multiple parameters stacked together
|
488 |
+
as a batch)
|
489 |
+
state: state-dict for p, to look up the optimizer state
|
490 |
+
"""
|
491 |
+
lr = group["lr"]
|
492 |
+
size_update_period = group["size_update_period"]
|
493 |
+
beta1 = group["betas"][0]
|
494 |
+
|
495 |
+
grad = p.grad
|
496 |
+
if clipping_scale != 1.0:
|
497 |
+
grad = grad * clipping_scale
|
498 |
+
step = state["step"]
|
499 |
+
delta = state["delta"]
|
500 |
+
|
501 |
+
delta.mul_(beta1)
|
502 |
+
batch_size = p.shape[0]
|
503 |
+
numel = p.numel() // batch_size
|
504 |
+
if numel > 1:
|
505 |
+
# Update the size/scale of p, and set param_rms
|
506 |
+
scale_grads = state["scale_grads"]
|
507 |
+
scale_grads[step % size_update_period] = (p * grad).sum(
|
508 |
+
dim=list(range(1, p.ndim)), keepdim=True
|
509 |
+
)
|
510 |
+
if step % size_update_period == size_update_period - 1:
|
511 |
+
param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
|
512 |
+
param_rms.copy_(
|
513 |
+
(p ** 2)
|
514 |
+
.mean(dim=list(range(1, p.ndim)), keepdim=True)
|
515 |
+
.sqrt()
|
516 |
+
)
|
517 |
+
if step > 0:
|
518 |
+
# self._size_update() learns the overall scale on the
|
519 |
+
# parameter, by shrinking or expanding it.
|
520 |
+
self._size_update(group, scale_grads, p, state)
|
521 |
+
|
522 |
+
if numel == 1:
|
523 |
+
# For parameters with 1 element we just use regular Adam.
|
524 |
+
# Updates delta.
|
525 |
+
self._step_scalar(group, p, state)
|
526 |
+
else:
|
527 |
+
self._step(group, p, state)
|
528 |
+
|
529 |
+
state["step"] = step + 1
|
530 |
+
|
531 |
+
def _size_update(
|
532 |
+
self, group: dict, scale_grads: Tensor, p: Tensor, state: dict
|
533 |
+
) -> None:
|
534 |
+
"""
|
535 |
+
Called only where p.numel() > 1, this updates the scale of the parameter.
|
536 |
+
If we imagine: p = underlying_param * scale.exp(), and we are doing
|
537 |
+
gradient descent on underlying param and on scale, this function does the update
|
538 |
+
on `scale`.
|
539 |
+
|
540 |
+
Args:
|
541 |
+
group: dict to look up configuration values
|
542 |
+
scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing
|
543 |
+
grads w.r.t. the scales.
|
544 |
+
p: The parameter to update
|
545 |
+
state: The state-dict of p
|
546 |
+
"""
|
547 |
+
|
548 |
+
param_rms = state["param_rms"]
|
549 |
+
beta1, beta2 = group["betas"]
|
550 |
+
size_lr = group["lr"] * group["scalar_lr_scale"]
|
551 |
+
param_min_rms = group["param_min_rms"]
|
552 |
+
param_max_rms = group["param_max_rms"]
|
553 |
+
eps = group["eps"]
|
554 |
+
step = state["step"]
|
555 |
+
batch_size = p.shape[0]
|
556 |
+
|
557 |
+
size_update_period = scale_grads.shape[0]
|
558 |
+
# correct beta2 for the size update period: we will have
|
559 |
+
# faster decay at this level.
|
560 |
+
beta2_corr = beta2 ** size_update_period
|
561 |
+
|
562 |
+
scale_exp_avg_sq = state[
|
563 |
+
"scale_exp_avg_sq"
|
564 |
+
] # shape: (batch_size, 1, 1, ..)
|
565 |
+
scale_exp_avg_sq.mul_(beta2_corr).add_(
|
566 |
+
(scale_grads ** 2).mean(
|
567 |
+
dim=0
|
568 |
+
), # mean over dim `size_update_period`
|
569 |
+
alpha=1 - beta2_corr,
|
570 |
+
) # shape is (batch_size, 1, 1, ...)
|
571 |
+
|
572 |
+
# The 1st time we reach here is when size_step == 1.
|
573 |
+
size_step = (step + 1) // size_update_period
|
574 |
+
bias_correction2 = 1 - beta2_corr ** size_step
|
575 |
+
# we don't bother with bias_correction1; this will help prevent divergence
|
576 |
+
# at the start of training.
|
577 |
+
|
578 |
+
denom = scale_exp_avg_sq.sqrt() + eps
|
579 |
+
|
580 |
+
scale_step = (
|
581 |
+
-size_lr
|
582 |
+
* (bias_correction2 ** 0.5)
|
583 |
+
* scale_grads.sum(dim=0)
|
584 |
+
/ denom
|
585 |
+
)
|
586 |
+
|
587 |
+
is_too_small = param_rms < param_min_rms
|
588 |
+
is_too_large = param_rms > param_max_rms
|
589 |
+
|
590 |
+
# when the param gets too small, just don't shrink it any further.
|
591 |
+
scale_step.masked_fill_(is_too_small, 0.0)
|
592 |
+
# when it gets too large, stop it from getting any larger.
|
593 |
+
scale_step.masked_fill_(is_too_large, -size_lr * size_update_period)
|
594 |
+
delta = state["delta"]
|
595 |
+
# the factor of (1-beta1) relates to momentum.
|
596 |
+
delta.add_(p * scale_step, alpha=(1 - beta1))
|
597 |
+
|
598 |
+
def _step(self, group: dict, p: Tensor, state: dict):
|
599 |
+
"""
|
600 |
+
This function does the core update of self.step(), in the case where the members of
|
601 |
+
the batch have more than 1 element.
|
602 |
+
|
603 |
+
Args:
|
604 |
+
group: A dict which will be used to look up configuration values
|
605 |
+
p: The parameter to be updated
|
606 |
+
grad: The grad of p
|
607 |
+
state: The state-dict corresponding to parameter p
|
608 |
+
|
609 |
+
This function modifies p.
|
610 |
+
"""
|
611 |
+
grad = p.grad
|
612 |
+
lr = group["lr"]
|
613 |
+
beta1, beta2 = group["betas"]
|
614 |
+
eps = group["eps"]
|
615 |
+
param_min_rms = group["param_min_rms"]
|
616 |
+
step = state["step"]
|
617 |
+
|
618 |
+
exp_avg_sq = state["exp_avg_sq"]
|
619 |
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
|
620 |
+
|
621 |
+
this_step = state["step"] - (
|
622 |
+
state["zero_step"] if "zero_step" in state else 0
|
623 |
+
)
|
624 |
+
bias_correction2 = 1 - beta2 ** (this_step + 1)
|
625 |
+
if bias_correction2 < 0.99:
|
626 |
+
# note: not in-place.
|
627 |
+
exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
|
628 |
+
|
629 |
+
denom = exp_avg_sq.sqrt()
|
630 |
+
denom += eps
|
631 |
+
grad = grad / denom
|
632 |
+
|
633 |
+
alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms)
|
634 |
+
|
635 |
+
delta = state["delta"]
|
636 |
+
delta.add_(grad * alpha)
|
637 |
+
p.add_(delta)
|
638 |
+
|
639 |
+
def _step_scalar(self, group: dict, p: Tensor, state: dict):
|
640 |
+
"""
|
641 |
+
A simplified form of the core update for scalar tensors, where we cannot get a good
|
642 |
+
estimate of the parameter rms.
|
643 |
+
"""
|
644 |
+
beta1, beta2 = group["betas"]
|
645 |
+
scalar_max = group["scalar_max"]
|
646 |
+
eps = group["eps"]
|
647 |
+
lr = group["lr"] * group["scalar_lr_scale"]
|
648 |
+
grad = p.grad
|
649 |
+
|
650 |
+
exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,)
|
651 |
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
652 |
+
|
653 |
+
# bias_correction2 is like in Adam. Don't bother with bias_correction1;
|
654 |
+
# slower update at the start will help stability anyway.
|
655 |
+
bias_correction2 = 1 - beta2 ** (state["step"] + 1)
|
656 |
+
denom = (exp_avg_sq / bias_correction2).sqrt() + eps
|
657 |
+
|
658 |
+
delta = state["delta"]
|
659 |
+
delta.add_(grad / denom, alpha=-lr * (1 - beta1))
|
660 |
+
p.clamp_(min=-scalar_max, max=scalar_max)
|
661 |
+
p.add_(delta)
|
662 |
+
|
663 |
+
|
664 |
+
class LRScheduler(object):
|
665 |
+
"""
|
666 |
+
Base-class for learning rate schedulers where the learning-rate depends on both the
|
667 |
+
batch and the epoch.
|
668 |
+
"""
|
669 |
+
|
670 |
+
def __init__(self, optimizer: Optimizer, verbose: bool = False):
|
671 |
+
# Attach optimizer
|
672 |
+
if not isinstance(optimizer, Optimizer):
|
673 |
+
raise TypeError(
|
674 |
+
"{} is not an Optimizer".format(type(optimizer).__name__)
|
675 |
+
)
|
676 |
+
self.optimizer = optimizer
|
677 |
+
self.verbose = verbose
|
678 |
+
|
679 |
+
for group in optimizer.param_groups:
|
680 |
+
group.setdefault("base_lr", group["lr"])
|
681 |
+
|
682 |
+
self.base_lrs = [group["base_lr"] for group in optimizer.param_groups]
|
683 |
+
|
684 |
+
self.epoch = 0
|
685 |
+
self.batch = 0
|
686 |
+
|
687 |
+
def state_dict(self):
|
688 |
+
"""Returns the state of the scheduler as a :class:`dict`.
|
689 |
+
|
690 |
+
It contains an entry for every variable in self.__dict__ which
|
691 |
+
is not the optimizer.
|
692 |
+
"""
|
693 |
+
return {
|
694 |
+
"base_lrs": self.base_lrs,
|
695 |
+
"epoch": self.epoch,
|
696 |
+
"batch": self.batch,
|
697 |
+
}
|
698 |
+
|
699 |
+
def load_state_dict(self, state_dict):
|
700 |
+
"""Loads the schedulers state.
|
701 |
+
|
702 |
+
Args:
|
703 |
+
state_dict (dict): scheduler state. Should be an object returned
|
704 |
+
from a call to :meth:`state_dict`.
|
705 |
+
"""
|
706 |
+
self.__dict__.update(state_dict)
|
707 |
+
|
708 |
+
def get_last_lr(self) -> List[float]:
|
709 |
+
"""Return last computed learning rate by current scheduler. Will be a list of float."""
|
710 |
+
return self._last_lr
|
711 |
+
|
712 |
+
def get_lr(self):
|
713 |
+
# Compute list of learning rates from self.epoch and self.batch and
|
714 |
+
# self.base_lrs; this must be overloaded by the user.
|
715 |
+
# e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ]
|
716 |
+
raise NotImplementedError
|
717 |
+
|
718 |
+
def step_batch(self, batch: Optional[int] = None) -> None:
|
719 |
+
# Step the batch index, or just set it. If `batch` is specified, it
|
720 |
+
# must be the batch index from the start of training, i.e. summed over
|
721 |
+
# all epochs.
|
722 |
+
# You can call this in any order; if you don't provide 'batch', it should
|
723 |
+
# of course be called once per batch.
|
724 |
+
if batch is not None:
|
725 |
+
self.batch = batch
|
726 |
+
else:
|
727 |
+
self.batch = self.batch + 1
|
728 |
+
self._set_lrs()
|
729 |
+
|
730 |
+
def step_epoch(self, epoch: Optional[int] = None):
|
731 |
+
# Step the epoch index, or just set it. If you provide the 'epoch' arg,
|
732 |
+
# you should call this at the start of the epoch; if you don't provide the 'epoch'
|
733 |
+
# arg, you should call it at the end of the epoch.
|
734 |
+
if epoch is not None:
|
735 |
+
self.epoch = epoch
|
736 |
+
else:
|
737 |
+
self.epoch = self.epoch + 1
|
738 |
+
self._set_lrs()
|
739 |
+
|
740 |
+
def _set_lrs(self):
|
741 |
+
values = self.get_lr()
|
742 |
+
assert len(values) == len(self.optimizer.param_groups)
|
743 |
+
|
744 |
+
for i, data in enumerate(zip(self.optimizer.param_groups, values)):
|
745 |
+
param_group, lr = data
|
746 |
+
param_group["lr"] = lr
|
747 |
+
self.print_lr(self.verbose, i, lr)
|
748 |
+
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
|
749 |
+
|
750 |
+
def print_lr(self, is_verbose, group, lr):
|
751 |
+
"""Display the current learning rate."""
|
752 |
+
if is_verbose:
|
753 |
+
logging.info(
|
754 |
+
f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
|
755 |
+
f" of group {group} to {lr:.4e}."
|
756 |
+
)
|
757 |
+
|
758 |
+
|
759 |
+
class Eden(LRScheduler):
|
760 |
+
"""
|
761 |
+
Eden scheduler.
|
762 |
+
The basic formula (before warmup) is:
|
763 |
+
lr = base_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 *
|
764 |
+
(((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) * warmup
|
765 |
+
where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches
|
766 |
+
and then stays constant at 1.
|
767 |
+
|
768 |
+
|
769 |
+
E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam
|
770 |
+
|
771 |
+
Args:
|
772 |
+
optimizer: the optimizer to change the learning rates on
|
773 |
+
lr_batches: the number of batches after which we start significantly
|
774 |
+
decreasing the learning rate, suggest 5000.
|
775 |
+
lr_epochs: the number of epochs after which we start significantly
|
776 |
+
decreasing the learning rate, suggest 6 if you plan to do e.g.
|
777 |
+
20 to 40 epochs, but may need smaller number if dataset is huge
|
778 |
+
and you will do few epochs.
|
779 |
+
"""
|
780 |
+
|
781 |
+
def __init__(
|
782 |
+
self,
|
783 |
+
optimizer: Optimizer,
|
784 |
+
lr_batches: Union[int, float],
|
785 |
+
lr_epochs: Union[int, float],
|
786 |
+
warmup_batches: Union[int, float] = 500.0,
|
787 |
+
verbose: bool = False,
|
788 |
+
):
|
789 |
+
super(Eden, self).__init__(optimizer, verbose)
|
790 |
+
self.lr_batches = lr_batches
|
791 |
+
self.lr_epochs = lr_epochs
|
792 |
+
self.warmup_batches = warmup_batches
|
793 |
+
|
794 |
+
def get_lr(self):
|
795 |
+
factor = (
|
796 |
+
(self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2
|
797 |
+
) ** -0.25 * (
|
798 |
+
((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2)
|
799 |
+
** -0.25
|
800 |
+
)
|
801 |
+
warmup_factor = (
|
802 |
+
1.0
|
803 |
+
if self.batch >= self.warmup_batches
|
804 |
+
else 0.5 + 0.5 * (self.batch / self.warmup_batches)
|
805 |
+
)
|
806 |
+
|
807 |
+
return [x * factor * warmup_factor for x in self.base_lrs]
|
808 |
+
|
809 |
+
|
810 |
+
def _test_eden():
|
811 |
+
m = torch.nn.Linear(100, 100)
|
812 |
+
optim = ScaledAdam(m.parameters(), lr=0.03)
|
813 |
+
|
814 |
+
scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True)
|
815 |
+
|
816 |
+
for epoch in range(10):
|
817 |
+
scheduler.step_epoch(epoch) # sets epoch to `epoch`
|
818 |
+
|
819 |
+
for step in range(20):
|
820 |
+
x = torch.randn(200, 100).detach()
|
821 |
+
x.requires_grad = True
|
822 |
+
y = m(x)
|
823 |
+
dy = torch.randn(200, 100).detach()
|
824 |
+
f = (y * dy).sum()
|
825 |
+
f.backward()
|
826 |
+
|
827 |
+
optim.step()
|
828 |
+
scheduler.step_batch()
|
829 |
+
optim.zero_grad()
|
830 |
+
|
831 |
+
logging.info(f"last lr = {scheduler.get_last_lr()}")
|
832 |
+
logging.info(f"state dict = {scheduler.state_dict()}")
|
833 |
+
|
834 |
+
|
835 |
+
# This is included mostly as a baseline for ScaledAdam.
|
836 |
+
class Eve(Optimizer):
|
837 |
+
"""
|
838 |
+
Implements Eve algorithm. This is a modified version of AdamW with a special
|
839 |
+
way of setting the weight-decay / shrinkage-factor, which is designed to make the
|
840 |
+
rms of the parameters approach a particular target_rms (default: 0.1). This is
|
841 |
+
for use with networks with 'scaled' versions of modules (see scaling.py), which
|
842 |
+
will be close to invariant to the absolute scale on the parameter matrix.
|
843 |
+
|
844 |
+
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
|
845 |
+
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
|
846 |
+
Eve is unpublished so far.
|
847 |
+
|
848 |
+
Arguments:
|
849 |
+
params (iterable): iterable of parameters to optimize or dicts defining
|
850 |
+
parameter groups
|
851 |
+
lr (float, optional): learning rate (default: 1e-3)
|
852 |
+
betas (Tuple[float, float], optional): coefficients used for computing
|
853 |
+
running averages of gradient and its square (default: (0.9, 0.999))
|
854 |
+
eps (float, optional): term added to the denominator to improve
|
855 |
+
numerical stability (default: 1e-8)
|
856 |
+
weight_decay (float, optional): weight decay coefficient (default: 3e-4;
|
857 |
+
this value means that the weight would decay significantly after
|
858 |
+
about 3k minibatches. Is not multiplied by learning rate, but
|
859 |
+
is conditional on RMS-value of parameter being > target_rms.
|
860 |
+
target_rms (float, optional): target root-mean-square value of
|
861 |
+
parameters, if they fall below this we will stop applying weight decay.
|
862 |
+
|
863 |
+
|
864 |
+
.. _Adam: A Method for Stochastic Optimization:
|
865 |
+
https://arxiv.org/abs/1412.6980
|
866 |
+
.. _Decoupled Weight Decay Regularization:
|
867 |
+
https://arxiv.org/abs/1711.05101
|
868 |
+
.. _On the Convergence of Adam and Beyond:
|
869 |
+
https://openreview.net/forum?id=ryQu7f-RZ
|
870 |
+
"""
|
871 |
+
|
872 |
+
def __init__(
|
873 |
+
self,
|
874 |
+
params,
|
875 |
+
lr=1e-3,
|
876 |
+
betas=(0.9, 0.98),
|
877 |
+
eps=1e-8,
|
878 |
+
weight_decay=1e-3,
|
879 |
+
target_rms=0.1,
|
880 |
+
):
|
881 |
+
if not 0.0 <= lr:
|
882 |
+
raise ValueError("Invalid learning rate: {}".format(lr))
|
883 |
+
if not 0.0 <= eps:
|
884 |
+
raise ValueError("Invalid epsilon value: {}".format(eps))
|
885 |
+
if not 0.0 <= betas[0] < 1.0:
|
886 |
+
raise ValueError(
|
887 |
+
"Invalid beta parameter at index 0: {}".format(betas[0])
|
888 |
+
)
|
889 |
+
if not 0.0 <= betas[1] < 1.0:
|
890 |
+
raise ValueError(
|
891 |
+
"Invalid beta parameter at index 1: {}".format(betas[1])
|
892 |
+
)
|
893 |
+
if not 0 <= weight_decay <= 0.1:
|
894 |
+
raise ValueError(
|
895 |
+
"Invalid weight_decay value: {}".format(weight_decay)
|
896 |
+
)
|
897 |
+
if not 0 < target_rms <= 10.0:
|
898 |
+
raise ValueError("Invalid target_rms value: {}".format(target_rms))
|
899 |
+
defaults = dict(
|
900 |
+
lr=lr,
|
901 |
+
betas=betas,
|
902 |
+
eps=eps,
|
903 |
+
weight_decay=weight_decay,
|
904 |
+
target_rms=target_rms,
|
905 |
+
)
|
906 |
+
super(Eve, self).__init__(params, defaults)
|
907 |
+
|
908 |
+
def __setstate__(self, state):
|
909 |
+
super(Eve, self).__setstate__(state)
|
910 |
+
|
911 |
+
@torch.no_grad()
|
912 |
+
def step(self, closure=None):
|
913 |
+
"""Performs a single optimization step.
|
914 |
+
|
915 |
+
Arguments:
|
916 |
+
closure (callable, optional): A closure that reevaluates the model
|
917 |
+
and returns the loss.
|
918 |
+
"""
|
919 |
+
loss = None
|
920 |
+
if closure is not None:
|
921 |
+
with torch.enable_grad():
|
922 |
+
loss = closure()
|
923 |
+
|
924 |
+
for group in self.param_groups:
|
925 |
+
for p in group["params"]:
|
926 |
+
if p.grad is None:
|
927 |
+
continue
|
928 |
+
|
929 |
+
# Perform optimization step
|
930 |
+
grad = p.grad
|
931 |
+
if grad.is_sparse:
|
932 |
+
raise RuntimeError(
|
933 |
+
"AdamW does not support sparse gradients"
|
934 |
+
)
|
935 |
+
|
936 |
+
state = self.state[p]
|
937 |
+
|
938 |
+
# State initialization
|
939 |
+
if len(state) == 0:
|
940 |
+
state["step"] = 0
|
941 |
+
# Exponential moving average of gradient values
|
942 |
+
state["exp_avg"] = torch.zeros_like(
|
943 |
+
p, memory_format=torch.preserve_format
|
944 |
+
)
|
945 |
+
# Exponential moving average of squared gradient values
|
946 |
+
state["exp_avg_sq"] = torch.zeros_like(
|
947 |
+
p, memory_format=torch.preserve_format
|
948 |
+
)
|
949 |
+
|
950 |
+
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
951 |
+
|
952 |
+
beta1, beta2 = group["betas"]
|
953 |
+
|
954 |
+
state["step"] += 1
|
955 |
+
bias_correction1 = 1 - beta1 ** state["step"]
|
956 |
+
bias_correction2 = 1 - beta2 ** state["step"]
|
957 |
+
|
958 |
+
# Decay the first and second moment running average coefficient
|
959 |
+
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
960 |
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
961 |
+
denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_(
|
962 |
+
group["eps"]
|
963 |
+
)
|
964 |
+
|
965 |
+
step_size = group["lr"] / bias_correction1
|
966 |
+
target_rms = group["target_rms"]
|
967 |
+
weight_decay = group["weight_decay"]
|
968 |
+
|
969 |
+
if p.numel() > 1:
|
970 |
+
# avoid applying this weight-decay on "scaling factors"
|
971 |
+
# (which are scalar).
|
972 |
+
is_above_target_rms = p.norm() > (
|
973 |
+
target_rms * (p.numel() ** 0.5)
|
974 |
+
)
|
975 |
+
p.mul_(1 - (weight_decay * is_above_target_rms))
|
976 |
+
|
977 |
+
p.addcdiv_(exp_avg, denom, value=-step_size)
|
978 |
+
|
979 |
+
# if random.random() < 0.0005:
|
980 |
+
# step = (exp_avg / denom) * step_size
|
981 |
+
# logging.info(
|
982 |
+
# f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}"
|
983 |
+
# )
|
984 |
+
|
985 |
+
return loss
|
986 |
+
|
987 |
+
def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
|
988 |
+
"""
|
989 |
+
Behaves like a constructor of a modified version of nn.Linear
|
990 |
+
that gives an easy way to set the default initial parameter scale.
|
991 |
+
|
992 |
+
Args:
|
993 |
+
Accepts the standard args and kwargs that nn.Linear accepts
|
994 |
+
e.g. in_features, out_features, bias=False.
|
995 |
+
|
996 |
+
initial_scale: you can override this if you want to increase
|
997 |
+
or decrease the initial magnitude of the module's output
|
998 |
+
(affects the initialization of weight_scale and bias_scale).
|
999 |
+
Another option, if you want to do something like this, is
|
1000 |
+
to re-initialize the parameters.
|
1001 |
+
"""
|
1002 |
+
ans = nn.Linear(*args, **kwargs)
|
1003 |
+
with torch.no_grad():
|
1004 |
+
ans.weight[:] *= initial_scale
|
1005 |
+
if ans.bias is not None:
|
1006 |
+
torch.nn.init.uniform_(
|
1007 |
+
ans.bias, -0.1 * initial_scale, 0.1 * initial_scale
|
1008 |
+
)
|
1009 |
+
return ans
|
1010 |
+
def _test_scaled_adam(hidden_dim: int):
|
1011 |
+
import timeit
|
1012 |
+
|
1013 |
+
E = 100
|
1014 |
+
B = 4
|
1015 |
+
T = 2
|
1016 |
+
logging.info("in test_eve_cain")
|
1017 |
+
# device = torch.device('cuda')
|
1018 |
+
device = torch.device("cpu")
|
1019 |
+
dtype = torch.float32
|
1020 |
+
|
1021 |
+
# these input_magnitudes and output_magnitudes are to test that
|
1022 |
+
# Abel is working as we expect and is able to adjust scales of
|
1023 |
+
# different dims differently.
|
1024 |
+
input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
|
1025 |
+
output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
|
1026 |
+
|
1027 |
+
for iter in [1, 0]:
|
1028 |
+
Linear = torch.nn.Linear if iter == 0 else ScaledLinear
|
1029 |
+
|
1030 |
+
m = torch.nn.Sequential(
|
1031 |
+
Linear(E, hidden_dim),
|
1032 |
+
torch.nn.PReLU(),
|
1033 |
+
Linear(hidden_dim, hidden_dim),
|
1034 |
+
torch.nn.PReLU(),
|
1035 |
+
Linear(hidden_dim, E),
|
1036 |
+
).to(device)
|
1037 |
+
|
1038 |
+
train_pairs = [
|
1039 |
+
(
|
1040 |
+
100.0
|
1041 |
+
* torch.randn(B, T, E, device=device, dtype=dtype)
|
1042 |
+
* input_magnitudes,
|
1043 |
+
torch.randn(B, T, E, device=device, dtype=dtype)
|
1044 |
+
* output_magnitudes,
|
1045 |
+
)
|
1046 |
+
for _ in range(20)
|
1047 |
+
]
|
1048 |
+
|
1049 |
+
if iter == 0:
|
1050 |
+
optim = Eve(m.parameters(), lr=0.003)
|
1051 |
+
elif iter == 1:
|
1052 |
+
optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0)
|
1053 |
+
scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
|
1054 |
+
|
1055 |
+
start = timeit.default_timer()
|
1056 |
+
avg_loss = 0.0
|
1057 |
+
for epoch in range(180):
|
1058 |
+
scheduler.step_epoch()
|
1059 |
+
# if epoch == 100 and iter in [2,3]:
|
1060 |
+
# optim.reset_speedup() # check it doesn't crash.
|
1061 |
+
|
1062 |
+
# if epoch == 130:
|
1063 |
+
# opts = diagnostics.TensorDiagnosticOptions(
|
1064 |
+
# 2 ** 22
|
1065 |
+
# ) # allow 4 megabytes per sub-module
|
1066 |
+
# diagnostic = diagnostics.attach_diagnostics(m, opts)
|
1067 |
+
|
1068 |
+
for n, (x, y) in enumerate(train_pairs):
|
1069 |
+
y_out = m(x)
|
1070 |
+
loss = ((y_out - y) ** 2).mean() * 100.0
|
1071 |
+
if epoch == 0 and n == 0:
|
1072 |
+
avg_loss = loss.item()
|
1073 |
+
else:
|
1074 |
+
avg_loss = 0.98 * avg_loss + 0.02 * loss.item()
|
1075 |
+
if n == 0 and epoch % 5 == 0:
|
1076 |
+
# norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
|
1077 |
+
# norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
|
1078 |
+
# norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()
|
1079 |
+
# norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item()
|
1080 |
+
# scale1 = '%.2e' % (m[0].weight_scale.exp().item())
|
1081 |
+
# scale1b = '%.2e' % (m[0].bias_scale.exp().item())
|
1082 |
+
# scale2 = '%.2e' % (m[2].weight_scale.exp().item())
|
1083 |
+
# scale2b = '%.2e' % (m[2].bias_scale.exp().item())
|
1084 |
+
lr = scheduler.get_last_lr()[0]
|
1085 |
+
logging.info(
|
1086 |
+
f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}"
|
1087 |
+
) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
|
1088 |
+
loss.log().backward()
|
1089 |
+
optim.step()
|
1090 |
+
optim.zero_grad()
|
1091 |
+
scheduler.step_batch()
|
1092 |
+
|
1093 |
+
# diagnostic.print_diagnostics()
|
1094 |
+
|
1095 |
+
stop = timeit.default_timer()
|
1096 |
+
logging.info(f"Iter={iter}, Time taken: {stop - start}")
|
1097 |
+
|
1098 |
+
logging.info(f"last lr = {scheduler.get_last_lr()}")
|
1099 |
+
# logging.info("state dict = ", scheduler.state_dict())
|
1100 |
+
# logging.info("optim state_dict = ", optim.state_dict())
|
1101 |
+
logging.info(f"input_magnitudes = {input_magnitudes}")
|
1102 |
+
logging.info(f"output_magnitudes = {output_magnitudes}")
|
1103 |
+
|
1104 |
+
|
1105 |
+
if __name__ == "__main__":
|
1106 |
+
torch.set_num_threads(1)
|
1107 |
+
torch.set_num_interop_threads(1)
|
1108 |
+
logging.getLogger().setLevel(logging.INFO)
|
1109 |
+
import subprocess
|
1110 |
+
|
1111 |
+
s = subprocess.check_output(
|
1112 |
+
"git status -uno .; git log -1; git diff HEAD .", shell=True
|
1113 |
+
)
|
1114 |
+
logging.info(s)
|
1115 |
+
import sys
|
1116 |
+
|
1117 |
+
if len(sys.argv) > 1:
|
1118 |
+
hidden_dim = int(sys.argv[1])
|
1119 |
+
else:
|
1120 |
+
hidden_dim = 200
|
1121 |
+
|
1122 |
+
_test_scaled_adam(hidden_dim)
|
1123 |
+
_test_eden()
|
steps/trainer.py
ADDED
@@ -0,0 +1,717 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time, sys, subprocess, json, re
|
2 |
+
from pathlib import Path
|
3 |
+
import os, random
|
4 |
+
import torch
|
5 |
+
import math, pickle
|
6 |
+
from tqdm import tqdm
|
7 |
+
from torch.optim import AdamW
|
8 |
+
from torch.optim.lr_scheduler import LambdaLR
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.distributed as dist
|
11 |
+
from torch.utils.data.sampler import Sampler
|
12 |
+
import copy
|
13 |
+
from torch.utils.tensorboard import SummaryWriter
|
14 |
+
import numpy as np
|
15 |
+
from torch.utils.data.distributed import DistributedSampler
|
16 |
+
import logging
|
17 |
+
# from data import librilight, gigaspeech, gigaspeech_waveform
|
18 |
+
from data import combined_dataset
|
19 |
+
from models import voice_star
|
20 |
+
|
21 |
+
from .trainer_utils import DistributedDynamicBatchSampler, StatefulDistributedSampler, StatefulSampler, AverageMeter, print_model_info
|
22 |
+
from .optim import ScaledAdam, Eden
|
23 |
+
import run_gen
|
24 |
+
import wandb, socket
|
25 |
+
|
26 |
+
class Trainer:
|
27 |
+
|
28 |
+
def __init__(self, args, world_size, rank, local_rank):
|
29 |
+
self.start_time = time.time()
|
30 |
+
self.args = args
|
31 |
+
if self.args.val_max_num_tokens == None:
|
32 |
+
self.args.val_max_num_tokens = self.args.max_num_tokens
|
33 |
+
self.world_size, self.rank, self.local_rank = world_size, rank, local_rank
|
34 |
+
self.device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
|
35 |
+
if self.rank == 0:
|
36 |
+
self.writer = SummaryWriter(args.exp_dir)
|
37 |
+
self.wandb = wandb.init(project="voice_editor", name=args.exp_dir.split("/")[-1], config=args, dir=args.exp_dir, entity=self.args.wandb_entity)
|
38 |
+
self.seed_everything(seed=self.args.seed)
|
39 |
+
self.meters = self._setup_meters()
|
40 |
+
|
41 |
+
self.progress, self.total_progress = self._setup_progress()
|
42 |
+
|
43 |
+
self.model, self.trainables, self.optim_states, self.scheduler_states, self.phn2num = self._setup_models()
|
44 |
+
|
45 |
+
self.train_dataset_length, self.train_sampler, self.train_loader, self.valid_loader = self._setup_dataloader() # both are use DistributedSampler, train sampler is stateful
|
46 |
+
if self.args.num_steps != None:
|
47 |
+
self.total_step = self.args.num_steps
|
48 |
+
self.args.num_epochs = math.ceil(self.total_step / math.floor(self.train_dataset_length / self.args.batch_size)) if not self.args.dynamic_batching else None
|
49 |
+
else:
|
50 |
+
self.total_step = int(math.floor(self.train_dataset_length / self.args.batch_size))*self.args.num_epochs
|
51 |
+
|
52 |
+
self.optimizer, self.scheduler = self._setup_optimizer()
|
53 |
+
self.scaler = torch.cuda.amp.GradScaler()
|
54 |
+
self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[self.local_rank], find_unused_parameters=False)
|
55 |
+
self.early_stop_accu_steps = 0
|
56 |
+
if self.rank == 0:
|
57 |
+
if self.args.dynamic_batching:
|
58 |
+
logging.info(f"max number of tokens per GPU in a training batch: {self.args.max_num_tokens}, max number of tokens per GPU in a inference batch: {self.args.val_max_num_tokens}")
|
59 |
+
else:
|
60 |
+
logging.info(f"batch size (per gpu): {self.args.batch_size}")
|
61 |
+
|
62 |
+
self.args.inference_every_n_steps = getattr(self.args, "inference_every_n_steps", self.args.val_every_n_steps*5)
|
63 |
+
assert self.args.inference_every_n_steps > self.args.val_every_n_steps and self.args.inference_every_n_steps % self.args.val_every_n_steps == 0, "inference_every_n_steps should be divisible by val_every_n_steps, otherwise the code will not get a chance to run inference"
|
64 |
+
|
65 |
+
def train(self):
|
66 |
+
flag = True
|
67 |
+
skip_flag = False
|
68 |
+
data_start_time = time.time()
|
69 |
+
if self.progress['step'] >= self.total_step:
|
70 |
+
if self.rank == 0:
|
71 |
+
self.writer.close()
|
72 |
+
self.wandb.finish()
|
73 |
+
return
|
74 |
+
while flag:
|
75 |
+
self.train_sampler.set_epoch(self.progress['epoch'])
|
76 |
+
for i, batch in enumerate(self.train_loader):
|
77 |
+
if len(batch['y_lens']) < self.args.gradient_accumulation_steps:
|
78 |
+
continue
|
79 |
+
data_end_time = time.time()
|
80 |
+
self.model.train()
|
81 |
+
if self.progress['step'] >= getattr(self.args, "uniform_weight_start_step", 1e50):
|
82 |
+
if self.progress['step'] == getattr(self.args, "uniform_weight_start_step", 1e50) and self.rank == 0:
|
83 |
+
logging.info("NOTE: start using uniform weight from step: {}".format(self.progress['step']))
|
84 |
+
self.args.codebook_weight = [2.5,2,1.5,0.6]
|
85 |
+
self.model.module.args.codebook_weight = [2.5,2,1.5,0.6]
|
86 |
+
|
87 |
+
if self.progress['step'] >= self.total_step:
|
88 |
+
dist.barrier()
|
89 |
+
flag = False
|
90 |
+
self.validate_and_save()
|
91 |
+
if self.rank == 0:
|
92 |
+
self.writer.close()
|
93 |
+
self.wandb.finish()
|
94 |
+
break
|
95 |
+
if isinstance(self.scheduler, Eden):
|
96 |
+
self.scheduler.step_epoch(self.progress['step']//self.args.pseudo_epoch_size + 1)
|
97 |
+
if self.args.optimizer_name == "ScaledAdam":
|
98 |
+
cur_lr = self.scheduler.get_last_lr()[0]
|
99 |
+
else:
|
100 |
+
lrs = [param_group['lr'] for param_group in self.optimizer.param_groups]
|
101 |
+
assert lrs[0] == lrs[1]
|
102 |
+
cur_lr = lrs[0]
|
103 |
+
|
104 |
+
if self.rank == 0 and self.progress['step'] % self.args.tb_write_every_n_steps == 0:
|
105 |
+
self.writer.add_scalar("train/lr", cur_lr, self.progress['step'])
|
106 |
+
self.wandb.log({"train/lr": cur_lr}, step=self.progress['step'])
|
107 |
+
|
108 |
+
all_inds = list(range(len(batch['y'])))
|
109 |
+
sum_losses = 0
|
110 |
+
sum_top10acc = 0
|
111 |
+
sum_ntoken = 0
|
112 |
+
sum_top10acc_cbi = [0 for _ in range(self.args.n_codebooks)]
|
113 |
+
# extra losses
|
114 |
+
sum_extra_losses = {}
|
115 |
+
# when using prompt-based training, it's likely that due to prompt, the total length gets much longer, which make effective batch size in each accumulation step much bigger and then lead to OOM.
|
116 |
+
# therefore we re-calculate graduent_accumulation_steps based on the effective batch size
|
117 |
+
|
118 |
+
if self.args.neighbor_prompt_prob > 0:
|
119 |
+
effective_batch_size = self.args.max_num_tokens // self.args.gradient_accumulation_steps
|
120 |
+
total_batch_size = sum(batch['y_lens']).item()
|
121 |
+
cur_gradient_accumulation_steps = max(self.args.gradient_accumulation_steps, total_batch_size // effective_batch_size)
|
122 |
+
gas = torch.tensor(cur_gradient_accumulation_steps, dtype=torch.int, device=self.local_rank)
|
123 |
+
dist.all_reduce(gas, op=dist.ReduceOp.MAX)
|
124 |
+
cur_gradient_accumulation_steps = gas.item()
|
125 |
+
len_batch = torch.tensor(len(batch['y']), dtype=torch.int, device=self.local_rank)
|
126 |
+
dist.all_reduce(len_batch, op=dist.ReduceOp.MIN)
|
127 |
+
len_batch = len_batch.item()
|
128 |
+
cur_gradient_accumulation_steps = min(cur_gradient_accumulation_steps, len_batch)
|
129 |
+
# for those that cur_gradient_accumulation_steps * effective_batch_size < total_batch_size, we only use the first cur_gradient_accumulation_steps * effective_batch_size samples
|
130 |
+
cur_len = 0
|
131 |
+
final_all_inds = []
|
132 |
+
pointer = 0
|
133 |
+
while cur_len < self.args.max_num_tokens and pointer < len(all_inds):
|
134 |
+
cur_len += batch['y_lens'][pointer]
|
135 |
+
final_all_inds.append(all_inds[pointer])
|
136 |
+
pointer += 1
|
137 |
+
all_inds = final_all_inds
|
138 |
+
else:
|
139 |
+
cur_gradient_accumulation_steps = self.args.gradient_accumulation_steps
|
140 |
+
|
141 |
+
|
142 |
+
sum_losses_local = 0.0
|
143 |
+
sum_top10acc_local = 0.0
|
144 |
+
sum_entropy_loss_local = 0.0
|
145 |
+
sum_ctc_loss_local = 0.0
|
146 |
+
sum_ntoken_local = 0.0
|
147 |
+
sum_top10acc_cbi_local = [0.0 for _ in range(self.args.n_codebooks)]
|
148 |
+
|
149 |
+
global_nan_flag = 0
|
150 |
+
for j in range(cur_gradient_accumulation_steps):
|
151 |
+
cur_ind = all_inds[j::cur_gradient_accumulation_steps]
|
152 |
+
cur_batch = {key: batch[key][cur_ind] for key in batch}
|
153 |
+
|
154 |
+
# Automatic casting
|
155 |
+
if self.args.precision == "float16":
|
156 |
+
precision_used = torch.float16
|
157 |
+
elif self.args.precision in ["bf16", "bfloat16"]:
|
158 |
+
precision_used = torch.bfloat16
|
159 |
+
else:
|
160 |
+
precision_used = torch.float32
|
161 |
+
|
162 |
+
with torch.amp.autocast('cuda', dtype=precision_used):
|
163 |
+
out = self.model(cur_batch, calc_loss=True)
|
164 |
+
if out is None:
|
165 |
+
continue
|
166 |
+
|
167 |
+
if torch.isnan(out['loss']).any():
|
168 |
+
local_nan_flag = torch.tensor(1, device=self.local_rank)
|
169 |
+
else:
|
170 |
+
local_nan_flag = torch.tensor(0, device=self.local_rank)
|
171 |
+
|
172 |
+
# All ranks check if *any* rank got a NaN
|
173 |
+
dist.all_reduce(local_nan_flag, op=dist.ReduceOp.SUM)
|
174 |
+
global_nan_flag = local_nan_flag.item()
|
175 |
+
if global_nan_flag > 0:
|
176 |
+
# Now *all* ranks break at the same j
|
177 |
+
logging.info(f"rank: {self.rank}. Loss at micro-batch {j} in step {self.progress['step']} was NaN on at least one rank; skipping.")
|
178 |
+
break
|
179 |
+
|
180 |
+
# Accumulate local values
|
181 |
+
record_loss = out['loss'].detach()
|
182 |
+
top10acc = out['top10acc'].detach()
|
183 |
+
effective_ntoken = out['effective_ntoken'].detach()
|
184 |
+
|
185 |
+
sum_losses_local += record_loss.item()
|
186 |
+
sum_top10acc_local += top10acc.item()
|
187 |
+
sum_ntoken_local += effective_ntoken.item()
|
188 |
+
|
189 |
+
# Optional losses
|
190 |
+
if 'entropy_loss' in out:
|
191 |
+
sum_entropy_loss_local += out['entropy_loss'].detach().item()
|
192 |
+
if 'ctc_loss' in out:
|
193 |
+
sum_ctc_loss_local += out['ctc_loss'].detach().item()
|
194 |
+
|
195 |
+
# Codebook accuracy
|
196 |
+
if 'top10acc_by_codebook' in out:
|
197 |
+
for cb in range(self.args.n_codebooks):
|
198 |
+
sum_top10acc_cbi_local[cb] += out['top10acc_by_codebook'][cb].detach().item()
|
199 |
+
|
200 |
+
# Backprop on this micro-batch
|
201 |
+
if self.args.optimizer_name == "ScaledAdam":
|
202 |
+
self.scaler.scale(out['loss']).backward()
|
203 |
+
else:
|
204 |
+
self.scaler.scale(out['loss'] / out['effective_ntoken']).backward()
|
205 |
+
|
206 |
+
if global_nan_flag > 0:
|
207 |
+
# If *any* rank had NaN, skip this step
|
208 |
+
logging.info(f"rank: {self.rank}. Loss at one micro-batch in step {self.progress['step']} was NaN on at least one rank; skipping.")
|
209 |
+
self.progress['step'] += 1
|
210 |
+
self.progress['cur_step'] += 1
|
211 |
+
self.optimizer.zero_grad()
|
212 |
+
continue
|
213 |
+
|
214 |
+
# Otherwise, do one big reduce for the summed metrics
|
215 |
+
metrics_tensor = torch.tensor([
|
216 |
+
sum_losses_local,
|
217 |
+
sum_top10acc_local,
|
218 |
+
sum_entropy_loss_local,
|
219 |
+
sum_ctc_loss_local,
|
220 |
+
sum_ntoken_local
|
221 |
+
], device=self.local_rank, dtype=torch.float32)
|
222 |
+
|
223 |
+
dist.all_reduce(metrics_tensor, op=dist.ReduceOp.SUM)
|
224 |
+
|
225 |
+
# Also reduce the codebook array in one shot if needed
|
226 |
+
codebook_tensor = torch.tensor(sum_top10acc_cbi_local, device=self.local_rank, dtype=torch.float32)
|
227 |
+
dist.all_reduce(codebook_tensor, op=dist.ReduceOp.SUM)
|
228 |
+
|
229 |
+
# Convert them back to Python scalars
|
230 |
+
sum_losses = metrics_tensor[0].item()
|
231 |
+
sum_top10acc = metrics_tensor[1].item()
|
232 |
+
sum_entropy_loss = metrics_tensor[2].item()
|
233 |
+
sum_ctc_loss = metrics_tensor[3].item()
|
234 |
+
sum_ntoken = metrics_tensor[4].item()
|
235 |
+
|
236 |
+
sum_top10acc_cbi = codebook_tensor.tolist()
|
237 |
+
|
238 |
+
if self.args.optimizer_name != "ScaledAdam":
|
239 |
+
self.scaler.unscale_(self.optimizer)
|
240 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.gradient_clip_val)
|
241 |
+
|
242 |
+
self.scaler.step(self.optimizer)
|
243 |
+
self.scaler.update()
|
244 |
+
self.optimizer.zero_grad()
|
245 |
+
|
246 |
+
if self.args.optimizer_name == "ScaledAdam":
|
247 |
+
self.scheduler.step_batch(self.progress['step'])
|
248 |
+
else:
|
249 |
+
self.scheduler.step()
|
250 |
+
|
251 |
+
# logging
|
252 |
+
if self.rank == 0:
|
253 |
+
average_loss = sum_losses / sum_ntoken
|
254 |
+
average_top10acc = sum_top10acc / sum_ntoken
|
255 |
+
average_top10acc_cbi = [sum_top10acc_cbi[cb] / sum_ntoken * self.args.n_codebooks for cb in range(self.args.n_codebooks)]
|
256 |
+
self.meters['train_loss'].update(average_loss, batch['x'].shape[0]*self.world_size)
|
257 |
+
self.meters['train_top10acc'].update(average_top10acc, batch['x'].shape[0]*self.world_size)
|
258 |
+
self.meters['train_top10acc'].update(average_top10acc, batch['x'].shape[0]*self.world_size)
|
259 |
+
for cb in range(self.args.n_codebooks):
|
260 |
+
self.meters[f'train_top10acc_cb{cb+1}'].update(average_top10acc_cbi[cb], batch['x'].shape[0]*self.world_size)
|
261 |
+
self.meters['data_time'].update(data_end_time - data_start_time)
|
262 |
+
self.meters['train_time'].update(time.time() - data_end_time)
|
263 |
+
|
264 |
+
# log extra losses
|
265 |
+
for key in sum_extra_losses:
|
266 |
+
if "train_"+key not in self.meters:
|
267 |
+
self.meters["train_"+key] = AverageMeter()
|
268 |
+
self.meters["train_"+key].update(sum(sum_extra_losses[key])/len(sum_extra_losses[key]), batch['x'].shape[0]*self.world_size)
|
269 |
+
|
270 |
+
if self.progress['step'] % self.args.tb_write_every_n_steps == 0:
|
271 |
+
self.writer.add_scalar('train/loss', average_loss, self.progress['step'])
|
272 |
+
self.writer.add_scalar('train/top10acc', average_top10acc, self.progress['step'])
|
273 |
+
self.writer.add_scalar("train/ntokens", sum_ntoken, self.progress['step'])
|
274 |
+
self.wandb.log({"train/loss": average_loss, "train/top10acc": average_top10acc, "train/ntokens": sum_ntoken, "train/data_time": data_end_time - data_start_time, "train/train_time": time.time() - data_end_time}, step=self.progress['step'])
|
275 |
+
|
276 |
+
for cb in range(self.args.n_codebooks):
|
277 |
+
self.writer.add_scalar(f'train/top10acc_cb{cb+1}', average_top10acc_cbi[cb], self.progress['step'])
|
278 |
+
self.wandb.log({f'train/top10acc_cb{cb+1}': average_top10acc_cbi[cb]}, step=self.progress['step'])
|
279 |
+
self.writer.add_scalar("train/data_time", data_end_time - data_start_time, self.progress['step'])
|
280 |
+
self.writer.add_scalar("train/train_time", time.time() - data_end_time, self.progress['step'])
|
281 |
+
# write extra losses
|
282 |
+
for key in sum_extra_losses:
|
283 |
+
self.writer.add_scalar(f"train/{key}", sum(sum_extra_losses[key])/len(sum_extra_losses[key]), self.progress['step'])
|
284 |
+
self.wandb.log({f"train/{key}": sum(sum_extra_losses[key])/len(sum_extra_losses[key])}, step=self.progress['step'])
|
285 |
+
# logging.info(f"ntoken: {sum_ntoken}")
|
286 |
+
|
287 |
+
# logging
|
288 |
+
if self.progress['step'] % self.args.print_every_n_steps == 0:
|
289 |
+
log_out = {}
|
290 |
+
log_out['cur_epoch'] = f"{self.progress['epoch']}/{self.args.num_epochs}" if self.args.num_epochs is not None else f"{self.progress['epoch']}"
|
291 |
+
log_out['cur_step'] = f"{int(self.progress['cur_step']+1)}"
|
292 |
+
log_out['total_step'] = f"{self.progress['step']}/{self.args.num_steps}"
|
293 |
+
log_out['lr'] = f"{cur_lr:.7f}"
|
294 |
+
log_out['ntokens'] = f"{sum_ntoken}"
|
295 |
+
for key in self.meters:
|
296 |
+
if self.meters[key].val != 0 or self.meters[key].avg != 0:
|
297 |
+
log_out[key] = f"{self.meters[key].val:.4f} ({self.meters[key].avg:.4f})" if isinstance(self.meters[key].val, float) else f"{self.meters[key].val}"
|
298 |
+
logging.info(log_out)
|
299 |
+
if np.isnan(self.meters['train_loss'].avg):
|
300 |
+
logging.warning("training diverged...")
|
301 |
+
raise RuntimeError("training diverged...")
|
302 |
+
|
303 |
+
# save the model only
|
304 |
+
if self.progress['step'] % self.args.save_every_n_steps == 0:
|
305 |
+
dist.barrier()
|
306 |
+
if self.rank == 0:
|
307 |
+
save_path = os.path.join(self.args.exp_dir,f"bundle_step{self.progress['step']}.pth")
|
308 |
+
self.save_progress(name=f"step{self.progress['step']}")
|
309 |
+
torch.save(
|
310 |
+
{
|
311 |
+
"model": self.model.module.state_dict(),
|
312 |
+
"args": self.args,
|
313 |
+
"phn2num": self.train_loader.dataset.phn2num,
|
314 |
+
"optimizer": self.optimizer.state_dict(),
|
315 |
+
"scheduler": self.scheduler.state_dict(),
|
316 |
+
},save_path
|
317 |
+
)
|
318 |
+
logging.info(f"save model, optimizer, scheduler and progress at {save_path} at global step {self.progress['step']}")
|
319 |
+
dist.barrier()
|
320 |
+
|
321 |
+
# validation and save models
|
322 |
+
if self.progress['step'] % self.args.val_every_n_steps == 0:
|
323 |
+
dist.barrier()
|
324 |
+
continue_training = self.validate_and_save()
|
325 |
+
# broadcast continue_training to all processes, so that all processes gets into generation stage
|
326 |
+
continue_training = torch.tensor(int(continue_training), dtype=torch.int, device=self.local_rank)
|
327 |
+
dist.broadcast(continue_training, src=0)
|
328 |
+
continue_training = bool(continue_training.item())
|
329 |
+
dist.barrier() # need this to ensure all processes get to the next line?
|
330 |
+
logging.info(f"rank: {self.rank}, continue_training: {continue_training}")
|
331 |
+
if not continue_training:
|
332 |
+
if self.rank == 0:
|
333 |
+
self.writer.close()
|
334 |
+
self.wandb.finish()
|
335 |
+
flag = False
|
336 |
+
break
|
337 |
+
|
338 |
+
self.progress['step'] += 1
|
339 |
+
self.progress['cur_step'] += 1
|
340 |
+
|
341 |
+
data_start_time = time.time()
|
342 |
+
self.progress['epoch'] += 1
|
343 |
+
self.progress['cur_step'] = 0 # reset cur_step to be 0
|
344 |
+
dist.destroy_process_group()
|
345 |
+
|
346 |
+
def validate_and_save(self):
|
347 |
+
self.model.eval()
|
348 |
+
|
349 |
+
score = self.validate(self.valid_loader)
|
350 |
+
|
351 |
+
if self.args.early_stop_threshold > 0:
|
352 |
+
if self.progress['best_score'] - score < self.args.early_stop_threshold:
|
353 |
+
self.early_stop_accu_steps += self.args.val_every_n_steps
|
354 |
+
if self.early_stop_accu_steps >= self.args.early_stop_step-1:
|
355 |
+
logging.info(f"early stop based on self.args.early_stop_threshold: {self.args.early_stop_threshold}, and self.args.early_stop_step: {self.args.early_stop_step}")
|
356 |
+
logging.info(f"best validation score at step: {self.progress['best_step']}, and the score is {self.progress['best_score']:.4f}")
|
357 |
+
return False
|
358 |
+
else:
|
359 |
+
self.early_stop_accu_steps = 0
|
360 |
+
|
361 |
+
if self.rank == 0:
|
362 |
+
save_path = os.path.join(self.args.exp_dir,"bundle.pth")
|
363 |
+
if os.path.isfile(save_path):
|
364 |
+
os.system(f"mv {save_path} {save_path.replace('.pth', '_prev.pth')}")
|
365 |
+
torch.save(
|
366 |
+
{
|
367 |
+
"model": self.model.module.state_dict(),
|
368 |
+
"optimizer": self.optimizer.state_dict(),
|
369 |
+
"scheduler": self.scheduler.state_dict(),
|
370 |
+
"args": self.args,
|
371 |
+
"phn2num": self.train_loader.dataset.phn2num
|
372 |
+
},save_path
|
373 |
+
)
|
374 |
+
self.save_progress()
|
375 |
+
logging.info(f"save models, indices, acc and other statistics at {save_path} and {self.args.exp_dir}/progress.pkl at global step {self.progress['step']}")
|
376 |
+
if (score < self.progress['best_score']):
|
377 |
+
self.progress['best_step'] = self.progress['step']
|
378 |
+
self.progress['best_score'] = score
|
379 |
+
save_path = os.path.join(self.args.exp_dir,"best_bundle.pth")
|
380 |
+
if os.path.isfile(save_path):
|
381 |
+
os.system(f"mv {save_path} {save_path.replace('.pth', '_prev.pth')}")
|
382 |
+
torch.save(
|
383 |
+
{
|
384 |
+
"model": self.model.module.state_dict(),
|
385 |
+
"optimizer": self.optimizer.state_dict(),
|
386 |
+
"scheduler": self.scheduler.state_dict(),
|
387 |
+
"args": self.args,
|
388 |
+
"phn2num": self.train_loader.dataset.phn2num
|
389 |
+
},save_path
|
390 |
+
)
|
391 |
+
logging.info(f"save *best* models at {save_path} at global step {self.progress['step']}")
|
392 |
+
|
393 |
+
# sync best score and best step, so that all processes early stop at the same time
|
394 |
+
best_score_tensor = torch.tensor(self.progress['best_score'], device=self.local_rank)
|
395 |
+
dist.broadcast(best_score_tensor, src=0)
|
396 |
+
self.progress['best_score'] = float(best_score_tensor.item())
|
397 |
+
best_step_tensor = torch.tensor(self.progress['best_step'], device=self.local_rank)
|
398 |
+
dist.broadcast(best_step_tensor, src=0)
|
399 |
+
self.progress['best_step'] = int(best_step_tensor.item())
|
400 |
+
dist.barrier()
|
401 |
+
return True
|
402 |
+
|
403 |
+
def validate(self, valid_loader=None, hide_progress=True):
|
404 |
+
if valid_loader == None:
|
405 |
+
valid_loader = self.valid_loader
|
406 |
+
self.model.eval()
|
407 |
+
|
408 |
+
start_val_time = time.time()
|
409 |
+
sum_losses = 0
|
410 |
+
sum_top10acc = 0
|
411 |
+
sum_ntoken = 0
|
412 |
+
sum_dur_loss = 0
|
413 |
+
sum_dur_acc = 0
|
414 |
+
sum_entropy_loss = 0
|
415 |
+
sum_ctc_loss = 0
|
416 |
+
|
417 |
+
sum_top10acc_cbi = [0 for _ in range(self.args.n_codebooks)]
|
418 |
+
mean_perplexity_cbi = [0 for _ in range(self.args.n_codebooks)]
|
419 |
+
|
420 |
+
with torch.no_grad():
|
421 |
+
for i, batch in enumerate(tqdm(valid_loader, disable=hide_progress)):
|
422 |
+
out = self.model(batch, calc_loss=True) # no reduction is applied to loss
|
423 |
+
sum_losses += out['loss']
|
424 |
+
sum_top10acc += out['top10acc']
|
425 |
+
sum_ntoken += out['effective_ntoken']
|
426 |
+
if "dur_loss" in out:
|
427 |
+
sum_dur_loss += out['dur_loss']
|
428 |
+
sum_dur_acc += out['dur_acc']
|
429 |
+
if "entropy_loss" in out:
|
430 |
+
sum_entropy_loss += out['entropy_loss']
|
431 |
+
if "ctc_loss" in out:
|
432 |
+
sum_ctc_loss += out['ctc_loss']
|
433 |
+
# logging.info(f"iter {i}::: {sum_losses}, {sum_top10acc}, {sum_ntoken}")
|
434 |
+
if 'top10acc_by_codebook' in out:
|
435 |
+
for cb in range(self.args.n_codebooks):
|
436 |
+
sum_top10acc_cbi[cb] += out['top10acc_by_codebook'][cb]
|
437 |
+
|
438 |
+
if 'perplexity_by_codebook' in out:
|
439 |
+
for cb in range(self.args.n_codebooks):
|
440 |
+
mean_perplexity_cbi[cb] += out['perplexity_by_codebook'][cb]
|
441 |
+
# if i > 10:
|
442 |
+
# break
|
443 |
+
|
444 |
+
|
445 |
+
dist.all_reduce(sum_losses, op=dist.ReduceOp.SUM)
|
446 |
+
dist.all_reduce(sum_top10acc, op=dist.ReduceOp.SUM)
|
447 |
+
dist.all_reduce(sum_ntoken, op=dist.ReduceOp.SUM)
|
448 |
+
if "dur_loss" in out:
|
449 |
+
dist.all_reduce(sum_dur_loss, op=dist.ReduceOp.SUM)
|
450 |
+
dist.all_reduce(sum_dur_acc, op=dist.ReduceOp.SUM)
|
451 |
+
if "entropy_loss" in out:
|
452 |
+
dist.all_reduce(sum_entropy_loss, op=dist.ReduceOp.SUM)
|
453 |
+
if "ctc_loss" in out:
|
454 |
+
dist.all_reduce(sum_ctc_loss, op=dist.ReduceOp.SUM)
|
455 |
+
|
456 |
+
if 'top10acc_by_codebook' in out:
|
457 |
+
for cb in range(self.args.n_codebooks):
|
458 |
+
dist.all_reduce(sum_top10acc_cbi[cb], op=dist.ReduceOp.SUM)
|
459 |
+
|
460 |
+
if 'perplexity_by_codebook' in out:
|
461 |
+
for cb in range(self.args.n_codebooks):
|
462 |
+
dist.all_reduce(mean_perplexity_cbi[cb], op=dist.ReduceOp.SUM)
|
463 |
+
|
464 |
+
val_loss = sum_losses / sum_ntoken
|
465 |
+
val_top10acc = sum_top10acc / sum_ntoken
|
466 |
+
|
467 |
+
if self.rank == 0:
|
468 |
+
if "dur_loss" in out:
|
469 |
+
val_dur_loss = sum_dur_loss / sum_ntoken
|
470 |
+
val_dur_acc = sum_dur_acc / sum_ntoken
|
471 |
+
self.meters['val_dur_loss'].update(val_dur_loss)
|
472 |
+
logging.info(f"val dur_loss: {val_dur_loss:.5f}")
|
473 |
+
self.meters['val_dur_acc'].update(val_dur_acc)
|
474 |
+
logging.info(f"val dur_acc: {val_dur_acc:.5f}")
|
475 |
+
self.writer.add_scalar("val/dur_loss", val_dur_loss, self.progress['step'])
|
476 |
+
self.writer.add_scalar("val/dur_acc", val_dur_acc, self.progress['step'])
|
477 |
+
self.wandb.log({"val/dur_loss": val_dur_loss, "val/dur_acc": val_dur_acc}, step=self.progress['step'])
|
478 |
+
# logging
|
479 |
+
self.meters['val_loss'].update(val_loss)
|
480 |
+
logging.info(f"val loss: {val_loss:.5f}")
|
481 |
+
self.writer.add_scalar("val/loss", val_loss, self.progress['step'])
|
482 |
+
self.wandb.log({"val/loss": val_loss}, step=self.progress['step'])
|
483 |
+
|
484 |
+
self.meters['val_top10acc'].update(val_top10acc)
|
485 |
+
logging.info(f"val top10acc: {val_top10acc:.5f}")
|
486 |
+
self.writer.add_scalar("val/top10acc", val_top10acc, self.progress['step'])
|
487 |
+
self.wandb.log({"val/top10acc": val_top10acc}, step=self.progress['step'])
|
488 |
+
for cb in range(self.args.n_codebooks):
|
489 |
+
average_top10acc_cbi = sum_top10acc_cbi[cb] / sum_ntoken * self.args.n_codebooks
|
490 |
+
self.meters[f'val_top10acc_cb{cb+1}'].update(average_top10acc_cbi)
|
491 |
+
self.writer.add_scalar(f'val/top10acc_cb{cb+1}', average_top10acc_cbi, self.progress['step'])
|
492 |
+
self.wandb.log({f'val/top10acc_cb{cb+1}': average_top10acc_cbi}, step=self.progress['step'])
|
493 |
+
|
494 |
+
temp = mean_perplexity_cbi[cb]/len(valid_loader)
|
495 |
+
self.writer.add_scalar(f'val/perplexity_cb{cb+1}', temp, self.progress['step'])
|
496 |
+
self.wandb.log({f'val/perplexity_cb{cb+1}': temp}, step=self.progress['step'])
|
497 |
+
|
498 |
+
average_perplexity = sum(mean_perplexity_cbi)/(self.args.n_codebooks*len(valid_loader))
|
499 |
+
self.wandb.log({"val/average_perplexity": average_perplexity}, step=self.progress['step'])
|
500 |
+
self.writer.add_scalar('val/average_perplexity', average_perplexity, self.progress['step'])
|
501 |
+
|
502 |
+
# log entropy and ctc loss
|
503 |
+
if "entropy_loss" in out:
|
504 |
+
val_entropy_loss = sum_entropy_loss / ((i+1) * self.world_size)
|
505 |
+
self.meters['val_entropy_loss'].update(val_entropy_loss)
|
506 |
+
logging.info(f"val entropy_loss: {val_entropy_loss:.5f}")
|
507 |
+
self.writer.add_scalar("val/entropy_loss", val_entropy_loss, self.progress['step'])
|
508 |
+
self.wandb.log({"val/entropy_loss": val_entropy_loss}, step=self.progress['step'])
|
509 |
+
if "ctc_loss" in out:
|
510 |
+
val_ctc_loss = sum_ctc_loss / ((i+1) * self.world_size)
|
511 |
+
self.meters['val_ctc_loss'].update(val_ctc_loss)
|
512 |
+
logging.info(f"val ctc_loss: {val_ctc_loss:.5f}")
|
513 |
+
self.writer.add_scalar("val/ctc_loss", val_ctc_loss, self.progress['step'])
|
514 |
+
self.wandb.log({"val/ctc_loss": val_ctc_loss}, step=self.progress['step'])
|
515 |
+
|
516 |
+
logging.info(f"validation takes: {time.time() - start_val_time:.2f}s")
|
517 |
+
logging.info(f"Step [{self.progress['step']}/{self.total_step}]\t Time elapsed {(time.time() - self.start_time)/3600.:.2f}h, Val Loss: {val_loss:.4f}, Val Top10Acc: {val_top10acc:.4f}")
|
518 |
+
|
519 |
+
return val_loss.item()
|
520 |
+
|
521 |
+
def _setup_meters(self):
|
522 |
+
meters = {}
|
523 |
+
meter_names = ['train_loss', 'val_loss', 'train_top10acc', 'val_top10acc', 'data_time', 'train_time']
|
524 |
+
meter_names += ['train_dur_loss', 'train_dur_acc', 'val_dur_loss', 'val_dur_acc']
|
525 |
+
meter_names += ['val_perplexity']
|
526 |
+
meter_names += [f'train_top10acc_cb{cb+1}' for cb in range(self.args.n_codebooks)]
|
527 |
+
meter_names += [f'val_top10acc_cb{cb+1}' for cb in range(self.args.n_codebooks)]
|
528 |
+
meter_names += [f'val_perplexity_cb{cb+1}' for cb in range(self.args.n_codebooks)]
|
529 |
+
for name in meter_names:
|
530 |
+
meters[name] = AverageMeter()
|
531 |
+
return meters
|
532 |
+
def _setup_progress(self):
|
533 |
+
"""
|
534 |
+
Need to customize it
|
535 |
+
"""
|
536 |
+
progress = {}
|
537 |
+
progress['best_step'] = 1
|
538 |
+
progress['best_score'] = np.inf # this records loss value
|
539 |
+
progress['step'] = 1
|
540 |
+
progress['epoch'] = 1
|
541 |
+
progress['cur_step'] = 0 # step in the current epoch, for resuming the sampler
|
542 |
+
total_progress = []
|
543 |
+
# if self.args.resume or self.args.validate:
|
544 |
+
if self.args.resume:
|
545 |
+
progress_pkl = "%s/progress.pkl" % self.args.exp_dir
|
546 |
+
with open(progress_pkl, "rb") as f:
|
547 |
+
total_progress = pickle.load(f)
|
548 |
+
progress['best_step'], progress['best_score'], progress['step'], progress['epoch'], progress['cur_step'], _ = total_progress[-1]
|
549 |
+
if self.rank == 0:
|
550 |
+
logging.info("\nResume training from:")
|
551 |
+
logging.info(" epoch = %s" % progress['epoch'])
|
552 |
+
logging.info(" cur_step = %s" % progress['cur_step'])
|
553 |
+
logging.info(" step = %s" % progress['step'])
|
554 |
+
logging.info(" best_step = %s" % progress['best_step'])
|
555 |
+
logging.info(" best_score = %s" % progress['best_score'])
|
556 |
+
return progress, total_progress
|
557 |
+
|
558 |
+
def save_progress(self, name=None):
|
559 |
+
self.total_progress.append([self.progress['best_step'], self.progress['best_score'], int(self.progress['step']+1), self.progress['epoch'], int(self.progress['cur_step']+1), time.time() - self.start_time])
|
560 |
+
if name is not None:
|
561 |
+
progress_fn = f"{self.args.exp_dir}/progress_{name}.pkl"
|
562 |
+
else:
|
563 |
+
progress_fn = f"{self.args.exp_dir}/progress.pkl"
|
564 |
+
with open(progress_fn, "wb") as f:
|
565 |
+
pickle.dump(self.total_progress, f)
|
566 |
+
|
567 |
+
def _setup_dataloader(self):
|
568 |
+
train_dataset, val_dataset = combined_dataset.dataset(self.args, 'train'), combined_dataset.dataset(self.args, 'valid') # need to change 'train' to 'valid' in actual training
|
569 |
+
|
570 |
+
if self.args.dynamic_batching:
|
571 |
+
train_sampler = DistributedDynamicBatchSampler(train_dataset, self.args, num_replicas=self.world_size, rank=self.rank, shuffle=True, seed=self.args.seed, drop_last=True, lengths_list=train_dataset.lengths_list, verbose=True, epoch=0)
|
572 |
+
valid_sampler = DistributedDynamicBatchSampler(val_dataset, self.args, num_replicas=self.world_size, rank=self.rank, shuffle=True, seed=self.args.seed, drop_last=True, lengths_list=val_dataset.lengths_list, verbose=True, epoch=0)
|
573 |
+
else:
|
574 |
+
train_sampler = StatefulDistributedSampler(train_dataset, self.args.batch_size//self.world_size, num_replicas=self.world_size, rank=self.rank, shuffle=True, seed=self.args.seed, drop_last=True)
|
575 |
+
valid_sampler = DistributedSampler(val_dataset, num_replicas=self.world_size, rank=self.rank, shuffle=False, seed=self.args.seed, drop_last=False)
|
576 |
+
|
577 |
+
if self.progress['step'] > 1:
|
578 |
+
train_sampler.set_epoch_resume(self.progress['epoch'], self.progress['cur_step'])
|
579 |
+
assert self.phn2num != None
|
580 |
+
|
581 |
+
if self.phn2num != None:
|
582 |
+
train_dataset.phn2num = self.phn2num
|
583 |
+
val_dataset.phn2num = self.phn2num
|
584 |
+
|
585 |
+
if self.args.dynamic_batching:
|
586 |
+
train_loader = torch.utils.data.DataLoader(train_dataset,
|
587 |
+
batch_sampler=train_sampler,
|
588 |
+
num_workers=self.args.num_workers,
|
589 |
+
collate_fn=train_dataset.collate, persistent_workers=True
|
590 |
+
)
|
591 |
+
valid_loader = torch.utils.data.DataLoader(val_dataset,
|
592 |
+
batch_sampler=valid_sampler,
|
593 |
+
num_workers=self.args.num_workers,
|
594 |
+
collate_fn=val_dataset.collate, persistent_workers=True
|
595 |
+
)
|
596 |
+
else:
|
597 |
+
train_loader = torch.utils.data.DataLoader(train_dataset,
|
598 |
+
batch_size=self.args.batch_size, sampler=train_sampler, num_workers=self.args.num_workers,
|
599 |
+
collate_fn=train_dataset.collate, persistent_workers=True
|
600 |
+
)
|
601 |
+
valid_loader = torch.utils.data.DataLoader(val_dataset,
|
602 |
+
batch_size=self.args.batch_size, sampler=valid_sampler,
|
603 |
+
num_workers=self.args.num_workers,
|
604 |
+
collate_fn=val_dataset.collate, persistent_workers=True
|
605 |
+
)
|
606 |
+
return len(train_dataset), train_sampler, train_loader, valid_loader
|
607 |
+
|
608 |
+
|
609 |
+
|
610 |
+
def _setup_models(self):
|
611 |
+
model = voice_star.VoiceStar(self.args)
|
612 |
+
|
613 |
+
if self.rank == 0:
|
614 |
+
logging.info(model)
|
615 |
+
logging.info("model parameters")
|
616 |
+
print_model_info(model)
|
617 |
+
|
618 |
+
phn2num = None
|
619 |
+
optim_states = None
|
620 |
+
scheduler_states = None
|
621 |
+
if self.progress['step'] > 1:
|
622 |
+
bundle = torch.load(os.path.join(self.args.exp_dir, "bundle.pth"), map_location="cpu")
|
623 |
+
model.load_state_dict(bundle['model'])
|
624 |
+
optim_states = bundle['optimizer']
|
625 |
+
scheduler_states = bundle['scheduler']
|
626 |
+
phn2num = bundle['phn2num']
|
627 |
+
if self.rank == 0:
|
628 |
+
logging.info("loaded parameters and data indices from epoch %d, global step %d" % (self.progress['epoch'], self.progress['step']))
|
629 |
+
del bundle['model']
|
630 |
+
|
631 |
+
if self.args.load_model_from != None and self.progress['step'] <= 1:
|
632 |
+
logging.info(f"load weights from {self.args.load_model_from}")
|
633 |
+
sd = torch.load(self.args.load_model_from, map_location="cpu")
|
634 |
+
if hasattr(model, "carefully_load_state_dict"):
|
635 |
+
model.carefully_load_state_dict(sd['model'])
|
636 |
+
else:
|
637 |
+
model.load_state_dict(sd['model'])
|
638 |
+
phn2num = sd['phn2num']
|
639 |
+
del sd
|
640 |
+
|
641 |
+
|
642 |
+
#### below operations is for getting params for optimizer, which is at wrapper level ###
|
643 |
+
if self.args.optimizer_name == "ScaledAdam":
|
644 |
+
trainables = [p for p in model.parameters() if p.requires_grad]
|
645 |
+
else:
|
646 |
+
no_decay = [".bias", ".audio_embeddings.weight", ".text_embeddings.weight", ".norm.weight", ".norm1.weight", ".norm2.weight"]
|
647 |
+
optimizer_grouped_parameters = [
|
648 |
+
{
|
649 |
+
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad],
|
650 |
+
"weight_decay": self.args.weight_decay,
|
651 |
+
},
|
652 |
+
{
|
653 |
+
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad],
|
654 |
+
"weight_decay": 0.0,
|
655 |
+
},
|
656 |
+
]
|
657 |
+
if len(optimizer_grouped_parameters[1]['params']) == 0:
|
658 |
+
logging.info("there is no embedding weights, bias, and layernorm parameters in the model, which should be True, check model parameter names")
|
659 |
+
trainables = optimizer_grouped_parameters[0]
|
660 |
+
else:
|
661 |
+
trainables = optimizer_grouped_parameters
|
662 |
+
#### below operations is for getting params for optimizer, which is at wrapper level ###
|
663 |
+
model.to(self.device)
|
664 |
+
|
665 |
+
return model, trainables, optim_states, scheduler_states, phn2num
|
666 |
+
|
667 |
+
|
668 |
+
def _setup_optimizer(self):
|
669 |
+
if self.args.optimizer_name == "ScaledAdam":
|
670 |
+
parameters_names = []
|
671 |
+
_model = self.model.module if isinstance(self.model, torch.nn.parallel.DistributedDataParallel) else self.model
|
672 |
+
parameters_names.append([n for n,p in self.model.named_parameters() if p.requires_grad])
|
673 |
+
optimizer = ScaledAdam(
|
674 |
+
self.trainables,
|
675 |
+
lr=self.args.lr,
|
676 |
+
betas=(0.9, 0.95),
|
677 |
+
clipping_scale=2.0,
|
678 |
+
parameters_names=parameters_names,
|
679 |
+
show_dominant_parameters=False,
|
680 |
+
clipping_update_period=self.args.clipping_update_period,
|
681 |
+
)
|
682 |
+
scheduler = Eden(optimizer, self.args.reduce_lr_start_step, self.args.reduce_lr_start_epoch, warmup_batches=self.total_step * self.args.warmup_fraction) # NOTE: if using ScaledAdam, we will use the Eden scheduler!
|
683 |
+
|
684 |
+
else:
|
685 |
+
optimizer = AdamW(self.trainables, lr=self.args.lr)
|
686 |
+
warmup_steps = self.total_step * self.args.warmup_fraction
|
687 |
+
def lr_lambda(current_step: int):
|
688 |
+
if current_step < warmup_steps:
|
689 |
+
return float(current_step) / float(max(1, warmup_steps))
|
690 |
+
return max(
|
691 |
+
0.0, float(self.total_step - current_step) / float(max(1, self.total_step - warmup_steps))
|
692 |
+
)
|
693 |
+
|
694 |
+
scheduler = LambdaLR(optimizer, lr_lambda, last_epoch=-1)
|
695 |
+
|
696 |
+
# if resume
|
697 |
+
if self.progress['step'] > 1:
|
698 |
+
optimizer.load_state_dict(self.optim_states)
|
699 |
+
for state in optimizer.state.values():
|
700 |
+
for k, v in state.items():
|
701 |
+
if isinstance(v, torch.Tensor):
|
702 |
+
state[k] = v.cuda()
|
703 |
+
del self.optim_states
|
704 |
+
|
705 |
+
scheduler.load_state_dict(self.scheduler_states)
|
706 |
+
|
707 |
+
optimizer.zero_grad()
|
708 |
+
return optimizer, scheduler
|
709 |
+
|
710 |
+
def seed_everything(self, seed=1):
|
711 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
712 |
+
random.seed(seed)
|
713 |
+
np.random.seed(seed)
|
714 |
+
torch.manual_seed(seed)
|
715 |
+
torch.cuda.manual_seed(seed)
|
716 |
+
torch.backends.cudnn.benchmark = False
|
717 |
+
torch.backends.cudnn.deterministic = True
|