jzq11111 commited on
Commit
a3e05e8
·
verified ·
1 Parent(s): 9084c21

Upload folder using huggingface_hub

Browse files
Files changed (48) hide show
  1. .gitattributes +4 -0
  2. .gitignore +7 -0
  3. .gradio/certificate.pem +31 -0
  4. LICENSE +21 -0
  5. README.md +3 -9
  6. app.py +211 -0
  7. download_pretrain.py +2 -0
  8. en_prompt0.wav +3 -0
  9. en_prompt1.wav +3 -0
  10. inference.py +342 -0
  11. modules/audio_detokenizer/audio_detokenizer.py +249 -0
  12. modules/audio_detokenizer/bigvgan_wrapper.py +94 -0
  13. modules/audio_detokenizer/flow_matching/dit_block.py +236 -0
  14. modules/audio_detokenizer/flow_matching/model.py +295 -0
  15. modules/audio_detokenizer/flow_matching/ode_wrapper.py +164 -0
  16. modules/audio_detokenizer/flow_matching/scheduler.py +82 -0
  17. modules/audio_detokenizer/semantic_fm_prefix_streaming.py +273 -0
  18. modules/audio_detokenizer/vocoder/activations.py +123 -0
  19. modules/audio_detokenizer/vocoder/alias_free_activation/__init__.py +0 -0
  20. modules/audio_detokenizer/vocoder/alias_free_activation/cuda/__init__.py +0 -0
  21. modules/audio_detokenizer/vocoder/alias_free_activation/cuda/activation1d.py +77 -0
  22. modules/audio_detokenizer/vocoder/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  23. modules/audio_detokenizer/vocoder/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
  24. modules/audio_detokenizer/vocoder/alias_free_activation/cuda/compat.h +29 -0
  25. modules/audio_detokenizer/vocoder/alias_free_activation/cuda/load.py +86 -0
  26. modules/audio_detokenizer/vocoder/alias_free_activation/cuda/type_shim.h +92 -0
  27. modules/audio_detokenizer/vocoder/alias_free_activation/torch/__init__.py +6 -0
  28. modules/audio_detokenizer/vocoder/alias_free_activation/torch/act.py +30 -0
  29. modules/audio_detokenizer/vocoder/alias_free_activation/torch/filter.py +101 -0
  30. modules/audio_detokenizer/vocoder/alias_free_activation/torch/resample.py +58 -0
  31. modules/audio_detokenizer/vocoder/bigvgan.py +492 -0
  32. modules/audio_detokenizer/vocoder/utils.py +105 -0
  33. modules/audio_tokenizer/audio_tokenizer.py +76 -0
  34. modules/audio_tokenizer/quantize/__init__.py +3 -0
  35. modules/audio_tokenizer/quantize/factorized_vector_quantize.py +145 -0
  36. modules/audio_tokenizer/quantize/residual_vq.py +168 -0
  37. modules/audio_tokenizer/quantize/vector_quantize.py +396 -0
  38. modules/audio_tokenizer/rep_codec.py +197 -0
  39. modules/audio_tokenizer/transformer.py +234 -0
  40. modules/audio_tokenizer/vocos.py +845 -0
  41. modules/tokenizer/tokenizer.py +243 -0
  42. readme.md +39 -0
  43. requirements.txt +18 -0
  44. test/test_audio_detokenizer.py +34 -0
  45. test/test_audio_tokenizer.py +15 -0
  46. test/test_tokenizer.py +37 -0
  47. zh_prompt0.wav +3 -0
  48. zh_prompt1.wav +3 -0
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ en_prompt0.wav filter=lfs diff=lfs merge=lfs -text
37
+ en_prompt1.wav filter=lfs diff=lfs merge=lfs -text
38
+ zh_prompt0.wav filter=lfs diff=lfs merge=lfs -text
39
+ zh_prompt1.wav filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ *.safetensors
2
+ *.pt
3
+ *.vscode
4
+ **/__pycache__/
5
+ modules/audio_detokenizer/vocoder/alias_free_activation/cuda/build/
6
+ tmp*
7
+ resources/
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Zeqian Ju
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,12 +1,6 @@
1
  ---
2
- title: Mooncast
3
- emoji: 🚀
4
- colorFrom: red
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 5.23.0
8
  app_file: app.py
9
- pinned: false
 
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: mooncast
 
 
 
 
 
3
  app_file: app.py
4
+ sdk: gradio
5
+ sdk_version: 5.22.0
6
  ---
 
 
app.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ import os
4
+ from huggingface_hub import snapshot_download
5
+ from inference import Model
6
+ import base64
7
+
8
+ snapshot_download(repo_id="jzq11111/mooncast", local_dir='./resources/')
9
+ model = Model()
10
+ model.generate_config.max_new_tokens = 50 * 50 # no more than 20s per turn
11
+
12
+
13
+ def process_json_and_generate_audio(prompt_audio_role0_file, prompt_text_role0, prompt_audio_role1_file, prompt_text_role1, json_dialogue_input_str):
14
+ try:
15
+ print(json_dialogue_input_str, type(json_dialogue_input_str))
16
+ print(prompt_audio_role0_file, prompt_text_role0, prompt_audio_role1_file, prompt_text_role1)
17
+ # json_data = json.loads(json_dialogue_input_str)
18
+ json_data = eval(json_dialogue_input_str.strip())
19
+ print(json_data, type(json_data))
20
+
21
+ def validate_json(data):
22
+ try:
23
+ if not isinstance(data, list):
24
+ return "json must be a dictionary"
25
+ cur_spk_should_be = 0
26
+ for item in data:
27
+ if item['role'] != str(cur_spk_should_be):
28
+ return f"role should be {cur_spk_should_be} in item {item}"
29
+ cur_spk_should_be = 1 - cur_spk_should_be
30
+ return None
31
+ except Exception as e:
32
+ return str(e)
33
+
34
+
35
+ validation_error = validate_json(json_data)
36
+ if validation_error:
37
+ raise gr.Error(validation_error)
38
+
39
+ role_mapping = {
40
+ "0": {
41
+ "ref_audio": prompt_audio_role0_file,
42
+ "ref_text": prompt_text_role0,
43
+ },
44
+ "1": {
45
+ "ref_audio": prompt_audio_role1_file,
46
+ "ref_text": prompt_text_role1,
47
+ }
48
+ }
49
+
50
+ # 完整输入 JSON (你需要根据你的模型调整)
51
+ model_input_json = {
52
+ "role_mapping": role_mapping,
53
+ "dialogue": json_data, # 从用户输入的 JSON 中获取 dialogue
54
+ }
55
+ print("模型推理输入 JSON:", model_input_json)
56
+
57
+
58
+ # 4. **[重要] 调用你的 Model 类的 `inference` 方法**
59
+ # audio_bytes = model.inference(model_input_json)
60
+
61
+ # 5. 返回音频 bytes 给 Gradio (Gradio 会自动处理音频 bytes 并播放)
62
+ # return base64.b64decode(audio_bytes)
63
+ for cur_chunk in model.inference(model_input_json, streaming=True):
64
+ yield base64.b64decode(cur_chunk)
65
+
66
+ except Exception as e:
67
+ # return str(e) # 返回错误信息给 Gradio
68
+ raise gr.Error(str(e))
69
+
70
+ title_en = "# PODCAST generator (supports English and Chinese)"
71
+ title_zh = "# 播客生成 (支持英文和中文)"
72
+
73
+ input_labels_en = ["Prompt Audio for Role 0", "Prompt Text for Role 0", "Prompt Audio for Role 1", "Prompt Text for Role 1", "Dialogue JSON Input"]
74
+ input_labels_zh = ["角色 0 的 Prompt 音频", "角色 0 的 Prompt 文本", "角色 1 的 Prompt 音频", "角色 1 的 Prompt 文本", "对话 JSON 输入"]
75
+
76
+ output_label_en = "Generated Audio Output (streaming)"
77
+ output_label_zh = "生成的音频输出(流式)"
78
+
79
+ example_prompt_text_role0_en = "Yeah, no, this is my backyard. It's never ending So just the way I like it. So social distancing has never been a problem."
80
+ example_prompt_text_role0_zh = "可以每天都骑并且可能会让你爱上骑车,然后通过爱上骑车的你省了很多很多钱。"
81
+ example_prompt_text_role1_en = "I'm doing great And. Look, it couldn't be any better than having you at your set, which is the outdoors."
82
+ example_prompt_text_role1_zh = "他最后就能让同样食材炒出来的菜味道大大提升。"
83
+
84
+ text_placeholder_zh = "对话轮流进行, 每轮最多50秒。文本越自然, 生成的音频效果越好。"
85
+ text_placeholder_en = "Dialogue alternates between roles. Limit each turn to a maximum of 50 seconds. The more natural the text, the better the generated audio."
86
+
87
+
88
+ example_json_en = '''[
89
+ {
90
+ "role": "0",
91
+ "text": "In an awesome time, And, we're even gonna do a second episode too So. This is part one part two, coming at some point in the future There. We are.",
92
+ },
93
+ {
94
+ "role": "1",
95
+ "text": "I love it. So grateful Thank you So I'm really excited. That's awesome. Yeah.",
96
+ },
97
+ {
98
+ "role": "0",
99
+ "text": "All I was told, which is good because I don't want to really talk too much more is that you're really really into fitness and nutrition And overall holistic I love it Yes.",
100
+ },
101
+ {
102
+ "role": "1",
103
+ "text": "Yeah So I started around thirteen Okay But my parents were fitness instructors as well. Awesome So I came from the beginning, and now it's this transition into this wholeness because I had to chart my. Own path and they weren't into nutrition at all So I had to learn that part."
104
+ }
105
+ ]'''
106
+ example_json_zh = '''[
107
+ {
108
+ "role": "0",
109
+ "text": "我觉得啊,就是经历了这么多���的经验, 就是补剂的作用就是九分的努力, 十分之一的补剂。 嗯,选的话肯定是九分更重要,但是我觉得补剂它能够让你九分的努力更加的有效率,更加的避免徒劳无功。 嗯,就是你,你你得先得真的锻炼,真的努力,真的健康饮食,然后再考虑补剂, 那你再加十十分之一的补剂的话,他可能就是说啊, 一半是心理作用,"
110
+ },
111
+ {
112
+ "role": "1",
113
+ "text": "对,其实很多时候心理作用是非常重要的。嗯,然后我每次用补剂的时候,我就会更加努力,就比如说我在健身之前我喝了一勺蛋白粉,我就会督促自己多练,"
114
+ },
115
+ {
116
+ "role": "0",
117
+ "text": "其实心理作用只要能实现你的预期目的就可以了。 就比如说给自行车链条加油, 它其实不是必要的,但是它可以让你骑行更顺畅, 然后提高你骑行的频率。"
118
+ }
119
+ ]
120
+ '''
121
+
122
+ # examples_en = [
123
+ # ['./en_prompt0.wav', example_prompt_text_role0_en, './en_prompt1.wav', example_prompt_text_role1_en, example_json_en]
124
+ # ]
125
+ # examples_zh = [
126
+ # ['./zh_prompt0.wav', example_prompt_text_role0_zh, './zh_prompt1.wav', example_prompt_text_role1_zh, example_json_zh]
127
+ # ]
128
+
129
+ examples = [
130
+ ['./en_prompt0.wav', example_prompt_text_role0_en, './en_prompt1.wav', example_prompt_text_role1_en, example_json_en],
131
+ ['./zh_prompt0.wav', example_prompt_text_role0_zh, './zh_prompt1.wav', example_prompt_text_role1_zh, example_json_zh]
132
+ ]
133
+
134
+ # -------------------- 更新界面元素的函数 --------------------
135
+ def update_ui_language(language):
136
+ if language == "English":
137
+ return gr.update(value=title_en), \
138
+ gr.update(label="UI Language"), \
139
+ gr.update(label=input_labels_en[0]), \
140
+ gr.update(label=input_labels_en[1]), \
141
+ gr.update(label=input_labels_en[2]), \
142
+ gr.update(label=input_labels_en[3]), \
143
+ gr.update(label=input_labels_en[4], placeholder=text_placeholder_en), \
144
+ gr.update(label=output_label_en), \
145
+ gr.update(value="Submit"), \
146
+ gr.update(label="Examples (Demonstration Use Only. Do Not Redistribute.)", headers=input_labels_en)
147
+
148
+ elif language == "中文":
149
+ return gr.update(value=title_zh), \
150
+ gr.update(label="UI 语言"), \
151
+ gr.update(label=input_labels_zh[0]), \
152
+ gr.update(label=input_labels_zh[1]), \
153
+ gr.update(label=input_labels_zh[2]), \
154
+ gr.update(label=input_labels_zh[3]), \
155
+ gr.update(label=input_labels_zh[4], placeholder=text_placeholder_zh), \
156
+ gr.update(label=output_label_zh), \
157
+ gr.update(value="提交"), \
158
+ gr.update(label="示例 (仅用于展示,切勿私自传播。)", headers=input_labels_zh)
159
+
160
+ else:
161
+ raise ValueError("Invalid language selected")
162
+
163
+
164
+ audio_output = gr.Audio(label=output_label_en, streaming=True)
165
+ css = """
166
+ .centered-title { /* CSS rule for centering title */
167
+ text-align: center !important;
168
+ }
169
+ """
170
+ # -------------------- Gradio 界面定义 (修改) --------------------
171
+ with gr.Blocks(css=css) as iface:
172
+
173
+ title_output = gr.Markdown(value=title_zh, elem_classes="centered-title")
174
+ language_choice = gr.Radio(["中文", "English"], value="中文", label="UI语言")
175
+
176
+ with gr.Row(): # Main row to create two columns
177
+ with gr.Column(scale=2):
178
+ json_input = gr.TextArea(label=input_labels_zh[4], lines=15, placeholder=text_placeholder_zh) # Dialogue JSON Input
179
+
180
+ with gr.Column(scale=1): # Right column (narrower - scale=1) for prompt inputs
181
+ audio_input_role0 = gr.Audio(type="filepath", label=input_labels_zh[0]) # Prompt Audio for Role 0
182
+ text_input_role0 = gr.TextArea(label=input_labels_zh[1], lines=2) # Prompt Text for Role 0
183
+
184
+ with gr.Column(scale=1): #
185
+ audio_input_role1 = gr.Audio(type="filepath", label=input_labels_zh[2]) # Prompt Audio for Role 1
186
+ text_input_role1 = gr.TextArea(label=input_labels_zh[3], lines=2) # Prompt Text for Role 1
187
+
188
+ examples_component = gr.Examples(
189
+ examples=examples,
190
+ inputs=[audio_input_role0, text_input_role0, audio_input_role1, text_input_role1, json_input],
191
+ cache_examples=False,
192
+ label="示例(仅用于展示,切勿私自传播。)",
193
+ )
194
+
195
+ submit_button = gr.Button("提交")
196
+
197
+ submit_button.click(
198
+ fn=process_json_and_generate_audio,
199
+ inputs=[audio_input_role0, text_input_role0, audio_input_role1, text_input_role1, json_input],
200
+ outputs=audio_output
201
+ )
202
+ audio_output.render()
203
+
204
+ language_choice.change(
205
+ fn=update_ui_language,
206
+ inputs=language_choice,
207
+ outputs=[title_output, language_choice, audio_input_role0, text_input_role0, audio_input_role1, text_input_role1, json_input, audio_output, submit_button, examples_component.dataset]
208
+ )
209
+
210
+
211
+ iface.launch(share=True)
download_pretrain.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from huggingface_hub import snapshot_download
2
+ snapshot_download(repo_id="jzq11111/mooncast", local_dir='./resources/')
en_prompt0.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d789de64f286dd5d0198e67310bb334d9c0bfe95b3369747ccab58ae614c3292
3
+ size 684388
en_prompt1.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:407697746e6a6da1d540fe06379c54c36dd3b1a06fb3a789eb457b981d2cc7f4
3
+ size 615116
inference.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import sys
3
+ sys.path.append(".")
4
+ from modules.tokenizer.tokenizer import get_tokenizer_and_extra_tokens
5
+ from modules.audio_tokenizer.audio_tokenizer import get_audio_tokenizer
6
+ from modules.audio_detokenizer.audio_detokenizer import get_audio_detokenizer, detokenize, detokenize_noref
7
+ import torch
8
+ import os
9
+ from glob import glob
10
+ import base64
11
+ import io
12
+ import torchaudio
13
+ from transformers import AutoModelForCausalLM, GenerationConfig
14
+ import librosa
15
+ from tqdm import tqdm
16
+
17
+ class Model(object):
18
+ def __init__(self):
19
+
20
+
21
+ self.tokenizer, self.extra_tokens = get_tokenizer_and_extra_tokens()
22
+ self.speech_token_offset = 163840
23
+ print(self.extra_tokens)
24
+ self.assistant_ids = self.tokenizer.encode("assistant") # [110866]
25
+ self.user_ids = self.tokenizer.encode("user") # [1495]
26
+ self.audio_ids = self.tokenizer.encode("audio") # [26229]
27
+ self.spk_0_ids = self.tokenizer.encode("0") # [501]
28
+ self.spk_1_ids = self.tokenizer.encode("1") # [503]
29
+
30
+ self.msg_end = self.extra_tokens.msg_end # 260
31
+ self.user_msg_start = self.extra_tokens.user_msg_start # 261
32
+ self.assistant_msg_start = self.extra_tokens.assistant_msg_start # 262
33
+ self.name_end = self.extra_tokens.name_end # 272
34
+ self.media_begin = self.extra_tokens.media_begin # 273
35
+ self.media_content = self.extra_tokens.media_content # 274
36
+ self.media_end = self.extra_tokens.media_end # 275
37
+
38
+ self.audio_tokenizer = get_audio_tokenizer()
39
+ self.audio_detokenizer = get_audio_detokenizer()
40
+ model_path = "resources/text2semantic"
41
+ self.model = AutoModelForCausalLM.from_pretrained(model_path, device_map="cuda:0", torch_dtype=torch.bfloat16, trust_remote_code=True, force_download=True).to(torch.cuda.current_device())
42
+ self.generate_config = GenerationConfig(
43
+ max_new_tokens=200 * 50, # no more than 200s per turn
44
+ do_sample=True,
45
+ top_k=30,
46
+ top_p=0.8,
47
+ temperature=0.8,
48
+ eos_token_id=self.media_end,
49
+ )
50
+
51
+ def _clean_text(self, text):
52
+ # you can add front-end processing here
53
+ text = text.replace("“", "")
54
+ text = text.replace("”", "")
55
+ text = text.replace("...", " ")
56
+ text = text.replace("…", " ")
57
+ text = text.replace("*", "")
58
+ text = text.replace(":", ",")
59
+ text = text.replace("‘", "'")
60
+ text = text.replace("’", "'")
61
+ text = text.strip()
62
+ return text
63
+
64
+ @torch.inference_mode()
65
+ def _process_text(self, js):
66
+
67
+ if "role_mapping" in js:
68
+ for role in js["role_mapping"].keys():
69
+ js["role_mapping"][role]["ref_bpe_ids"] = self.tokenizer.encode(self._clean_text(js["role_mapping"][role]["ref_text"]))
70
+
71
+ for turn in js["dialogue"]:
72
+ turn["bpe_ids"] = self.tokenizer.encode(self._clean_text(turn["text"]))
73
+ return js
74
+
75
+ def inference(self, js, streaming=False):
76
+ js = self._process_text(js)
77
+ if "role_mapping" not in js:
78
+ return self.infer_without_prompt(js, streaming)
79
+ else:
80
+ return self.infer_with_prompt(js, streaming)
81
+
82
+ @torch.inference_mode()
83
+ def infer_with_prompt(self, js, streaming=False):
84
+ user_role_0_ids = [self.user_msg_start] + self.user_ids + self.spk_0_ids + [self.name_end]
85
+ user_role_1_ids = [self.user_msg_start] + self.user_ids + self.spk_1_ids + [self.name_end]
86
+ assistant_role_0_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_0_ids + [self.name_end]
87
+ assistant_role_1_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_1_ids + [self.name_end]
88
+
89
+ media_start = [self.media_begin] + self.audio_ids + [self.media_content]
90
+ media_end = [self.media_end] + [self.msg_end]
91
+
92
+ assistant_role_0_ids = torch.LongTensor(assistant_role_0_ids).unsqueeze(0).to(torch.cuda.current_device())
93
+ assistant_role_1_ids = torch.LongTensor(assistant_role_1_ids).unsqueeze(0).to(torch.cuda.current_device())
94
+ media_start = torch.LongTensor(media_start).unsqueeze(0).to(torch.cuda.current_device())
95
+ media_end = torch.LongTensor(media_end).unsqueeze(0).to(torch.cuda.current_device())
96
+
97
+
98
+ prompt = []
99
+ cur_role_dict = dict()
100
+ for role, role_item in js["role_mapping"].items():
101
+ waveform_24k = librosa.load(role_item["ref_audio"], sr=24000)[0]
102
+ waveform_24k = torch.tensor(waveform_24k).unsqueeze(0).to(torch.cuda.current_device())
103
+
104
+ waveform_16k = librosa.load(role_item["ref_audio"], sr=16000)[0]
105
+ waveform_16k = torch.tensor(waveform_16k).unsqueeze(0).to(torch.cuda.current_device())
106
+
107
+ semantic_tokens = self.audio_tokenizer.tokenize(waveform_16k)
108
+ semantic_tokens = semantic_tokens.to(torch.cuda.current_device())
109
+ prompt_ids = semantic_tokens + self.speech_token_offset
110
+
111
+ cur_role_dict[role] = {
112
+ "ref_bpe_ids": role_item["ref_bpe_ids"],
113
+ "wav_24k": waveform_24k,
114
+ "semantic_tokens": semantic_tokens,
115
+ "prompt_ids": prompt_ids
116
+ }
117
+
118
+ prompt = prompt + user_role_0_ids + cur_role_dict["0"]["ref_bpe_ids"] + [self.msg_end]
119
+ prompt = prompt + user_role_1_ids + cur_role_dict["1"]["ref_bpe_ids"] + [self.msg_end]
120
+
121
+ for seg_id, turn in enumerate(js["dialogue"]):
122
+ role_id = turn["role"]
123
+ cur_user_ids = user_role_0_ids if role_id == "0" else user_role_1_ids
124
+ cur_start_ids = cur_user_ids + turn["bpe_ids"] + [self.msg_end]
125
+ prompt = prompt + cur_start_ids
126
+
127
+ prompt = torch.LongTensor(prompt).unsqueeze(0).to(torch.cuda.current_device())
128
+
129
+ prompt = torch.cat([prompt, assistant_role_0_ids, media_start, cur_role_dict["0"]["prompt_ids"], media_end], dim=-1)
130
+ prompt = torch.cat([prompt, assistant_role_1_ids, media_start, cur_role_dict["1"]["prompt_ids"], media_end], dim=-1)
131
+
132
+
133
+ generation_config = self.generate_config
134
+ # you can modify sampling strategy here
135
+
136
+ wav_list = []
137
+ for seg_id, turn in tqdm(enumerate(js["dialogue"])):
138
+ role_id = turn["role"]
139
+ cur_assistant_ids = assistant_role_0_ids if role_id == "0" else assistant_role_1_ids
140
+ prompt = torch.cat([prompt, cur_assistant_ids, media_start], dim=-1)
141
+ len_prompt = prompt.shape[1]
142
+ generation_config.min_length = len_prompt + 2
143
+ # print(generation_config)
144
+ # todo: add streaming support for generate function
145
+ outputs = self.model.generate(prompt,
146
+ generation_config=generation_config)
147
+ if outputs[0, -1] == self.media_end:
148
+ outputs = outputs[:, :-1]
149
+ output_token = outputs[:, len_prompt:]
150
+ prompt = torch.cat([outputs, media_end], dim=-1)
151
+
152
+ torch_token = output_token - self.speech_token_offset
153
+ if streaming:
154
+ # gen_speech_fm = detokenize(self.audio_detokenizer, torch_token, cur_role_dict[role_id]["wav_24k"], cur_role_dict[role_id]["semantic_tokens"])
155
+ # yield detokenize(self.audio_detokenizer, torch_token, cur_role_dict[role_id]["wav_24k"], cur_role_dict[role_id]["semantic_tokens"])
156
+ for cur_chunk in detokenize(self.audio_detokenizer, torch_token, cur_role_dict[role_id]["wav_24k"], cur_role_dict[role_id]["semantic_tokens"], streaming=True):
157
+ cur_chunk = cur_chunk.cpu()
158
+ cur_chunk = cur_chunk / cur_chunk.abs().max()
159
+ cur_buffer = io.BytesIO()
160
+ torchaudio.save(cur_buffer, cur_chunk, sample_rate=24000, format="mp3")
161
+ audio_bytes = cur_buffer.getvalue()
162
+ audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
163
+ yield audio_b64
164
+ else:
165
+ gen_speech_fm = detokenize(self.audio_detokenizer, torch_token, cur_role_dict[role_id]["wav_24k"], cur_role_dict[role_id]["semantic_tokens"])
166
+ gen_speech_fm = gen_speech_fm.cpu()
167
+ gen_speech_fm = gen_speech_fm / gen_speech_fm.abs().max()
168
+ wav_list.append(gen_speech_fm)
169
+ del torch_token
170
+ if not streaming:
171
+ concat_wav = torch.cat(wav_list, dim=-1).cpu()
172
+ # print(concat_wav.shape)
173
+ buffer = io.BytesIO()
174
+ torchaudio.save(buffer, concat_wav, sample_rate=24000, format="mp3")
175
+ audio_bytes = buffer.getvalue()
176
+ audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
177
+ return audio_b64
178
+
179
+ @torch.inference_mode()
180
+ def infer_without_prompt(self, js, streaming=False):
181
+ user_role_0_ids = [self.user_msg_start] + self.user_ids + self.spk_0_ids + [self.name_end]
182
+ user_role_1_ids = [self.user_msg_start] + self.user_ids + self.spk_1_ids + [self.name_end]
183
+ assistant_role_0_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_0_ids + [self.name_end]
184
+ assistant_role_1_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_1_ids + [self.name_end]
185
+
186
+ media_start = [self.media_begin] + self.audio_ids + [self.media_content]
187
+ media_end = [self.media_end] + [self.msg_end]
188
+
189
+ assistant_role_0_ids = torch.LongTensor(assistant_role_0_ids).unsqueeze(0).to(torch.cuda.current_device())
190
+ assistant_role_1_ids = torch.LongTensor(assistant_role_1_ids).unsqueeze(0).to(torch.cuda.current_device())
191
+ media_start = torch.LongTensor(media_start).unsqueeze(0).to(torch.cuda.current_device())
192
+ media_end = torch.LongTensor(media_end).unsqueeze(0).to(torch.cuda.current_device())
193
+
194
+
195
+ prompt = []
196
+ for seg_id, turn in enumerate(js["dialogue"]):
197
+ role_id = turn["role"]
198
+ cur_user_ids = user_role_0_ids if role_id == "0" else user_role_1_ids
199
+ cur_start_ids = cur_user_ids + turn["bpe_ids"] + [self.msg_end]
200
+ prompt = prompt + cur_start_ids
201
+
202
+ prompt = torch.LongTensor(prompt).unsqueeze(0).to(torch.cuda.current_device())
203
+ generation_config = self.generate_config
204
+ # you can modify sampling strategy here
205
+
206
+ wav_list = []
207
+ for seg_id, turn in tqdm(enumerate(js["dialogue"])):
208
+ role_id = turn["role"]
209
+ cur_assistant_ids = assistant_role_0_ids if role_id == "0" else assistant_role_1_ids
210
+ prompt = torch.cat([prompt, cur_assistant_ids, media_start], dim=-1)
211
+ len_prompt = prompt.shape[1]
212
+ generation_config.min_length = len_prompt + 2
213
+ # print(generation_config)
214
+ # todo: add streaming support for generate function
215
+ outputs = self.model.generate(prompt,
216
+ generation_config=generation_config)
217
+ if outputs[0, -1] == self.media_end:
218
+ outputs = outputs[:, :-1]
219
+ output_token = outputs[:, len_prompt:]
220
+ prompt = torch.cat([outputs, media_end], dim=-1)
221
+
222
+ torch_token = output_token - self.speech_token_offset
223
+ if streaming:
224
+ for cur_chunk in detokenize_noref(self.audio_detokenizer, torch_token, streaming=True):
225
+ cur_chunk = cur_chunk.cpu()
226
+ cur_chunk = cur_chunk / cur_chunk.abs().max()
227
+ cur_buffer = io.BytesIO()
228
+ torchaudio.save(cur_buffer, cur_chunk, sample_rate=24000, format="mp3")
229
+ audio_bytes = cur_buffer.getvalue()
230
+ audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
231
+ yield audio_b64
232
+ else:
233
+ gen_speech_fm = detokenize_noref(self.audio_detokenizer, torch_token)
234
+ gen_speech_fm = gen_speech_fm.cpu()
235
+ gen_speech_fm = gen_speech_fm / gen_speech_fm.abs().max()
236
+ wav_list.append(gen_speech_fm)
237
+ del torch_token
238
+
239
+ if not streaming:
240
+ concat_wav = torch.cat(wav_list, dim=-1).cpu()
241
+ # print(concat_wav.shape)
242
+ buffer = io.BytesIO()
243
+ torchaudio.save(buffer, concat_wav, sample_rate=24000, format="mp3")
244
+ audio_bytes = buffer.getvalue()
245
+ audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
246
+ return audio_b64
247
+
248
+
249
+ if __name__ == "__main__":
250
+ model = Model()
251
+
252
+ # speaker should be interleaved
253
+ zh_test_json = {
254
+ "role_mapping": {
255
+ "0": {
256
+ "ref_audio": "./zh_prompt0.wav",
257
+ "ref_text": "可以每天都骑并且可能会让你爱上骑车,然后通过爱上骑车的你省了很多很多钱。", #asr output
258
+ },
259
+ "1": {
260
+ "ref_audio": "./zh_prompt1.wav",
261
+ "ref_text": "他最后就能让同样食材炒出来的菜味道大大提升。" #asr output
262
+ }
263
+ },
264
+ "dialogue": [
265
+ {
266
+ "role": "0",
267
+ "text": "我觉得啊,就是经历了这么多年的经验, 就是补剂的作用就是九分的努力, 十分之一的补剂。 嗯,选的话肯定是九分更重要,但是我觉得补剂它能够让你九分的努力更加的有效率,更加的避免徒劳无功。 嗯,就是你,你你得先得真的锻炼,真的努力,真的健康饮食,然后再考虑补剂, 那你再加十十分之一的补剂的话,他可能就是说啊, 一半是心理作用,"
268
+ },
269
+ {
270
+ "role": "1",
271
+ "text": "对,其实很多时候心理作用是非常重要的。嗯,然后我每次用补剂的时候,我就会更加努力,就比如说我在健身之前我喝了一勺蛋白粉,我就会督促自己多练,"
272
+ },
273
+ {
274
+ "role": "0",
275
+ "text": "其实心理作用只要能实现你的预期目的就可以了。 就比如说给自行车链条加油, 它其实不是必要的,但是它可以让你骑行更顺畅, 然后提高你骑行的频率。"
276
+ }
277
+ ]
278
+ }
279
+ audio_bytes = model.inference(zh_test_json)
280
+ file_to_save = open(f"tmp_generated_zh.mp3", "wb")
281
+ file_to_save.write(base64.b64decode(audio_bytes))
282
+ print("zh done")
283
+
284
+ # speaker should be interleaved
285
+ en_test_json = {
286
+ "role_mapping": {
287
+ "0": {
288
+ "ref_audio": "./en_prompt0.wav",
289
+ "ref_text": "Yeah, no, this is my backyard. It's never ending So just the way I like it. So social distancing has never been a problem.", #asr output
290
+ },
291
+ "1": {
292
+ "ref_audio": "./en_prompt1.wav",
293
+ "ref_text": "I'm doing great And. Look, it couldn't be any better than having you at your set, which is the outdoors." #asr output
294
+ }
295
+ },
296
+ "dialogue": [
297
+ {
298
+ "role": "0",
299
+ "text": "In an awesome time, And, we're even gonna do a second episode too So. This is part one part two, coming at some point in the future There. We are.",
300
+ },
301
+ {
302
+ "role": "1",
303
+ "text": "I love it. So grateful Thank you So I'm really excited. That's awesome. Yeah."
304
+ },
305
+ {
306
+ "role": "0",
307
+ "text": "All I was told, which is good because I don't want to really talk too much more is that you're really really into fitness and nutrition And overall holistic I love it Yes."
308
+ },
309
+ {
310
+ "role": "1",
311
+ "text": "Yeah So I started around thirteen Okay But my parents were fitness instructors as well. Awesome So I came from the beginning, and now it's this transition into this wholeness because I had to chart my. Own path and they weren't into nutrition at all So I had to learn that part."
312
+ }
313
+ ]
314
+ }
315
+ audio_bytes = model.inference(en_test_json)
316
+ file_to_save = open(f"tmp_generated_en.mp3", "wb")
317
+ file_to_save.write(base64.b64decode(audio_bytes))
318
+ print("en done")
319
+
320
+
321
+ # also support inference without prompt
322
+ # speaker should be interleaved
323
+ without_prompt_test_json = {
324
+ "dialogue": [
325
+ {
326
+ "role": "0",
327
+ "text": "我觉得啊,就是经历了这么多年的经验, 就是补剂的作用就是九分的努力, 十分之一的补剂。 嗯,选的话肯定是九分更重要,但是我觉得补剂它能够让你九分的努力更加的有效率,更加的避免徒劳无功。 嗯,就是你,你你得先得真的锻炼,真的努力,真的健康饮食,然后再考虑补剂, 那你再加十十分之一的补剂的话,他可能就是说啊, 一半是心理作用,"
328
+ },
329
+ {
330
+ "role": "1",
331
+ "text": "对,其实很多时候心理作用是非常重要的。嗯,然后我每次用补剂的时候,我就会更加努力,就比如说我在健身之前我喝了一勺蛋白粉,我就会督促自己多练,"
332
+ },
333
+ {
334
+ "role": "0",
335
+ "text": "其实心理作用只要能实现你的预期目的就可以了。 就比如说给自行车链条加油, 它其实不是必要的,但是它可以让你骑行更顺畅, 然后提高你骑行的频率。"
336
+ }
337
+ ]
338
+ }
339
+ audio_bytes = model.inference(without_prompt_test_json)
340
+ file_to_save = open(f"tmp_generated_woprompt.mp3", "wb")
341
+ file_to_save.write(base64.b64decode(audio_bytes))
342
+ print("without prompt done")
modules/audio_detokenizer/audio_detokenizer.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+
4
+ from modules.audio_detokenizer.bigvgan_wrapper import BigVGANWrapper
5
+ from modules.audio_detokenizer.semantic_fm_prefix_streaming import StreamingSemanticFMWrapper
6
+
7
+
8
+ class PrefixStreamingFlowMatchingDetokenizer:
9
+ def __init__(self, vocoder: BigVGANWrapper, fm: StreamingSemanticFMWrapper, look_ahead_tokens: int = 0) -> None:
10
+ self.dtype = torch.bfloat16
11
+
12
+ print("Currently using bfloat16 for PrefixFlowMatchingDetokenizer")
13
+
14
+ self.vocoder = vocoder
15
+ self.vocoder.to_dtype(self.dtype)
16
+
17
+ self.semantic_fm = fm
18
+
19
+ # initialize mel_spec
20
+ self.max_pos_size = 4096
21
+ self.is_timbre_semantic_token = False
22
+ self.pre_mel = None
23
+ self.frame_size = 480 # how many samples in a frame
24
+ self.pre_wav = None
25
+ self.state_dict_backup = None
26
+ self.hamming_window_cache = {}
27
+ self.previous_chunk_left = None
28
+ self.look_ahead_tokens = look_ahead_tokens
29
+
30
+ self.clear_states()
31
+
32
+
33
+ @classmethod
34
+ def from_pretrained(cls, vocoder_config, vocoder_ckpt, fm_config, fm_ckpt, device,
35
+ look_ahead_tokens=0,
36
+ max_prompt_chunk=2, max_kv_cache_tokens=900,
37
+ use_cfg=False, use_cfg_rescale=True, cfg_init=1.5, cfg_scale=7.5, cfg_schedule="linear"):
38
+ bigvgan = BigVGANWrapper.from_pretrained(vocoder_config, vocoder_ckpt, device)
39
+ semantic_fm = StreamingSemanticFMWrapper.from_pretrained(fm_config, fm_ckpt, device, max_prompt_chunk=max_prompt_chunk, max_kv_cache_tokens=max_kv_cache_tokens,
40
+ use_cfg=use_cfg, cfg_scale=cfg_scale, use_cfg_rescale=use_cfg_rescale, cfg_init=cfg_init, cfg_schedule=cfg_schedule)
41
+ return cls(bigvgan, semantic_fm, look_ahead_tokens=look_ahead_tokens)
42
+
43
+ @torch.inference_mode()
44
+ def prefill(self, timbre_speech, timbre_semantic_token, chunk_size: int, timbre_mel=None):
45
+ """
46
+ Arguments:
47
+ timbre_speech: torch.Tensor, shape [B, N_speech_24k]
48
+ timbre_semantic_token: torch.Tensor, shape [B, N]
49
+ chunk_size: int, chunk size for prefilling
50
+ timbre_mel: torch.Tensor, shape [B, N, 80], optional, if not None, use this mel spectrogram instead of extracting from timbre_speech
51
+ """
52
+ if timbre_mel is None:
53
+ assert timbre_speech is not None, "timbre_speech should not be None if timbre_mel is not None"
54
+ assert len(timbre_semantic_token.shape) == 2 and len(timbre_speech.shape) == 2 and chunk_size > 0
55
+ assert timbre_speech.shape[0] == 1 and timbre_semantic_token.shape[0] == 1
56
+
57
+ mel_spec = self.vocoder.extract_mel_from_wav(wav_data=timbre_speech.squeeze(0))
58
+ else:
59
+ assert len(timbre_mel.shape) == 3 and len(timbre_semantic_token.shape) == 2 and chunk_size > 0
60
+ assert timbre_mel.shape[0] == 1 and timbre_semantic_token.shape[0] == 1
61
+ mel_spec = timbre_mel.squeeze(0)
62
+
63
+ if mel_spec.shape[0] < timbre_semantic_token.shape[1]:
64
+ # pad mel_spec
65
+ mel_spec = torch.nn.functional.pad(mel_spec, (0, 0, 0, timbre_semantic_token.shape[1] - mel_spec.shape[0]))
66
+ elif mel_spec.shape[0] > timbre_semantic_token.shape[1]:
67
+ # truncate mel_spec
68
+ mel_spec = mel_spec[:timbre_semantic_token.shape[1], :]
69
+
70
+ # clear all states
71
+ self.semantic_fm.clear_all_states()
72
+ self.semantic_fm.prefill(mel_spec, timbre_semantic_token.squeeze(0), chunk_size=chunk_size, verbose=False)
73
+ self.state_dict_backup = self.semantic_fm.state_dict()
74
+
75
+ @torch.inference_mode()
76
+ def detokenize_streaming(self, semantic_token, ode_step=30, verbose=False, ode_solver="neural_ode_euler", is_final=False, upsample_factor=1):
77
+ assert len(semantic_token.shape) == 2 and ode_step > 0
78
+ assert semantic_token.shape[0] == 1
79
+
80
+ semantic_token = semantic_token.repeat_interleave(upsample_factor, dim=1)
81
+
82
+ semantic_token = semantic_token.squeeze(0)
83
+
84
+ if self.look_ahead_tokens != 0 and self.previous_chunk_left is not None:
85
+ semantic_token_previous = self.previous_chunk_left["semantic_token"]
86
+ semantic_token = torch.cat([semantic_token_previous, semantic_token], dim=-1)
87
+
88
+ x_t_chunk = torch.randn(semantic_token.shape[0], 80).to(semantic_token.device).to(self.dtype)
89
+
90
+ if self.look_ahead_tokens != 0 and self.previous_chunk_left is None:
91
+ self.previous_chunk_left = {"semantic_token": None}
92
+
93
+ speech_mel = self.semantic_fm.infer_chunk(
94
+ xt_chunk=x_t_chunk,
95
+ semantic_tokens_chunk=semantic_token,
96
+ start_position_id=self.semantic_fm.start_position_id,
97
+ ode_steps=ode_step,
98
+ verbose=verbose,
99
+ look_ahead_tokens=self.look_ahead_tokens * upsample_factor if not is_final else 0,
100
+ cache=self.previous_chunk_left,
101
+ ode_solver=ode_solver
102
+ )
103
+
104
+ chunk_size = speech_mel.shape[0]
105
+ length = speech_mel.shape[0]
106
+ self.semantic_fm.start_position_id += length
107
+ self.semantic_fm.update_incremental_state()
108
+ self.semantic_fm.reserve_kv_cache_tokens += self.semantic_fm.ode_wrapper.kv_cache_tokens
109
+
110
+ # smoothing
111
+
112
+ # I will maintain the history of seqlen wav
113
+ # For the first chunk, I will only return the half chunk wav, and save the res half chunk in history
114
+ # For the rest requests, I will concat the generated wav with the history, output one chunk of the history, save the
115
+
116
+ if self.pre_mel is None: # first chunk, related to TTFB
117
+ concat_mel = speech_mel
118
+ concat_reconstructed_wav = self.vocoder.decode_mel(concat_mel)
119
+ if is_final:
120
+ self.clear_states()
121
+ self.state_dict_backup = None
122
+ ret_wav = concat_reconstructed_wav.float()
123
+ else:
124
+ reconstructed_wav = concat_reconstructed_wav[:, :int(self.frame_size * chunk_size // 2)] # return the first half chunk
125
+
126
+ self.pre_wav = concat_reconstructed_wav[:, -int(self.frame_size * chunk_size // 2):] # log the last half chunk for next generation step
127
+ self.pre_mel = speech_mel[-chunk_size//2:, :]
128
+
129
+ ret_wav = reconstructed_wav.float()
130
+ else:
131
+ concat_mel = torch.cat([self.pre_mel, speech_mel], dim=0)
132
+ concat_reconstructed_wav = self.vocoder.decode_mel(concat_mel)
133
+
134
+ if is_final:
135
+ self.clear_states()
136
+ self.state_dict_backup = None
137
+ ret_wav = concat_reconstructed_wav.float()
138
+ else:
139
+ # fetch history
140
+ prev_speech_len = self.pre_wav.shape[1]
141
+
142
+ if concat_reconstructed_wav.shape[1] > prev_speech_len * 2:
143
+ gen_speech_len = prev_speech_len * 2
144
+ else:
145
+ gen_speech_len = concat_reconstructed_wav.shape[1] // 2
146
+
147
+
148
+ reconstructed_wav = concat_reconstructed_wav[:, :gen_speech_len] # return the first half chunk
149
+
150
+ if gen_speech_len not in self.hamming_window_cache:
151
+ self.hamming_window_cache[gen_speech_len] = torch.hamming_window(gen_speech_len).to(self.dtype).to(semantic_token.device).unsqueeze(0)
152
+
153
+ hamming_window = self.hamming_window_cache[gen_speech_len]
154
+
155
+
156
+ # apply smoothing of the first half chunk
157
+ reconstructed_wav[:, :int(gen_speech_len // 2 )] = self.pre_wav[:, :int(gen_speech_len // 2 )] * hamming_window[:,-int(gen_speech_len // 2):] + \
158
+ reconstructed_wav[:, :int(gen_speech_len // 2)] * hamming_window[:, :int(gen_speech_len // 2)]
159
+
160
+ res_speech_len = concat_reconstructed_wav.shape[1] - gen_speech_len
161
+ res_mel_len = res_speech_len // self.frame_size
162
+
163
+ self.pre_wav = concat_reconstructed_wav[:, -res_speech_len:]
164
+ self.pre_mel = speech_mel[-res_mel_len:, :]
165
+ ret_wav = reconstructed_wav.float()
166
+
167
+ if not is_final and self.semantic_fm.start_position_id + 2*chunk_size > self.max_pos_size:
168
+ # out of position id,
169
+ self.semantic_fm.clear_all_states()
170
+ self.semantic_fm.load_state_dict(self.state_dict_backup)
171
+
172
+ return ret_wav
173
+
174
+ def clear_states(self):
175
+ self.semantic_fm.clear_all_states()
176
+ self.previous_chunk_left = None
177
+ self.pre_mel = None
178
+ self.pre_wav = None
179
+
180
+ def get_audio_detokenizer():
181
+ fm_model_config = "resources/audio_detokenizer/config.yaml"
182
+ fm_ckpt_path = "resources/audio_detokenizer/model.pt"
183
+
184
+ bigvgan_config_file = "resources/vocoder/config.json"
185
+ bigvgan_ckpt_path = "resources/vocoder/model.pt"
186
+
187
+ device=torch.cuda.current_device()
188
+ detokenizer = PrefixStreamingFlowMatchingDetokenizer.from_pretrained(
189
+ vocoder_config=bigvgan_config_file,
190
+ vocoder_ckpt=bigvgan_ckpt_path,
191
+ max_prompt_chunk=10, # 10 * 3 = 30s
192
+ fm_config=fm_model_config,
193
+ fm_ckpt=fm_ckpt_path,
194
+ device=device,
195
+ use_cfg=False,
196
+ look_ahead_tokens=12)
197
+
198
+ return detokenizer
199
+
200
+
201
+ def detokenize(detokenizer, tokens, ref_wav, ref_tokens, streaming=False):
202
+ with torch.no_grad():
203
+ detokenizer.clear_states()
204
+ detokenizer.prefill(ref_wav, ref_tokens, chunk_size=150)
205
+ cache_speech_collection = []
206
+ chunk_size = 150
207
+ first_chunk_size = 100
208
+ first_chunk_tokens = tokens[:, :first_chunk_size]
209
+ gen_speech = detokenizer.detokenize_streaming(first_chunk_tokens, is_final=tokens.size(1) <= first_chunk_size)
210
+ if streaming:
211
+ yield gen_speech
212
+ else:
213
+ cache_speech_collection.append(gen_speech)
214
+ res_tokens = tokens[:, first_chunk_size:]
215
+ for i in range(0, res_tokens.size(1), chunk_size):
216
+ chunk_tokens = res_tokens[:, i:i+chunk_size]
217
+ gen_speech = detokenizer.detokenize_streaming(chunk_tokens, is_final=(i+chunk_size >= res_tokens.size(1)))
218
+ if streaming:
219
+ yield gen_speech
220
+ else:
221
+ cache_speech_collection.append(gen_speech)
222
+ if not streaming:
223
+ gen_speech_all = torch.cat(cache_speech_collection, dim=-1)
224
+ return gen_speech_all
225
+
226
+ def detokenize_noref(detokenizer, tokens, streaming=False):
227
+ with torch.no_grad():
228
+ detokenizer.clear_states()
229
+ cache_speech_collection = []
230
+ chunk_size = 150
231
+ first_chunk_size = 100
232
+ first_chunk_tokens = tokens[:, :first_chunk_size]
233
+ gen_speech = detokenizer.detokenize_streaming(first_chunk_tokens, is_final=tokens.size(1) <= first_chunk_size)
234
+ if streaming:
235
+ yield gen_speech
236
+ else:
237
+ cache_speech_collection.append(gen_speech)
238
+ res_tokens = tokens[:, first_chunk_size:]
239
+ for i in range(0, res_tokens.size(1), chunk_size):
240
+ chunk_tokens = res_tokens[:, i:i+chunk_size]
241
+ gen_speech = detokenizer.detokenize_streaming(chunk_tokens, is_final=(i+chunk_size >= res_tokens.size(1)))
242
+ if streaming:
243
+ yield gen_speech
244
+ else:
245
+ cache_speech_collection.append(gen_speech)
246
+ if not streaming:
247
+ gen_speech_all = torch.cat(cache_speech_collection, dim=-1)
248
+ return gen_speech_all
249
+
modules/audio_detokenizer/bigvgan_wrapper.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import logging
4
+
5
+ import librosa
6
+ import torch
7
+
8
+ from modules.audio_detokenizer.vocoder.bigvgan import BigVGAN
9
+ from modules.audio_detokenizer.vocoder.utils import get_melspec, AttrDict, load_checkpoint
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class BigVGANWrapper:
15
+ def __init__(self, vocoder: BigVGAN, device: torch.device, h: AttrDict, dtype=None) -> None:
16
+ self.vocoder = vocoder.to(device)
17
+ if dtype is not None:
18
+ self.vocoder = self.vocoder.to(dtype)
19
+ self.vocoder = self.vocoder.eval()
20
+ self.device = device
21
+ self.h = h
22
+
23
+ def to_dtype(self, dtype):
24
+ self.vocoder = self.vocoder.to(dtype)
25
+
26
+ def extract_mel_from_wav(self, wav_path=None, wav_data=None):
27
+ """
28
+ params:
29
+ wav_path: str, path of the wav, should be 24k
30
+ wav_data: torch.tensor or numpy array, shape [T], wav data, should be 24k
31
+ return:
32
+ mel: [T, num_mels], torch.tensor
33
+ """
34
+ if wav_data is None:
35
+ wav_data, _ = librosa.load(wav_path, sr=self.h["sampling_rate"])
36
+
37
+ wav_data = torch.tensor(wav_data).unsqueeze(0)
38
+
39
+ mel = get_melspec(y=wav_data, n_fft=self.h["n_fft"], num_mels=self.h["num_mels"], sampling_rate=self.h["sampling_rate"],
40
+ hop_size=self.h["hop_size"], win_size=self.h["win_size"], fmin=self.h["fmin"], fmax=self.h["fmax"])
41
+ return mel.squeeze(0).transpose(0, 1)
42
+
43
+ @torch.inference_mode()
44
+ def extract_mel_from_wav_batch(self, wav_data):
45
+ """
46
+ params:
47
+ wav_data: torch.tensor or numpy array, shape [Batch, T], wav data, should be 24k
48
+ return:
49
+ mel: [Batch, T, num_mels], torch.tensor
50
+ """
51
+
52
+ wav_data = torch.tensor(wav_data)
53
+
54
+ mel = get_melspec(wav=wav_data, n_fft=self.h["n_fft"], num_mels=self.h["num_mels"], sampling_rate=self.h["sampling_rate"],
55
+ hop_size=self.h["hop_size"], win_size=self.h["win_size"], fmin=self.h["fmin"], fmax=self.h["fmax"])
56
+ return mel.transpose(1, 2)
57
+
58
+ def decode_mel(self, mel):
59
+ """
60
+ params:
61
+ mel: [T, num_mels], torch.tensor
62
+ return:
63
+ wav: [1, T], torch.tensor
64
+ """
65
+ mel = mel.transpose(0, 1).unsqueeze(0).to(self.device)
66
+ wav = self.vocoder(mel)
67
+ return wav.squeeze(0)
68
+
69
+ def decode_mel_batch(self, mel):
70
+ """
71
+ params:
72
+ mel: [B, T, num_mels], torch.tensor
73
+ return:
74
+ wav: [B, 1, T], torch.tensor
75
+ """
76
+ mel = mel.transpose(1, 2).to(self.device)
77
+ wav = self.vocoder(mel)
78
+ return wav
79
+
80
+ @classmethod
81
+ def from_pretrained(cls, model_config, ckpt_path, device):
82
+ with open(model_config) as f:
83
+ data = f.read()
84
+ json_config = json.loads(data)
85
+ h = AttrDict(json_config)
86
+ vocoder = BigVGAN(h, True)
87
+ state_dict_g = load_checkpoint(ckpt_path, "cpu")
88
+ vocoder.load_state_dict(state_dict_g["generator"])
89
+
90
+ logger.info(">>> Load vocoder from {}".format(ckpt_path))
91
+ return cls(vocoder, device, h)
92
+
93
+
94
+
modules/audio_detokenizer/flow_matching/dit_block.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from flash_attn import flash_attn_varlen_func, flash_attn_varlen_qkvpacked_func
10
+
11
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
12
+ # x shape: bsz, seqlen, self.n_local_heads, self.head_hidden_dim / 2
13
+ # the last shape is "self.hidden_dim / 2" because we convert to complex
14
+ assert x.ndim == 4
15
+ assert freqs_cis.shape == (x.shape[0], x.shape[1], x.shape[-1]), \
16
+ f'x shape: {x.shape}, freqs_cis shape: {freqs_cis.shape}'
17
+
18
+ # reshape freq cis to match and apply pointwise multiply
19
+ # new shape: bsz, seq_len, 1, self.head_hidden_dim / 2
20
+ shape = [x.shape[0], x.shape[1], 1, x.shape[-1]]
21
+ return freqs_cis.view(*shape)
22
+
23
+
24
+ def apply_rotary_emb(
25
+ xq: torch.Tensor,
26
+ xk: torch.Tensor,
27
+ freqs_cis: torch.Tensor,
28
+ ):
29
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
30
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
31
+
32
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
33
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
34
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
35
+ return xq_out.type_as(xq), xk_out.type_as(xk)
36
+
37
+
38
+
39
+ class Attention(nn.Module):
40
+
41
+ def __init__(
42
+ self,
43
+ dim: int,
44
+ num_heads: int = 8,
45
+ qkv_bias: bool = False,
46
+ qk_norm: bool = False,
47
+ attn_drop: float = 0.,
48
+ proj_drop: float = 0.,
49
+ norm_layer: nn.Module = nn.LayerNorm,
50
+ flash_attention: bool = True
51
+ ) -> None:
52
+ super().__init__()
53
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
54
+ self.num_heads = num_heads
55
+ self.head_dim = dim // num_heads
56
+ self.scale = self.head_dim ** -0.5
57
+ self.fused_attn = flash_attention
58
+
59
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
60
+ self.qk_norm = qk_norm
61
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
62
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
63
+ self.attn_drop = nn.Dropout(attn_drop)
64
+ self.proj = nn.Linear(dim, dim)
65
+ self.proj_drop = nn.Dropout(proj_drop)
66
+
67
+ def forward(self, x: torch.Tensor, seq_len, cu_seqlens, max_seqlen, cu_seqlens_k, max_seqlen_k, rotary_pos_emb=None, incremental_state=None, nopadding=True) -> torch.Tensor:
68
+ B, N, C = x.shape
69
+
70
+ if self.fused_attn:
71
+ if nopadding:
72
+ qkv = self.qkv(x)
73
+ qkv = qkv.view(B * N, self.num_heads * 3, self.head_dim)
74
+ q, k, v = qkv.split([self.num_heads] * 3, dim=1)
75
+ q, k = self.q_norm(q), self.k_norm(k)
76
+
77
+ q = q.view(B, N, self.num_heads, self.head_dim)
78
+ k = k.view(B, N, self.num_heads, self.head_dim)
79
+ v = v.view(B, N, self.num_heads, self.head_dim)
80
+
81
+ if rotary_pos_emb is not None:
82
+ q, k = apply_rotary_emb(q, k, rotary_pos_emb)
83
+
84
+ if incremental_state is not None:
85
+ if "prev_k" in incremental_state:
86
+ prev_k = incremental_state["prev_k"]
87
+ k = torch.cat([prev_k, k], dim=1)
88
+
89
+ if "cur_k" not in incremental_state:
90
+ incremental_state["cur_k"] = {}
91
+ incremental_state["cur_k"] = k
92
+
93
+ if "prev_v" in incremental_state:
94
+ prev_v = incremental_state["prev_v"]
95
+ v = torch.cat([prev_v, v], dim=1)
96
+
97
+ if "cur_v" not in incremental_state:
98
+ incremental_state["cur_v"] = {}
99
+ incremental_state["cur_v"] = v
100
+
101
+ q = q.view(B * N, self.num_heads, self.head_dim)
102
+ k = k.view(-1, self.num_heads, self.head_dim)
103
+ v = v.view(-1, self.num_heads, self.head_dim)
104
+
105
+ x = flash_attn_varlen_func(
106
+ q=q,
107
+ k=k,
108
+ v=v,
109
+ cu_seqlens_q=cu_seqlens,
110
+ cu_seqlens_k=cu_seqlens_k,
111
+ max_seqlen_q=max_seqlen,
112
+ max_seqlen_k=max_seqlen_k,
113
+ dropout_p=self.attn_drop.p if self.training else 0.,
114
+ )
115
+ else:
116
+
117
+ if incremental_state is not None:
118
+ raise NotImplementedError("It is designed for batching inference. AR-chunk is not supported currently.")
119
+
120
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
121
+ if self.qk_norm:
122
+ q, k, v = qkv.unbind(2)
123
+ q, k = self.q_norm(q), self.k_norm(k)
124
+ # re-bind
125
+ qkv = torch.stack((q, k, v), dim=2)
126
+
127
+ # pack qkv with seq_len
128
+ qkv_collect = []
129
+ for i in range(qkv.shape[0]):
130
+ qkv_collect.append(
131
+ qkv[i, :seq_len[i], :, :, :]
132
+ )
133
+
134
+ qkv = torch.cat(qkv_collect, dim=0)
135
+
136
+ x = flash_attn_varlen_qkvpacked_func(qkv=qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, dropout_p=self.attn_drop.p if self.training else 0.)
137
+
138
+ # unpack and pad 0
139
+ x_collect = []
140
+ for i in range(B):
141
+ x_collect.append(
142
+ x[cu_seqlens[i]:cu_seqlens[i+1], :, :]
143
+ )
144
+ x = torch.nn.utils.rnn.pad_sequence(x_collect, batch_first=True, padding_value=0)
145
+
146
+ else:
147
+ q = q * self.scale
148
+ attn = q @ k.transpose(-2, -1)
149
+ attn = attn.softmax(dim=-1)
150
+ attn = self.attn_drop(attn)
151
+ x = attn @ v
152
+ x = x.transpose(1, 2)
153
+
154
+ x = x.reshape(B, N, C)
155
+ x = self.proj(x)
156
+ x = self.proj_drop(x)
157
+ return x
158
+
159
+
160
+ def modulate(x, shift, scale):
161
+ return x * (1 + scale) + shift
162
+
163
+
164
+ class FinalLayer(nn.Module):
165
+ """
166
+ The final layer of DiT.
167
+ """
168
+ def __init__(self, hidden_size, out_channels):
169
+ super().__init__()
170
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
171
+ self.linear = nn.Linear(hidden_size, out_channels, bias=True)
172
+ self.adaLN_modulation = nn.Sequential(
173
+ nn.SiLU(),
174
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
175
+ )
176
+
177
+ def forward(self, x, c):
178
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=2)
179
+ x = modulate(self.norm_final(x), shift, scale)
180
+ x = self.linear(x)
181
+ return x
182
+
183
+
184
+ class DiTBlock(nn.Module):
185
+ """
186
+ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
187
+ """
188
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, ffn_type="conv1d_conv1d", ffn_gated_glu=True, ffn_act_layer="gelu", ffn_conv_kernel_size=5, **block_kwargs):
189
+ super().__init__()
190
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
191
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
192
+
193
+
194
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
195
+
196
+ if ffn_type == "vanilla_mlp":
197
+ from timm.models.vision_transformer import Mlp
198
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
199
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
200
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
201
+ else:
202
+ raise NotImplementedError(f"FFN type {ffn_type} is not implemented")
203
+
204
+ self.adaLN_modulation = nn.Sequential(
205
+ nn.SiLU(),
206
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
207
+ )
208
+
209
+ def forward(self, x, c, seq_len, cu_seqlens, cu_maxlen, cu_seqlens_k, cu_maxlen_k, mask, rotary_pos_emb=None, incremental_state=None, nopadding=True):
210
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=2)
211
+
212
+ x_ = modulate(self.norm1(x), shift_msa, scale_msa)
213
+
214
+ if incremental_state is not None:
215
+ if "attn_kvcache" not in incremental_state:
216
+ incremental_state["attn_kvcache"] = {}
217
+ inc_attn = incremental_state["attn_kvcache"]
218
+ else:
219
+ inc_attn = None
220
+
221
+ x_ = self.attn(x_, seq_len=seq_len, cu_seqlens=cu_seqlens, max_seqlen=cu_maxlen, cu_seqlens_k=cu_seqlens_k, max_seqlen_k=cu_maxlen_k, rotary_pos_emb=rotary_pos_emb, incremental_state=inc_attn, nopadding=nopadding)
222
+
223
+ if not nopadding:
224
+ x_ = x_ * mask[:, :, None]
225
+
226
+ x = x + gate_msa * x_
227
+
228
+ x_ = modulate(self.norm2(x), shift_mlp, scale_mlp)
229
+
230
+ x_ = self.mlp(x_)
231
+
232
+ if not nopadding:
233
+ x_ = x_ * mask[:, :, None]
234
+
235
+ x = x + gate_mlp * x_
236
+ return x
modules/audio_detokenizer/flow_matching/model.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+ from modules.audio_detokenizer.flow_matching.dit_block import DiTBlock, FinalLayer
5
+
6
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0,
7
+ interpolation_factor: int = 1, max_seq_length: int = 4096):
8
+ print(f'using rope base theta = {theta}, interpolation factor = {interpolation_factor}')
9
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
10
+
11
+ # ROPE type-A extention
12
+ # we choose to use interpolation rather than extrapolation for better position encoding
13
+ # for scale purposes, t should be a float tensor
14
+ t = torch.arange(end, device=freqs.device).float()
15
+ scale = 1.0 / float(interpolation_factor)
16
+ t *= scale
17
+
18
+ freqs = torch.outer(t, freqs).float() # type: ignore
19
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
20
+
21
+ # Sometimes, we don't need so many rope emb as seq_len is smaller than max_pos_emb
22
+ # e.g. rope 1M but seqlen 32k, this will cause gpu memory waste
23
+ if max_seq_length < end:
24
+ freqs_cis = freqs_cis[:max_seq_length,].clone()
25
+ return freqs_cis
26
+
27
+
28
+ class TimestepEmbedder(nn.Module):
29
+ """
30
+ Embeds scalar timesteps into vector representations.
31
+ """
32
+ def __init__(self, hidden_size, frequency_embedding_size=256):
33
+ super().__init__()
34
+ self.mlp = nn.Sequential(
35
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
36
+ nn.SiLU(),
37
+ nn.Linear(hidden_size, hidden_size, bias=True),
38
+ )
39
+ self.frequency_embedding_size = frequency_embedding_size
40
+
41
+ @staticmethod
42
+ def timestep_embedding(t, dim, max_period=10000):
43
+ """
44
+ Create sinusoidal timestep embeddings.
45
+ :param t: a 1-D Tensor of N indices, one per batch element.
46
+ These may be fractional.
47
+ :param dim: the dimension of the output.
48
+ :param max_period: controls the minimum frequency of the embeddings.
49
+ :return: an (N, D) Tensor of positional embeddings.
50
+ """
51
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
52
+ half = dim // 2
53
+ freqs = torch.exp(
54
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
55
+ ).float().to(device=t.device)
56
+ args = t[:, None].float() * freqs[None]
57
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
58
+ if dim % 2:
59
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
60
+ return embedding
61
+
62
+ def forward(self, t):
63
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
64
+ t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
65
+ return t_emb
66
+
67
+
68
+ class SinusoidalPositionalEmbedding(nn.Module):
69
+ """This module produces sinusoidal positional embeddings of any length.
70
+
71
+ Padding symbols are ignored.
72
+ """
73
+
74
+ def __init__(self, embedding_dim, padding_idx, init_size=1024):
75
+ super().__init__()
76
+ self.embedding_dim = embedding_dim
77
+ self.padding_idx = padding_idx
78
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
79
+ init_size,
80
+ embedding_dim,
81
+ padding_idx,
82
+ )
83
+ self.register_buffer('_float_tensor', torch.FloatTensor(1))
84
+
85
+ @staticmethod
86
+ def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
87
+ """Build sinusoidal embeddings.
88
+
89
+ This matches the implementation in tensor2tensor, but differs slightly
90
+ from the description in Section 3.5 of "Attention Is All You Need".
91
+ """
92
+ half_dim = embedding_dim // 2 # d/2
93
+ emb = math.log(10000) / (half_dim - 1) # 2*log(10000)/(d-2)
94
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) # -2i/(d-2)*log(10000); i from 0 to (d-2)/2; shape: (d/2, )
95
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) # pos/[1000 ** (2i/(d-2))]; shape: (num_embeddings, d/2)
96
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) # shape: (num_embeddings, d)
97
+ if embedding_dim % 2 == 1:
98
+ # zero pad
99
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
100
+ if padding_idx is not None:
101
+ emb[padding_idx, :] = 0
102
+ return emb
103
+
104
+ def forward(self, input, incremental_state=None, timestep=None, **kwargs):
105
+ """Input is expected to be of size [bsz x seqlen]."""
106
+ bsz, seq_len = input.shape[:2]
107
+ max_pos = self.padding_idx + 1 + seq_len
108
+ if self.weights is None or max_pos > self.weights.size(0):
109
+ # recompute/expand embeddings if needed
110
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
111
+ max_pos,
112
+ self.embedding_dim,
113
+ self.padding_idx,
114
+ )
115
+ self.weights = self.weights.to(self._float_tensor)
116
+
117
+ if incremental_state is not None:
118
+ # positions is the same for every token when decoding a single step
119
+ pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
120
+ return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
121
+
122
+ positions = self.make_positions(input, self.padding_idx)
123
+ return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach() # (B, T, dim)
124
+
125
+ def max_positions(self):
126
+ """Maximum number of supported positions."""
127
+ return int(1e5) # an arbitrary large number
128
+
129
+ def make_positions(self, tensor, padding_idx):
130
+ """Replace non-padding symbols with their position numbers.
131
+
132
+ Position numbers begin at padding_idx+1. Padding symbols are ignored.
133
+ """
134
+ # The series of casts and type-conversions here are carefully
135
+ # balanced to both work with ONNX export and XLA. In particular XLA
136
+ # prefers ints, cumsum defaults to output longs, and ONNX doesn't know
137
+ # how to handle the dtype kwarg in cumsum.
138
+ mask = tensor.ne(padding_idx).int()
139
+ return (
140
+ torch.cumsum(mask, dim=1).type_as(mask) * mask
141
+ ).long() + padding_idx
142
+
143
+
144
+ class DiTPrefix(nn.Module):
145
+ """
146
+ Diffusion model with a Transformer backbone.
147
+ """
148
+ def __init__(
149
+ self,
150
+ input_size,
151
+ output_size,
152
+ semantic_vocab_size,
153
+ hidden_size=1024,
154
+ depth=12,
155
+ num_heads=4,
156
+ # mlp related
157
+ mlp_ratio=4.0,
158
+ ffn_type="conv1d_conv1d",
159
+ ffn_gated_glu=True,
160
+ ffn_act_layer="gelu",
161
+ ffn_conv_kernel_size=5,
162
+
163
+ # rope
164
+ use_rope=False,
165
+ rope_params={
166
+ "max_position_embeddings": 4096,
167
+ "rope_base": 10000.0,
168
+ "rope_interpolation_factor": 1.0,
169
+ },
170
+
171
+
172
+ position_embedding_type="sincos",
173
+ max_seq_len=4096,
174
+ prompt_cfg_dropout=0.0
175
+ ):
176
+ super().__init__()
177
+ self.num_heads = num_heads
178
+
179
+ self.prompt_cfg_dropout = prompt_cfg_dropout
180
+
181
+ self.t_embedder = TimestepEmbedder(hidden_size)
182
+
183
+ self.semantic_token_embedding = nn.Embedding(semantic_vocab_size, hidden_size)
184
+
185
+ self.input_linear = nn.Linear(input_size, hidden_size)
186
+
187
+ # position embedding
188
+ if position_embedding_type == "learnable":
189
+ self.position_embedding = nn.Embedding(max_seq_len+1, hidden_size)
190
+ elif position_embedding_type == "sincos":
191
+ self.position_embedding = SinusoidalPositionalEmbedding(hidden_size, 0, max_seq_len+1)
192
+ elif position_embedding_type == "skip":
193
+ self.position_embedding = None
194
+ else:
195
+ raise NotImplementedError("Position embedding type: {} not implemented.".format(position_embedding_type))
196
+
197
+ self.use_rope = use_rope
198
+
199
+ if self.use_rope:
200
+
201
+ assert hidden_size % num_heads == 0, "Hidden size must be divisible by num_heads for rope position embedding."
202
+ rope_dim = hidden_size // num_heads
203
+
204
+ self.rotary_pos_emb = precompute_freqs_cis(
205
+ rope_dim, rope_params["max_position_embeddings"],
206
+ theta=rope_params["rope_base"],
207
+ interpolation_factor=rope_params["rope_interpolation_factor"],
208
+ max_seq_length=max_seq_len
209
+ )
210
+
211
+ self.blocks = nn.ModuleList([
212
+ DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio,
213
+ ffn_type=ffn_type, ffn_conv_kernel_size=ffn_conv_kernel_size, ffn_gated_glu=ffn_gated_glu, ffn_act_layer=ffn_act_layer) for _ in range(depth)
214
+ ])
215
+ self.final_layer = FinalLayer(hidden_size, output_size)
216
+ self.initialize_weights()
217
+
218
+ def initialize_weights(self):
219
+ # Initialize transformer layers:
220
+ def _basic_init(module):
221
+ if isinstance(module, nn.Linear):
222
+ torch.nn.init.xavier_uniform_(module.weight)
223
+ if module.bias is not None:
224
+ nn.init.constant_(module.bias, 0)
225
+ self.apply(_basic_init)
226
+
227
+
228
+ # Initialize timestep embedding MLP:
229
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
230
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
231
+
232
+ # Zero-out adaLN modulation layers in DiT blocks:
233
+ for block in self.blocks:
234
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
235
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
236
+
237
+ # Zero-out output layers:
238
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
239
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
240
+ nn.init.constant_(self.final_layer.linear.weight, 0)
241
+ nn.init.constant_(self.final_layer.linear.bias, 0)
242
+
243
+ def forward(self, x, position_ids, t, condition, seq_len, cu_seqlens, cu_maxlen, cu_seqlens_k, cu_maxlen_k, mask, incremental_state=None, nopadding=True):
244
+ """
245
+ Forward pass of DiT.
246
+ x: (N, T, C) tensor of inputs (latent representations of speech)
247
+ position_ids: (N, T) tensor of positional indices
248
+ t: (N,) tensor of diffusion timesteps
249
+ condition: (N, T) tensor of semantic tokens
250
+ seq_len: (N,) tensor of sequence lengths
251
+ """
252
+
253
+ condition = self.semantic_token_embedding(condition) # (N, T, D)
254
+
255
+ x = self.input_linear(x)
256
+
257
+ if self.position_embedding is not None:
258
+ position_emb = self.position_embedding(position_ids)
259
+ x = x + position_emb
260
+
261
+ # ROPE
262
+ if self.use_rope:
263
+ bsz, seqlen = position_ids.shape
264
+ if self.rotary_pos_emb.device != position_ids.device:
265
+ self.rotary_pos_emb = self.rotary_pos_emb.to(position_ids.device)
266
+ rotary_pos_emb = torch.zeros((bsz, seqlen, self.rotary_pos_emb.shape[1]),
267
+ dtype=self.rotary_pos_emb.dtype,
268
+ device=self.rotary_pos_emb.device)
269
+ for b in range(bsz):
270
+ cur_rope = rotary_pos_emb[b]
271
+ cur_position_ids = position_ids[b]
272
+ cur_rope[:] = self.rotary_pos_emb[cur_position_ids]
273
+ else:
274
+ rotary_pos_emb = None
275
+
276
+ t = self.t_embedder(t) # (N, D)
277
+ c = t.unsqueeze(1) + condition # (N, T, D)
278
+
279
+
280
+ for block_idx, block in enumerate(self.blocks):
281
+ # x = block(x, c, attn_mask) # (N, T, D)
282
+ # XXX mask could be None because we always use full mask
283
+
284
+ if incremental_state is not None:
285
+ if block_idx not in incremental_state:
286
+ incremental_state[block_idx] = {}
287
+ incr = incremental_state[block_idx]
288
+ else:
289
+ incr = None
290
+
291
+ x = block(x=x, c=c, seq_len=seq_len, cu_seqlens=cu_seqlens, cu_maxlen=cu_maxlen, cu_seqlens_k=cu_seqlens_k, cu_maxlen_k=cu_maxlen_k, mask=mask, rotary_pos_emb=rotary_pos_emb, incremental_state=incr, nopadding=nopadding)
292
+
293
+ x = self.final_layer(x, c) # (N, T, C)
294
+ return x
295
+
modules/audio_detokenizer/flow_matching/ode_wrapper.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from functools import lru_cache
4
+ import copy
5
+
6
+
7
+ @lru_cache(maxsize=1)
8
+ def get_cached_zeros(numel, device="cpu", dtype=torch.float32):
9
+ return torch.zeros(numel, device=device, dtype=dtype)
10
+
11
+ class StreamingODEWrapperForPrefix(nn.Module):
12
+ def __init__(self, net, x_mask, x_cond, use_cfg=False, use_cfg_rescale=True, cfg_init=1.0, cfg_scale=4.0, cfg_schedule="linear", cfg_token_id=0):
13
+ super(StreamingODEWrapperForPrefix, self).__init__()
14
+ self.net = net
15
+ self.x_mask = x_mask
16
+ self.x_cond = x_cond
17
+
18
+ assert use_cfg == False, "cfg is not supported in streaming detokenizer"
19
+
20
+ self.use_cfg = use_cfg
21
+ self.use_cfg_rescale = use_cfg_rescale
22
+ self.cfg_init = cfg_init
23
+ self.cfg_scale = cfg_scale
24
+ self.cfg_token_id = cfg_token_id
25
+ self.cfg_schedule = cfg_schedule
26
+ self.position_ids = None
27
+ self.seq_len = None
28
+
29
+ self.incremental_state = {}
30
+ self.kv_cache_tokens = 0
31
+ self.cu_seqlens = None
32
+ self.cu_maxlen = None
33
+
34
+ self.cu_seqlens_k = None
35
+ self.cu_maxlen_k = None
36
+ self.previous_seqlen = None
37
+
38
+ def clear_all_states(self):
39
+ self.incremental_state = {}
40
+ self.kv_cache_tokens = 0
41
+ self.cu_seqlens = None
42
+ self.cu_maxlen = None
43
+
44
+ self.cu_seqlens_k = None
45
+ self.cu_maxlen_k = None
46
+ self.previous_seqlen = None
47
+
48
+ def state_dict(self):
49
+ return {
50
+ "incremental_state": copy.deepcopy(self.incremental_state),
51
+ "kv_cache_tokens": copy.deepcopy(self.kv_cache_tokens),
52
+ "cu_seqlens": copy.deepcopy(self.cu_seqlens),
53
+ "cu_maxlen": copy.deepcopy(self.cu_maxlen),
54
+ "cu_seqlens_k": copy.deepcopy(self.cu_seqlens_k),
55
+ "cu_maxlen_k": copy.deepcopy(self.cu_maxlen_k),
56
+ "previous_seqlen": copy.deepcopy(self.previous_seqlen)
57
+ }
58
+
59
+ def load_state_dict(self, state_dict):
60
+ self.incremental_state = state_dict["incremental_state"]
61
+ self.kv_cache_tokens = state_dict["kv_cache_tokens"]
62
+ self.cu_seqlens = state_dict["cu_seqlens"]
63
+ self.cu_maxlen = state_dict["cu_maxlen"]
64
+ self.cu_seqlens_k = state_dict["cu_seqlens_k"]
65
+ self.cu_maxlen_k = state_dict["cu_maxlen_k"]
66
+ self.previous_seqlen = state_dict["previous_seqlen"]
67
+
68
+ def set_conditions(self, x_mask, x_cond, start_position_id, cache={}):
69
+ if not self.use_cfg:
70
+ self.x_mask = x_mask
71
+ self.x_cond = x_cond
72
+ else:
73
+ self.x_cond = torch.cat((x_cond, x_cond), dim=0)
74
+ self.x_mask = torch.cat((x_mask, x_mask), dim=0)
75
+
76
+ position_ids_cur = [i for i in range(start_position_id, self.x_cond.shape[1] + start_position_id)]
77
+ position_ids = torch.tensor([position_ids_cur])
78
+
79
+
80
+ if not self.use_cfg:
81
+ self.position_ids = position_ids.to(self.x_cond.device).long()
82
+ self.seq_len = torch.Tensor([position_ids.shape[1]]).to(self.x_cond.device).long()
83
+ else:
84
+ self.position_ids = torch.cat((position_ids, position_ids), dim=0).to(self.x_cond.device).long()
85
+ self.seq_len = torch.Tensor([position_ids.shape[1], position_ids.shape[1]]).to(self.x_cond.device).long()
86
+
87
+ cu_seqlens = torch.cumsum(self.seq_len, dim=0)
88
+ self.cu_seqlens = torch.cat([torch.Tensor([0]).to(cu_seqlens.device), cu_seqlens], dim=0).int()
89
+ self.cu_maxlen = self.seq_len.cpu().max()
90
+
91
+ if self.cu_seqlens_k is None:
92
+ self.cu_seqlens_k = self.cu_seqlens
93
+ self.cu_maxlen_k = self.cu_maxlen
94
+ previous_seqlen = self.seq_len
95
+ else:
96
+ previous_seqlen_old = cache["previous_seqlen"]
97
+ previous_seqlen = previous_seqlen_old + self.seq_len
98
+ # calculate cu_seqlens_k
99
+ cu_seqlens_k = torch.cumsum(previous_seqlen, dim=0)
100
+ self.cu_seqlens_k = torch.cat([torch.Tensor([0]).to(cu_seqlens_k.device), cu_seqlens_k], dim=0).int()
101
+ self.cu_maxlen_k = previous_seqlen.cpu().max()
102
+ self.previous_seqlen = previous_seqlen
103
+ ret_cache = {
104
+ "previous_seqlen": previous_seqlen
105
+ }
106
+ return ret_cache
107
+
108
+ def update_incremental_state(self, reserve_kv_cache_tokens=0, max_kv_cache_tokens=900, condition_cache={"previous_seqlen"}):
109
+
110
+ assert reserve_kv_cache_tokens <= max_kv_cache_tokens, "reserve_kv_cache_tokens should be less than or equal to max_kv_cache_tokens"
111
+
112
+ for layer_idx, layer_cache in self.incremental_state.items():
113
+ # update attention kv cache
114
+ layer_cache["attn_kvcache"]["prev_k"] = layer_cache["attn_kvcache"]["cur_k"]
115
+ layer_cache["attn_kvcache"]["prev_v"] = layer_cache["attn_kvcache"]["cur_v"]
116
+
117
+ self.kv_cache_tokens = layer_cache["attn_kvcache"]["prev_k"].shape[1]
118
+
119
+ if self.kv_cache_tokens > max_kv_cache_tokens:
120
+ # drop old tokens from reserve kv cache tokens to max_kv_cache_tokens
121
+ reserve_tokens_excludeprompt = max_kv_cache_tokens - reserve_kv_cache_tokens
122
+
123
+ if reserve_kv_cache_tokens == 0:
124
+ layer_cache["attn_kvcache"]["prev_k"] = layer_cache["attn_kvcache"]["prev_k"][:, -reserve_tokens_excludeprompt:]
125
+ layer_cache["attn_kvcache"]["prev_v"] = layer_cache["attn_kvcache"]["prev_v"][:, -reserve_tokens_excludeprompt:]
126
+ elif reserve_tokens_excludeprompt == 0:
127
+ layer_cache["attn_kvcache"]["prev_k"] = layer_cache["attn_kvcache"]["prev_k"][:, :reserve_kv_cache_tokens]
128
+ layer_cache["attn_kvcache"]["prev_v"] = layer_cache["attn_kvcache"]["prev_v"][:, :reserve_kv_cache_tokens]
129
+ else:
130
+ layer_cache["attn_kvcache"]["prev_k"] = torch.cat([
131
+ layer_cache["attn_kvcache"]["prev_k"][:, :reserve_kv_cache_tokens],
132
+ layer_cache["attn_kvcache"]["prev_k"][:, -reserve_tokens_excludeprompt:]
133
+ ], dim=1)
134
+
135
+ layer_cache["attn_kvcache"]["prev_v"] = torch.cat([
136
+ layer_cache["attn_kvcache"]["prev_v"][:, :reserve_kv_cache_tokens],
137
+ layer_cache["attn_kvcache"]["prev_v"][:, -reserve_tokens_excludeprompt:]
138
+ ], dim=1)
139
+
140
+
141
+ bsz = layer_cache["attn_kvcache"]["prev_k"].shape[0]
142
+ self.previous_seqlen = torch.Tensor([layer_cache["attn_kvcache"]["prev_k"].shape[1] for i in range(bsz)]).to(layer_cache["attn_kvcache"]["prev_k"].device).long()
143
+ condition_cache["previous_seqlen"] = self.previous_seqlen
144
+ self.kv_cache_tokens = layer_cache["attn_kvcache"]["prev_k"].shape[1]
145
+
146
+ # clear current cache
147
+ layer_cache["attn_kvcache"].pop("cur_k")
148
+ layer_cache["attn_kvcache"].pop("cur_v")
149
+
150
+
151
+ def forward(self, t, x, args=None):
152
+ # t = torch.tensor([t * 1000] * x.shape[0], device=x.device, dtype=x.dtype).long()
153
+ t = get_cached_zeros(x.shape[0], device=x.device, dtype=torch.long) + (t * 1000).long()
154
+
155
+ if self.use_cfg:
156
+ raise NotImplementedError("cfg is not supported in streaming detokenizer.")
157
+ else:
158
+ pred_noise = self.net(x=x, condition=self.x_cond, t=t, position_ids=self.position_ids,
159
+ cu_seqlens=self.cu_seqlens, cu_maxlen=self.cu_maxlen,
160
+ cu_seqlens_k=self.cu_seqlens_k, cu_maxlen_k=self.cu_maxlen_k,
161
+ incremental_state=self.incremental_state, nopadding=True,
162
+ mask=None, seq_len=None
163
+ )
164
+ return pred_noise
modules/audio_detokenizer/flow_matching/scheduler.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from abc import abstractmethod, ABC
3
+ try:
4
+ from torchdyn.core import NeuralODE
5
+ NEURALODE_INSTALLED = True
6
+ except ImportError:
7
+ NEURALODE_INSTALLED = False
8
+
9
+ class SchedulerBase(ABC):
10
+ def __init__(self) -> None:
11
+ pass
12
+
13
+ @abstractmethod
14
+ def set_timesteps(self):
15
+ pass
16
+
17
+ @abstractmethod
18
+ def step(self):
19
+ pass
20
+
21
+ @abstractmethod
22
+ def add_noise(self):
23
+ pass
24
+
25
+
26
+ class StreamingFlowMatchingScheduler(SchedulerBase):
27
+ def __init__(self, timesteps=1000, sigma_min=1e-4,
28
+ ) -> None:
29
+ super().__init__()
30
+
31
+ self.sigma_min = sigma_min
32
+ self.timesteps = timesteps
33
+ self.t_min = 0
34
+ self.t_max = 1 - self.sigma_min
35
+
36
+ self.neural_ode = None
37
+
38
+
39
+ def set_timesteps(self, timesteps=15):
40
+ self.timesteps = timesteps
41
+
42
+ def step(self, xt, predicted_v):
43
+
44
+ h = (self.t_max - self.t_min) / self.timesteps
45
+ h = h * torch.ones(xt.shape[0], dtype=xt.dtype, device=xt.device)
46
+
47
+ xt = xt + h * predicted_v
48
+ return xt
49
+
50
+ def sample(self, ode_wrapper, time_steps, xt, verbose=False, x0=None):
51
+ h = (self.t_max - self.t_min) / self.timesteps
52
+ h = h * torch.ones(xt.shape[0], dtype=xt.dtype, device=xt.device)
53
+
54
+ if verbose:
55
+ gt_v = x0 - xt
56
+
57
+ for t in time_steps:
58
+ predicted_v = ode_wrapper(t, xt)
59
+ if verbose:
60
+ dist = torch.mean(torch.nn.functional.l1_loss(gt_v, predicted_v))
61
+ print("Time: {}, Distance: {}".format(t, dist))
62
+ xt = xt + h * predicted_v
63
+ return xt
64
+
65
+ def sample_by_neuralode(self, ode_wrapper, time_steps, xt, verbose=False, x0=None):
66
+ if not NEURALODE_INSTALLED:
67
+ raise ImportError("NeuralODE is not installed, please install it first.")
68
+
69
+ if self.neural_ode is None:
70
+ self.neural_ode = NeuralODE(ode_wrapper, solver='euler', sensitivity="adjoint", atol=self.sigma_min, rtol=self.sigma_min)
71
+
72
+ eval_points, traj = self.neural_ode(xt, time_steps)
73
+ return traj[-1]
74
+
75
+
76
+ def add_noise(self, original_samples: torch.FloatTensor,
77
+ noise: torch.FloatTensor,
78
+ timesteps: torch.IntTensor,):
79
+ ut = original_samples - (1 - self.sigma_min) * noise # 和ut的梯度没关系
80
+ t_unsqueeze = timesteps.unsqueeze(1).unsqueeze(1).float() / self.timesteps
81
+ x_noisy = t_unsqueeze * original_samples + (1. - (1 - self.sigma_min) * t_unsqueeze) * noise
82
+ return x_noisy, ut
modules/audio_detokenizer/semantic_fm_prefix_streaming.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ import logging
3
+ import time
4
+
5
+ import os
6
+ import torch
7
+
8
+ from modules.audio_detokenizer.flow_matching.ode_wrapper import StreamingODEWrapperForPrefix
9
+ from modules.audio_detokenizer.flow_matching.model import DiTPrefix
10
+ from modules.audio_detokenizer.flow_matching.scheduler import StreamingFlowMatchingScheduler
11
+
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class StreamingSemanticFMWrapper:
17
+ def __init__(self, speech_model: DiTPrefix, max_kv_cache_tokens=900, max_prompt_chunk=2,
18
+ use_cfg=True, use_cfg_rescale=True, cfg_init=1.5, cfg_scale=7.5, cfg_schedule="linear", cfg_token_id=0,
19
+ normalize_mel=False, mel_mean=None, mel_std=None, device: torch.device = torch.device("cpu")) -> None:
20
+
21
+ self.dtype = torch.bfloat16
22
+ self.speech_model = speech_model.to(device).to(self.dtype)
23
+ self.speech_model = self.speech_model.eval()
24
+ self.device = device
25
+ self.normalize_mel = normalize_mel
26
+ self.mel_mean = mel_mean
27
+ self.mel_std = mel_std
28
+
29
+ self.use_cfg = use_cfg
30
+ self.use_cfg_rescale = use_cfg_rescale
31
+ self.cfg_init = cfg_init
32
+ self.cfg_scale = cfg_scale
33
+ self.cfg_schedule = cfg_schedule
34
+
35
+ self.incremental_state = {}
36
+ self.condition_cache = {"previous_seqlen": 0}
37
+
38
+ logger.info(f">>> SemanticFMWrapper initialized with use_cfg={use_cfg}, use_cfg_rescale={use_cfg_rescale}, cfg_init={cfg_init}, cfg_scale={cfg_scale}, cfg_schedule={cfg_schedule}")
39
+
40
+ self.scheduler = StreamingFlowMatchingScheduler()
41
+ self.ode_wrapper = StreamingODEWrapperForPrefix(net=self.speech_model, x_mask=None, x_cond=None,
42
+ use_cfg=use_cfg, use_cfg_rescale=use_cfg_rescale, cfg_init=cfg_init, cfg_scale=cfg_scale, cfg_schedule=cfg_schedule, cfg_token_id=cfg_token_id)
43
+
44
+ self.max_kv_cache_tokens = max_kv_cache_tokens
45
+ self.max_prompt_chunk = max_prompt_chunk
46
+ self.reserve_kv_cache_tokens = 0
47
+
48
+ @torch.inference_mode()
49
+ def infer_chunk(self, xt_chunk, semantic_tokens_chunk, start_position_id,
50
+ cache = None, look_ahead_tokens=0,
51
+ ode_steps=15, verbose=False, ode_solver="neural_ode_euler"):
52
+ """
53
+ semantic_tokens: [T_1], torch.LongTensor
54
+ xt: [T_2, 80], torch.Tensor, DO NOT normalize it outside
55
+ ode_steps: int, number of ode steps, default 15
56
+ verbose: bool, default False
57
+ ode_solver: str, ode solver, expected in ("neural_ode_euler", "naive_euler"), default "neural_ode_euler"
58
+ """
59
+ bs = 1
60
+
61
+ self.scheduler.set_timesteps(ode_steps)
62
+
63
+ semantic_tokens_chunk = semantic_tokens_chunk.unsqueeze(0).to(self.device)
64
+ xt_chunk = xt_chunk.unsqueeze(0).to(self.device).to(self.dtype)
65
+
66
+ t_span = torch.linspace(0, 1, self.scheduler.timesteps)
67
+
68
+ x_mask = torch.zeros(bs, xt_chunk.shape[1], device=self.device).bool()
69
+
70
+ cache_ret = self.ode_wrapper.set_conditions(x_mask=x_mask, x_cond=semantic_tokens_chunk, start_position_id=start_position_id, cache=self.condition_cache)
71
+
72
+ if verbose:
73
+ t_start = time.time()
74
+ if ode_solver == "neural_ode_euler":
75
+ x_t = self.scheduler.sample_by_neuralode(self.ode_wrapper, time_steps=t_span, xt=xt_chunk, verbose=False)
76
+ elif ode_solver == "naive_euler":
77
+ x_t = self.scheduler.sample(ode_wrapper=self.ode_wrapper, time_steps=t_span, xt=xt_chunk, verbose=False)
78
+ else:
79
+ raise NotImplementedError("ode_solver should be in ('neural_ode_euler', 'naive_euler')")
80
+
81
+ if look_ahead_tokens > 0:
82
+ semantic_tokens_left = semantic_tokens_chunk.view(-1)[-look_ahead_tokens:]
83
+ cache["semantic_token"] = semantic_tokens_left
84
+ x_t_ret = x_t[:, :-look_ahead_tokens, :]
85
+ else:
86
+ x_t_ret = x_t
87
+
88
+ if look_ahead_tokens > 0:
89
+ x_mask = torch.zeros(bs, xt_chunk.shape[1] - look_ahead_tokens, device=self.device).bool()
90
+ self.condition_cache = self.ode_wrapper.set_conditions(x_mask=x_mask, x_cond=semantic_tokens_chunk[:, :-look_ahead_tokens], start_position_id=start_position_id, cache=self.condition_cache)
91
+ self.ode_wrapper(torch.Tensor([0.999]).to(x_t_ret.device), x_t_ret)
92
+ else:
93
+ self.condition_cache = cache_ret
94
+
95
+ if verbose:
96
+ t_end = time.time()
97
+ logger.info(f"[ODE Chunk] Time cost: {t_end - t_start}")
98
+
99
+ if self.normalize_mel:
100
+ x_t_ret = x_t_ret * self.mel_std + self.mel_mean
101
+ return x_t_ret.squeeze(0)
102
+
103
+
104
+ @torch.inference_mode()
105
+ def infer_mel(self, semantic_tokens, ode_steps=15, chunk_size=150, verbose=False, ode_solver="neural_ode_euler"):
106
+ """
107
+ semantic_tokens: [T_1], torch.LongTensor
108
+ prompt: [T_2, 80], torch.Tensor, DO NOT normalize it outside
109
+ prompt_semantic_tokens, [T_2], torch.LongTensor
110
+ ode_steps: int, number of ode steps, default 15
111
+ verbose: bool, default False
112
+ ode_solver: str, ode solver, expected in ("neural_ode_euler", "naive_euler"), default "neural_ode_euler"
113
+ """
114
+ assert semantic_tokens.dim() == 1
115
+
116
+ x_t = torch.randn(semantic_tokens.shape[0], 80).to(self.device).to(self.dtype)
117
+
118
+ seq_len = semantic_tokens.shape[0]
119
+
120
+ num_chunks = seq_len // chunk_size
121
+ if seq_len % chunk_size != 0:
122
+ num_chunks += 1
123
+
124
+ x_pred_collect = []
125
+
126
+ if verbose:
127
+ t_start = time.time()
128
+
129
+ for chunk_id in range(num_chunks):
130
+ start = chunk_id * chunk_size
131
+ end = min(start + chunk_size, seq_len)
132
+ semantic_tokens_chunk = semantic_tokens[start:end]
133
+ x_t_chunk = x_t[start:end, :]
134
+
135
+ x_pred = self.infer_chunk(xt_chunk=x_t_chunk, semantic_tokens_chunk=semantic_tokens_chunk, start_position_id=self.start_position_id,
136
+ ode_steps=ode_steps, verbose=verbose, ode_solver=ode_solver)
137
+ self.start_position_id += end - start
138
+ self.update_incremental_state()
139
+
140
+ x_pred_collect.append(x_pred)
141
+
142
+ if verbose:
143
+ t_end = time.time()
144
+ logger.info(f"[ODE] Time cost: {t_end - t_start}")
145
+
146
+ x_pred = torch.cat(x_pred_collect, dim=0)
147
+
148
+ return x_pred
149
+
150
+ def clear_all_states(self):
151
+ self.start_position_id = 0
152
+ self.condition_cache = {"previous_seqlen": 0}
153
+ self.ode_wrapper.clear_all_states()
154
+
155
+ def state_dict(self):
156
+ return {
157
+ "start_position_id": self.start_position_id,
158
+ "ode_wrapper": self.ode_wrapper.state_dict(),
159
+ "condition_cache": self.condition_cache
160
+ }
161
+
162
+ def load_state_dict(self, state_dict):
163
+ if state_dict is not None:
164
+ self.start_position_id = state_dict["start_position_id"]
165
+ self.ode_wrapper.load_state_dict(state_dict["ode_wrapper"])
166
+ self.condition_cache = state_dict["condition_cache"]
167
+
168
+ def update_incremental_state(self):
169
+ self.ode_wrapper.update_incremental_state(reserve_kv_cache_tokens=0, max_kv_cache_tokens=self.max_kv_cache_tokens, condition_cache=self.condition_cache)
170
+
171
+ @torch.inference_mode()
172
+ def prefill(self, mel, semantic_token, chunk_size=150, verbose=False):
173
+ """
174
+ mel: [T, 80], torch.Tensor
175
+ semantic_token: [T], torch.LongTensor
176
+ chunk_size: int, default 150
177
+ """
178
+ assert mel.dim() == 2
179
+ assert semantic_token.dim() == 1
180
+ assert semantic_token.shape[0] == mel.shape[0], "Semantic token and mel shape mismatch"
181
+ seq_len = mel.shape[0]
182
+ num_chunks = min(seq_len // chunk_size, self.max_prompt_chunk)
183
+ start_pos = seq_len - num_chunks * chunk_size
184
+
185
+ res_mel = mel[:start_pos, :]
186
+ res_semantic_token = semantic_token[:start_pos]
187
+ self.prefill_chunk(res_mel, res_semantic_token, start_position_id=self.start_position_id)
188
+ self.start_position_id += start_pos
189
+ self.update_incremental_state()
190
+ self.reserve_kv_cache_tokens += self.ode_wrapper.kv_cache_tokens
191
+
192
+ if verbose:
193
+ logger.info("Prefilling prompt with {} chunks".format(num_chunks))
194
+ start_time = time.time()
195
+
196
+ for chunk_id in range(num_chunks):
197
+ start = start_pos + chunk_id * chunk_size
198
+ end = start + chunk_size
199
+ mel_chunk = mel[start:end, :]
200
+ semantic_token_chunk = semantic_token[start:end]
201
+
202
+ self.prefill_chunk(mel_chunk, semantic_token_chunk, start_position_id=self.start_position_id)
203
+ self.start_position_id += end - start
204
+
205
+ self.update_incremental_state()
206
+ self.reserve_kv_cache_tokens += self.ode_wrapper.kv_cache_tokens
207
+
208
+
209
+ if verbose:
210
+ logger.info("Prefilling done in {:.2f} seconds".format(time.time() - start_time))
211
+
212
+ def prefill_chunk(self, mel_chunk, semantic_tokens_chunk, start_position_id=0):
213
+ """
214
+ mel_chunk: [T, 80], torch.Tensor, T is the chunk size
215
+ semantic_tokens_chunk: [T], torch.LongTensor
216
+ start_position_id: int, default 0
217
+ """
218
+ bs = 1
219
+
220
+ semantic_tokens_chunk = semantic_tokens_chunk.unsqueeze(0).to(self.device)
221
+ mel_chunk = mel_chunk.unsqueeze(0).to(self.device).to(self.dtype)
222
+
223
+ if self.normalize_mel:
224
+ mel_chunk = (mel_chunk - self.mel_mean) / self.mel_std
225
+
226
+ x_mask = torch.zeros(bs, mel_chunk.shape[1], device=self.device).bool()
227
+
228
+ self.condition_cache = self.ode_wrapper.set_conditions(x_mask=x_mask, x_cond=semantic_tokens_chunk, start_position_id=start_position_id, cache=self.condition_cache)
229
+
230
+ x_t = torch.Tensor([0.999]).to(self.device)
231
+
232
+ self.ode_wrapper(x_t, mel_chunk)
233
+
234
+
235
+ @classmethod
236
+ def from_pretrained(cls, model_config, ckpt_path, device, max_prompt_chunk=2, max_kv_cache_tokens=900, use_cfg=True, use_cfg_rescale=True, cfg_init=1.5, cfg_scale=7.5, cfg_schedule="linear"):
237
+
238
+ # open yaml file
239
+ with open(model_config, 'r') as f:
240
+ config = yaml.safe_load(f)
241
+ model_config = config["model"]["dit"]
242
+ dit = DiTPrefix(
243
+ input_size=model_config["input_size"],
244
+ semantic_vocab_size=model_config["semantic_vocab_size"] + 1,
245
+ hidden_size=model_config["hidden_size"],
246
+ depth=model_config["depth"],
247
+ num_heads=model_config["num_heads"],
248
+ mlp_ratio=model_config["mlp_ratio"],
249
+ ffn_type=model_config.get("ffn_type", "conv1d_conv1d"),
250
+ ffn_gated_glu=model_config.get("ffn_gated_glu", True),
251
+ ffn_act_layer=model_config.get("ffn_act_layer", "gelu"),
252
+ ffn_conv_kernel_size=model_config.get("ffn_conv_kernel_size", 5),
253
+
254
+ use_rope=model_config.get("use_rope", False),
255
+ rope_params=model_config.get("rope_params", { "max_position_embeddings": 4096,"rope_base": 10000,"rope_interpolation_factor": 1 }),
256
+
257
+ position_embedding_type=model_config["position_embedding_type"],
258
+ max_seq_len=model_config["max_seq_len"],
259
+ output_size=model_config["input_size"],
260
+ prompt_cfg_dropout=0
261
+ )
262
+ cfg_semantic_token_id = model_config["semantic_vocab_size"]
263
+
264
+ # load state_dict
265
+ state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)["state_dict"]
266
+ speech_model_params = {k.replace("speech_model.", ""): v for k, v in state_dict.items() if "speech_model" in k}
267
+ dit.load_state_dict(speech_model_params, strict=True)
268
+ logger.info(f">>> Loaded checkpoint from {ckpt_path}")
269
+
270
+ return cls(speech_model=dit, device=device, normalize_mel=config["normalize_mel"], mel_mean=config["mel_mean"], mel_std=config["mel_std"], max_prompt_chunk=max_prompt_chunk, max_kv_cache_tokens=max_kv_cache_tokens,
271
+ use_cfg=use_cfg, use_cfg_rescale=use_cfg_rescale, cfg_init=cfg_init, cfg_scale=cfg_scale, cfg_schedule=cfg_schedule, cfg_token_id=cfg_semantic_token_id)
272
+
273
+
modules/audio_detokenizer/vocoder/activations.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, sin, pow
3
+ from torch.nn import Parameter
4
+
5
+
6
+ class Snake(nn.Module):
7
+ """
8
+ Implementation of a sine-based periodic activation function
9
+ Shape:
10
+ - Input: (B, C, T)
11
+ - Output: (B, C, T), same shape as the input
12
+ Parameters:
13
+ - alpha - trainable parameter
14
+ References:
15
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
16
+ https://arxiv.org/abs/2006.08195
17
+ Examples:
18
+ >>> a1 = snake(256)
19
+ >>> x = torch.randn(256)
20
+ >>> x = a1(x)
21
+ """
22
+
23
+ def __init__(
24
+ self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
25
+ ):
26
+ """
27
+ Initialization.
28
+ INPUT:
29
+ - in_features: shape of the input
30
+ - alpha: trainable parameter
31
+ alpha is initialized to 1 by default, higher values = higher-frequency.
32
+ alpha will be trained along with the rest of your model.
33
+ """
34
+ super(Snake, self).__init__()
35
+ self.in_features = in_features
36
+
37
+ # Initialize alpha
38
+ self.alpha_logscale = alpha_logscale
39
+ if self.alpha_logscale: # Log scale alphas initialized to zeros
40
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
41
+ else: # Linear scale alphas initialized to ones
42
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
43
+
44
+ self.alpha.requires_grad = alpha_trainable
45
+
46
+ self.no_div_by_zero = 0.000000001
47
+
48
+ def forward(self, x):
49
+ """
50
+ Forward pass of the function.
51
+ Applies the function to the input elementwise.
52
+ Snake ∶= x + 1/a * sin^2 (xa)
53
+ """
54
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # Line up with x to [B, C, T]
55
+ if self.alpha_logscale:
56
+ alpha = torch.exp(alpha)
57
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
58
+
59
+ return x
60
+
61
+
62
+ class SnakeBeta(nn.Module):
63
+ """
64
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
65
+ Shape:
66
+ - Input: (B, C, T)
67
+ - Output: (B, C, T), same shape as the input
68
+ Parameters:
69
+ - alpha - trainable parameter that controls frequency
70
+ - beta - trainable parameter that controls magnitude
71
+ References:
72
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
73
+ https://arxiv.org/abs/2006.08195
74
+ Examples:
75
+ >>> a1 = snakebeta(256)
76
+ >>> x = torch.randn(256)
77
+ >>> x = a1(x)
78
+ """
79
+
80
+ def __init__(
81
+ self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
82
+ ):
83
+ """
84
+ Initialization.
85
+ INPUT:
86
+ - in_features: shape of the input
87
+ - alpha - trainable parameter that controls frequency
88
+ - beta - trainable parameter that controls magnitude
89
+ alpha is initialized to 1 by default, higher values = higher-frequency.
90
+ beta is initialized to 1 by default, higher values = higher-magnitude.
91
+ alpha will be trained along with the rest of your model.
92
+ """
93
+ super(SnakeBeta, self).__init__()
94
+ self.in_features = in_features
95
+
96
+ # Initialize alpha
97
+ self.alpha_logscale = alpha_logscale
98
+ if self.alpha_logscale: # Log scale alphas initialized to zeros
99
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
100
+ self.beta = Parameter(torch.zeros(in_features) * alpha)
101
+ else: # Linear scale alphas initialized to ones
102
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
103
+ self.beta = Parameter(torch.ones(in_features) * alpha)
104
+
105
+ self.alpha.requires_grad = alpha_trainable
106
+ self.beta.requires_grad = alpha_trainable
107
+
108
+ self.no_div_by_zero = 0.000000001
109
+
110
+ def forward(self, x):
111
+ """
112
+ Forward pass of the function.
113
+ Applies the function to the input elementwise.
114
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
115
+ """
116
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # Line up with x to [B, C, T]
117
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
118
+ if self.alpha_logscale:
119
+ alpha = torch.exp(alpha)
120
+ beta = torch.exp(beta)
121
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
122
+
123
+ return x
modules/audio_detokenizer/vocoder/alias_free_activation/__init__.py ADDED
File without changes
modules/audio_detokenizer/vocoder/alias_free_activation/cuda/__init__.py ADDED
File without changes
modules/audio_detokenizer/vocoder/alias_free_activation/cuda/activation1d.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from ..torch.resample import UpSample1d, DownSample1d
7
+
8
+ # load fused CUDA kernel: this enables importing anti_alias_activation_cuda
9
+ from modules.audio_detokenizer.vocoder.alias_free_activation.cuda import load
10
+
11
+ anti_alias_activation_cuda = load.load()
12
+
13
+
14
+ class FusedAntiAliasActivation(torch.autograd.Function):
15
+ """
16
+ Assumes filter size 12, replication padding on upsampling/downsampling, and logscale alpha/beta parameters as inputs.
17
+ The hyperparameters are hard-coded in the kernel to maximize speed.
18
+ NOTE: The fused kenrel is incorrect for Activation1d with different hyperparameters.
19
+ """
20
+
21
+ @staticmethod
22
+ def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta):
23
+ activation_results = anti_alias_activation_cuda.forward(
24
+ inputs, up_ftr, down_ftr, alpha, beta
25
+ )
26
+
27
+ return activation_results
28
+
29
+ @staticmethod
30
+ def backward(ctx, output_grads):
31
+ raise NotImplementedError
32
+ return output_grads, None, None
33
+
34
+
35
+ class Activation1d(nn.Module):
36
+ def __init__(
37
+ self,
38
+ activation,
39
+ up_ratio: int = 2,
40
+ down_ratio: int = 2,
41
+ up_kernel_size: int = 12,
42
+ down_kernel_size: int = 12,
43
+ fused: bool = True,
44
+ ):
45
+ super().__init__()
46
+ self.up_ratio = up_ratio
47
+ self.down_ratio = down_ratio
48
+ self.act = activation
49
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
50
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
51
+
52
+ self.fused = fused # Whether to use fused CUDA kernel or not
53
+
54
+ def forward(self, x):
55
+ if not self.fused:
56
+ x = self.upsample(x)
57
+ x = self.act(x)
58
+ x = self.downsample(x)
59
+ return x
60
+ else:
61
+ if self.act.__class__.__name__ == "Snake":
62
+ beta = self.act.alpha.data # Snake uses same params for alpha and beta
63
+ else:
64
+ beta = (
65
+ self.act.beta.data
66
+ ) # Snakebeta uses different params for alpha and beta
67
+ alpha = self.act.alpha.data
68
+ if (
69
+ not self.act.alpha_logscale
70
+ ): # Exp baked into cuda kernel, cancel it out with a log
71
+ alpha = torch.log(alpha)
72
+ beta = torch.log(beta)
73
+
74
+ x = FusedAntiAliasActivation.apply(
75
+ x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta
76
+ )
77
+ return x
modules/audio_detokenizer/vocoder/alias_free_activation/cuda/anti_alias_activation.cpp ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #include <torch/extension.h>
18
+
19
+ extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta);
20
+
21
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
22
+ m.def("forward", &fwd_cuda, "Anti-Alias Activation forward (CUDA)");
23
+ }
modules/audio_detokenizer/vocoder/alias_free_activation/cuda/anti_alias_activation_cuda.cu ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #include <ATen/ATen.h>
18
+ #include <cuda.h>
19
+ #include <cuda_runtime.h>
20
+ #include <cuda_fp16.h>
21
+ #include <cuda_profiler_api.h>
22
+ #include <ATen/cuda/CUDAContext.h>
23
+ #include <torch/extension.h>
24
+ #include "type_shim.h"
25
+ #include <assert.h>
26
+ #include <cfloat>
27
+ #include <limits>
28
+ #include <stdint.h>
29
+ #include <c10/macros/Macros.h>
30
+
31
+ namespace
32
+ {
33
+ // Hard-coded hyperparameters
34
+ // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
35
+ constexpr int ELEMENTS_PER_LDG_STG = 1; //(WARP_ITERATIONS < 4) ? 1 : 4;
36
+ constexpr int BUFFER_SIZE = 32;
37
+ constexpr int FILTER_SIZE = 12;
38
+ constexpr int HALF_FILTER_SIZE = 6;
39
+ constexpr int UPSAMPLE_REPLICATION_PAD = 5; // 5 on each side, matching torch impl
40
+ constexpr int DOWNSAMPLE_REPLICATION_PAD_LEFT = 5; // matching torch impl
41
+ constexpr int DOWNSAMPLE_REPLICATION_PAD_RIGHT = 6; // matching torch impl
42
+
43
+ template <typename input_t, typename output_t, typename acc_t>
44
+ __global__ void anti_alias_activation_forward(
45
+ output_t *dst,
46
+ const input_t *src,
47
+ const input_t *up_ftr,
48
+ const input_t *down_ftr,
49
+ const input_t *alpha,
50
+ const input_t *beta,
51
+ int batch_size,
52
+ int channels,
53
+ int seq_len)
54
+ {
55
+ // Up and downsample filters
56
+ input_t up_filter[FILTER_SIZE];
57
+ input_t down_filter[FILTER_SIZE];
58
+
59
+ // Load data from global memory including extra indices reserved for replication paddings
60
+ input_t elements[2 * FILTER_SIZE + 2 * BUFFER_SIZE + 2 * UPSAMPLE_REPLICATION_PAD] = {0};
61
+ input_t intermediates[2 * FILTER_SIZE + 2 * BUFFER_SIZE + DOWNSAMPLE_REPLICATION_PAD_LEFT + DOWNSAMPLE_REPLICATION_PAD_RIGHT] = {0};
62
+
63
+ // Output stores downsampled output before writing to dst
64
+ output_t output[BUFFER_SIZE];
65
+
66
+ // blockDim/threadIdx = (128, 1, 1)
67
+ // gridDim/blockIdx = (seq_blocks, channels, batches)
68
+ int block_offset = (blockIdx.x * 128 * BUFFER_SIZE + seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
69
+ int local_offset = threadIdx.x * BUFFER_SIZE;
70
+ int seq_offset = blockIdx.x * 128 * BUFFER_SIZE + local_offset;
71
+
72
+ // intermediate have double the seq_len
73
+ int intermediate_local_offset = threadIdx.x * BUFFER_SIZE * 2;
74
+ int intermediate_seq_offset = blockIdx.x * 128 * BUFFER_SIZE * 2 + intermediate_local_offset;
75
+
76
+ // Get values needed for replication padding before moving pointer
77
+ const input_t *right_most_pntr = src + (seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
78
+ input_t seq_left_most_value = right_most_pntr[0];
79
+ input_t seq_right_most_value = right_most_pntr[seq_len - 1];
80
+
81
+ // Move src and dst pointers
82
+ src += block_offset + local_offset;
83
+ dst += block_offset + local_offset;
84
+
85
+ // Alpha and beta values for snake activatons. Applies exp by default
86
+ alpha = alpha + blockIdx.y;
87
+ input_t alpha_val = expf(alpha[0]);
88
+ beta = beta + blockIdx.y;
89
+ input_t beta_val = expf(beta[0]);
90
+
91
+ #pragma unroll
92
+ for (int it = 0; it < FILTER_SIZE; it += 1)
93
+ {
94
+ up_filter[it] = up_ftr[it];
95
+ down_filter[it] = down_ftr[it];
96
+ }
97
+
98
+ // Apply replication padding for upsampling, matching torch impl
99
+ #pragma unroll
100
+ for (int it = -HALF_FILTER_SIZE; it < BUFFER_SIZE + HALF_FILTER_SIZE; it += 1)
101
+ {
102
+ int element_index = seq_offset + it; // index for element
103
+ if ((element_index < 0) && (element_index >= -UPSAMPLE_REPLICATION_PAD))
104
+ {
105
+ elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_left_most_value;
106
+ }
107
+ if ((element_index >= seq_len) && (element_index < seq_len + UPSAMPLE_REPLICATION_PAD))
108
+ {
109
+ elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_right_most_value;
110
+ }
111
+ if ((element_index >= 0) && (element_index < seq_len))
112
+ {
113
+ elements[2 * (HALF_FILTER_SIZE + it)] = 2 * src[it];
114
+ }
115
+ }
116
+
117
+ // Apply upsampling strided convolution and write to intermediates. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT for replication padding of the downsampilng conv later
118
+ #pragma unroll
119
+ for (int it = 0; it < (2 * BUFFER_SIZE + 2 * FILTER_SIZE); it += 1)
120
+ {
121
+ input_t acc = 0.0;
122
+ int element_index = intermediate_seq_offset + it; // index for intermediate
123
+ #pragma unroll
124
+ for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
125
+ {
126
+ if ((element_index + f_idx) >= 0)
127
+ {
128
+ acc += up_filter[f_idx] * elements[it + f_idx];
129
+ }
130
+ }
131
+ intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] = acc;
132
+ }
133
+
134
+ // Apply activation function. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT and DOWNSAMPLE_REPLICATION_PAD_RIGHT for replication padding of the downsampilng conv later
135
+ double no_div_by_zero = 0.000000001;
136
+ #pragma unroll
137
+ for (int it = 0; it < 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it += 1)
138
+ {
139
+ intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] += (1.0 / (beta_val + no_div_by_zero)) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val);
140
+ }
141
+
142
+ // Apply replication padding before downsampling conv from intermediates
143
+ #pragma unroll
144
+ for (int it = 0; it < DOWNSAMPLE_REPLICATION_PAD_LEFT; it += 1)
145
+ {
146
+ intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT];
147
+ }
148
+ #pragma unroll
149
+ for (int it = DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it < DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE + DOWNSAMPLE_REPLICATION_PAD_RIGHT; it += 1)
150
+ {
151
+ intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE - 1];
152
+ }
153
+
154
+ // Apply downsample strided convolution (assuming stride=2) from intermediates
155
+ #pragma unroll
156
+ for (int it = 0; it < BUFFER_SIZE; it += 1)
157
+ {
158
+ input_t acc = 0.0;
159
+ #pragma unroll
160
+ for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
161
+ {
162
+ // Add constant DOWNSAMPLE_REPLICATION_PAD_RIGHT to match torch implementation
163
+ acc += down_filter[f_idx] * intermediates[it * 2 + f_idx + DOWNSAMPLE_REPLICATION_PAD_RIGHT];
164
+ }
165
+ output[it] = acc;
166
+ }
167
+
168
+ // Write output to dst
169
+ #pragma unroll
170
+ for (int it = 0; it < BUFFER_SIZE; it += ELEMENTS_PER_LDG_STG)
171
+ {
172
+ int element_index = seq_offset + it;
173
+ if (element_index < seq_len)
174
+ {
175
+ dst[it] = output[it];
176
+ }
177
+ }
178
+
179
+ }
180
+
181
+ template <typename input_t, typename output_t, typename acc_t>
182
+ void dispatch_anti_alias_activation_forward(
183
+ output_t *dst,
184
+ const input_t *src,
185
+ const input_t *up_ftr,
186
+ const input_t *down_ftr,
187
+ const input_t *alpha,
188
+ const input_t *beta,
189
+ int batch_size,
190
+ int channels,
191
+ int seq_len)
192
+ {
193
+ if (seq_len == 0)
194
+ {
195
+ return;
196
+ }
197
+ else
198
+ {
199
+ // Use 128 threads per block to maximimize gpu utilization
200
+ constexpr int threads_per_block = 128;
201
+ constexpr int seq_len_per_block = 4096;
202
+ int blocks_per_seq_len = (seq_len + seq_len_per_block - 1) / seq_len_per_block;
203
+ dim3 blocks(blocks_per_seq_len, channels, batch_size);
204
+ dim3 threads(threads_per_block, 1, 1);
205
+
206
+ anti_alias_activation_forward<input_t, output_t, acc_t>
207
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, up_ftr, down_ftr, alpha, beta, batch_size, channels, seq_len);
208
+ }
209
+ }
210
+ }
211
+
212
+ extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta)
213
+ {
214
+ // Input is a 3d tensor with dimensions [batches, channels, seq_len]
215
+ const int batches = input.size(0);
216
+ const int channels = input.size(1);
217
+ const int seq_len = input.size(2);
218
+
219
+ // Output
220
+ auto act_options = input.options().requires_grad(false);
221
+
222
+ torch::Tensor anti_alias_activation_results =
223
+ torch::empty({batches, channels, seq_len}, act_options);
224
+
225
+ void *input_ptr = static_cast<void *>(input.data_ptr());
226
+ void *up_filter_ptr = static_cast<void *>(up_filter.data_ptr());
227
+ void *down_filter_ptr = static_cast<void *>(down_filter.data_ptr());
228
+ void *alpha_ptr = static_cast<void *>(alpha.data_ptr());
229
+ void *beta_ptr = static_cast<void *>(beta.data_ptr());
230
+ void *anti_alias_activation_results_ptr = static_cast<void *>(anti_alias_activation_results.data_ptr());
231
+
232
+ DISPATCH_FLOAT_HALF_AND_BFLOAT(
233
+ input.scalar_type(),
234
+ "dispatch anti alias activation_forward",
235
+ dispatch_anti_alias_activation_forward<scalar_t, scalar_t, float>(
236
+ reinterpret_cast<scalar_t *>(anti_alias_activation_results_ptr),
237
+ reinterpret_cast<const scalar_t *>(input_ptr),
238
+ reinterpret_cast<const scalar_t *>(up_filter_ptr),
239
+ reinterpret_cast<const scalar_t *>(down_filter_ptr),
240
+ reinterpret_cast<const scalar_t *>(alpha_ptr),
241
+ reinterpret_cast<const scalar_t *>(beta_ptr),
242
+ batches,
243
+ channels,
244
+ seq_len););
245
+ return anti_alias_activation_results;
246
+ }
modules/audio_detokenizer/vocoder/alias_free_activation/cuda/compat.h ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ /*This code is copied fron NVIDIA apex:
18
+ * https://github.com/NVIDIA/apex
19
+ * with minor changes. */
20
+
21
+ #ifndef TORCH_CHECK
22
+ #define TORCH_CHECK AT_CHECK
23
+ #endif
24
+
25
+ #ifdef VERSION_GE_1_3
26
+ #define DATA_PTR data_ptr
27
+ #else
28
+ #define DATA_PTR data
29
+ #endif
modules/audio_detokenizer/vocoder/alias_free_activation/cuda/load.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ import os
5
+ import pathlib
6
+ import subprocess
7
+
8
+ from torch.utils import cpp_extension
9
+
10
+ """
11
+ Setting this param to a list has a problem of generating different compilation commands (with diferent order of architectures) and leading to recompilation of fused kernels.
12
+ Set it to empty stringo avoid recompilation and assign arch flags explicity in extra_cuda_cflags below
13
+ """
14
+ os.environ["TORCH_CUDA_ARCH_LIST"] = ""
15
+
16
+
17
+ def load():
18
+ # Check if cuda 11 is installed for compute capability 8.0
19
+ cc_flag = []
20
+ _, bare_metal_major, _ = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
21
+ if int(bare_metal_major) >= 11:
22
+ cc_flag.append("-gencode")
23
+ cc_flag.append("arch=compute_80,code=sm_80")
24
+
25
+ # Build path
26
+ srcpath = pathlib.Path(__file__).parent.absolute()
27
+ buildpath = srcpath / "build"
28
+ _create_build_dir(buildpath)
29
+
30
+ # Helper function to build the kernels.
31
+ def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
32
+ return cpp_extension.load(
33
+ name=name,
34
+ sources=sources,
35
+ build_directory=buildpath,
36
+ extra_cflags=[
37
+ "-O3",
38
+ ],
39
+ extra_cuda_cflags=[
40
+ "-O3",
41
+ "-gencode",
42
+ "arch=compute_70,code=sm_70",
43
+ "--use_fast_math",
44
+ ]
45
+ + extra_cuda_flags
46
+ + cc_flag,
47
+ verbose=True,
48
+ )
49
+
50
+ extra_cuda_flags = [
51
+ "-U__CUDA_NO_HALF_OPERATORS__",
52
+ "-U__CUDA_NO_HALF_CONVERSIONS__",
53
+ "--expt-relaxed-constexpr",
54
+ "--expt-extended-lambda",
55
+ ]
56
+
57
+ sources = [
58
+ srcpath / "anti_alias_activation.cpp",
59
+ srcpath / "anti_alias_activation_cuda.cu",
60
+ ]
61
+ anti_alias_activation_cuda = _cpp_extention_load_helper(
62
+ "anti_alias_activation_cuda", sources, extra_cuda_flags
63
+ )
64
+
65
+ return anti_alias_activation_cuda
66
+
67
+
68
+ def _get_cuda_bare_metal_version(cuda_dir):
69
+ raw_output = subprocess.check_output(
70
+ [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
71
+ )
72
+ output = raw_output.split()
73
+ release_idx = output.index("release") + 1
74
+ release = output[release_idx].split(".")
75
+ bare_metal_major = release[0]
76
+ bare_metal_minor = release[1][0]
77
+
78
+ return raw_output, bare_metal_major, bare_metal_minor
79
+
80
+
81
+ def _create_build_dir(buildpath):
82
+ try:
83
+ os.mkdir(buildpath)
84
+ except OSError:
85
+ if not os.path.isdir(buildpath):
86
+ print(f"Creation of the build directory {buildpath} failed")
modules/audio_detokenizer/vocoder/alias_free_activation/cuda/type_shim.h ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #include <ATen/ATen.h>
18
+ #include "compat.h"
19
+
20
+ #define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...) \
21
+ switch (TYPE) \
22
+ { \
23
+ case at::ScalarType::Float: \
24
+ { \
25
+ using scalar_t = float; \
26
+ __VA_ARGS__; \
27
+ break; \
28
+ } \
29
+ case at::ScalarType::Half: \
30
+ { \
31
+ using scalar_t = at::Half; \
32
+ __VA_ARGS__; \
33
+ break; \
34
+ } \
35
+ case at::ScalarType::BFloat16: \
36
+ { \
37
+ using scalar_t = at::BFloat16; \
38
+ __VA_ARGS__; \
39
+ break; \
40
+ } \
41
+ default: \
42
+ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
43
+ }
44
+
45
+ #define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
46
+ switch (TYPEIN) \
47
+ { \
48
+ case at::ScalarType::Float: \
49
+ { \
50
+ using scalar_t_in = float; \
51
+ switch (TYPEOUT) \
52
+ { \
53
+ case at::ScalarType::Float: \
54
+ { \
55
+ using scalar_t_out = float; \
56
+ __VA_ARGS__; \
57
+ break; \
58
+ } \
59
+ case at::ScalarType::Half: \
60
+ { \
61
+ using scalar_t_out = at::Half; \
62
+ __VA_ARGS__; \
63
+ break; \
64
+ } \
65
+ case at::ScalarType::BFloat16: \
66
+ { \
67
+ using scalar_t_out = at::BFloat16; \
68
+ __VA_ARGS__; \
69
+ break; \
70
+ } \
71
+ default: \
72
+ AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
73
+ } \
74
+ break; \
75
+ } \
76
+ case at::ScalarType::Half: \
77
+ { \
78
+ using scalar_t_in = at::Half; \
79
+ using scalar_t_out = at::Half; \
80
+ __VA_ARGS__; \
81
+ break; \
82
+ } \
83
+ case at::ScalarType::BFloat16: \
84
+ { \
85
+ using scalar_t_in = at::BFloat16; \
86
+ using scalar_t_out = at::BFloat16; \
87
+ __VA_ARGS__; \
88
+ break; \
89
+ } \
90
+ default: \
91
+ AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
92
+ }
modules/audio_detokenizer/vocoder/alias_free_activation/torch/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ from .filter import *
5
+ from .resample import *
6
+ from .act import *
modules/audio_detokenizer/vocoder/alias_free_activation/torch/act.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch.nn as nn
5
+ from .resample import UpSample1d, DownSample1d
6
+
7
+
8
+ class Activation1d(nn.Module):
9
+ def __init__(
10
+ self,
11
+ activation,
12
+ up_ratio: int = 2,
13
+ down_ratio: int = 2,
14
+ up_kernel_size: int = 12,
15
+ down_kernel_size: int = 12,
16
+ ):
17
+ super().__init__()
18
+ self.up_ratio = up_ratio
19
+ self.down_ratio = down_ratio
20
+ self.act = activation
21
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
22
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
23
+
24
+ # x: [B,C,T]
25
+ def forward(self, x):
26
+ x = self.upsample(x)
27
+ x = self.act(x)
28
+ x = self.downsample(x)
29
+
30
+ return x
modules/audio_detokenizer/vocoder/alias_free_activation/torch/filter.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import math
8
+
9
+ if "sinc" in dir(torch):
10
+ sinc = torch.sinc
11
+ else:
12
+ # This code is adopted from adefossez's julius.core.sinc under the MIT License
13
+ # https://adefossez.github.io/julius/julius/core.html
14
+ # LICENSE is in incl_licenses directory.
15
+ def sinc(x: torch.Tensor):
16
+ """
17
+ Implementation of sinc, i.e. sin(pi * x) / (pi * x)
18
+ __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
19
+ """
20
+ return torch.where(
21
+ x == 0,
22
+ torch.tensor(1.0, device=x.device, dtype=x.dtype),
23
+ torch.sin(math.pi * x) / math.pi / x,
24
+ )
25
+
26
+
27
+ # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
28
+ # https://adefossez.github.io/julius/julius/lowpass.html
29
+ # LICENSE is in incl_licenses directory.
30
+ def kaiser_sinc_filter1d(
31
+ cutoff, half_width, kernel_size
32
+ ): # return filter [1,1,kernel_size]
33
+ even = kernel_size % 2 == 0
34
+ half_size = kernel_size // 2
35
+
36
+ # For kaiser window
37
+ delta_f = 4 * half_width
38
+ A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
39
+ if A > 50.0:
40
+ beta = 0.1102 * (A - 8.7)
41
+ elif A >= 21.0:
42
+ beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
43
+ else:
44
+ beta = 0.0
45
+ window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
46
+
47
+ # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
48
+ if even:
49
+ time = torch.arange(-half_size, half_size) + 0.5
50
+ else:
51
+ time = torch.arange(kernel_size) - half_size
52
+ if cutoff == 0:
53
+ filter_ = torch.zeros_like(time)
54
+ else:
55
+ filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
56
+ """
57
+ Normalize filter to have sum = 1, otherwise we will have a small leakage of the constant component in the input signal.
58
+ """
59
+ filter_ /= filter_.sum()
60
+ filter = filter_.view(1, 1, kernel_size)
61
+
62
+ return filter
63
+
64
+
65
+ class LowPassFilter1d(nn.Module):
66
+ def __init__(
67
+ self,
68
+ cutoff=0.5,
69
+ half_width=0.6,
70
+ stride: int = 1,
71
+ padding: bool = True,
72
+ padding_mode: str = "replicate",
73
+ kernel_size: int = 12,
74
+ ):
75
+ """
76
+ kernel_size should be even number for stylegan3 setup, in this implementation, odd number is also possible.
77
+ """
78
+ super().__init__()
79
+ if cutoff < -0.0:
80
+ raise ValueError("Minimum cutoff must be larger than zero.")
81
+ if cutoff > 0.5:
82
+ raise ValueError("A cutoff above 0.5 does not make sense.")
83
+ self.kernel_size = kernel_size
84
+ self.even = kernel_size % 2 == 0
85
+ self.pad_left = kernel_size // 2 - int(self.even)
86
+ self.pad_right = kernel_size // 2
87
+ self.stride = stride
88
+ self.padding = padding
89
+ self.padding_mode = padding_mode
90
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
91
+ self.register_buffer("filter", filter)
92
+
93
+ # Input [B, C, T]
94
+ def forward(self, x):
95
+ _, C, _ = x.shape
96
+
97
+ if self.padding:
98
+ x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
99
+ out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
100
+
101
+ return out
modules/audio_detokenizer/vocoder/alias_free_activation/torch/resample.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+ from .filter import LowPassFilter1d
7
+ from .filter import kaiser_sinc_filter1d
8
+
9
+
10
+ class UpSample1d(nn.Module):
11
+ def __init__(self, ratio=2, kernel_size=None):
12
+ super().__init__()
13
+ self.ratio = ratio
14
+ self.kernel_size = (
15
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
16
+ )
17
+ self.stride = ratio
18
+ self.pad = self.kernel_size // ratio - 1
19
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
20
+ self.pad_right = (
21
+ self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
22
+ )
23
+ filter = kaiser_sinc_filter1d(
24
+ cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
25
+ )
26
+ self.register_buffer("filter", filter)
27
+
28
+ # x: [B, C, T]
29
+ def forward(self, x):
30
+ _, C, _ = x.shape
31
+
32
+ x = F.pad(x, (self.pad, self.pad), mode="replicate")
33
+ x = self.ratio * F.conv_transpose1d(
34
+ x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
35
+ )
36
+ x = x[..., self.pad_left : -self.pad_right]
37
+
38
+ return x
39
+
40
+
41
+ class DownSample1d(nn.Module):
42
+ def __init__(self, ratio=2, kernel_size=None):
43
+ super().__init__()
44
+ self.ratio = ratio
45
+ self.kernel_size = (
46
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
47
+ )
48
+ self.lowpass = LowPassFilter1d(
49
+ cutoff=0.5 / ratio,
50
+ half_width=0.6 / ratio,
51
+ stride=ratio,
52
+ kernel_size=self.kernel_size,
53
+ )
54
+
55
+ def forward(self, x):
56
+ xx = self.lowpass(x)
57
+
58
+ return xx
modules/audio_detokenizer/vocoder/bigvgan.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ import os
8
+ import json
9
+ from pathlib import Path
10
+ from typing import Optional, Union, Dict
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.nn import Conv1d, ConvTranspose1d
15
+ from torch.nn.utils import weight_norm, remove_weight_norm
16
+
17
+ from modules.audio_detokenizer.vocoder.activations import Snake, SnakeBeta
18
+ from modules.audio_detokenizer.vocoder.utils import init_weights, get_padding
19
+ from modules.audio_detokenizer.vocoder.alias_free_activation.torch.act import Activation1d as TorchActivation1d
20
+ from modules.audio_detokenizer.vocoder.utils import AttrDict
21
+
22
+ from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
23
+
24
+
25
+ def load_hparams_from_json(path) -> AttrDict:
26
+ with open(path) as f:
27
+ data = f.read()
28
+ return AttrDict(json.loads(data))
29
+
30
+
31
+ class AMPBlock1(torch.nn.Module):
32
+ """
33
+ AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
34
+ AMPBlock1 has additional self.convs2 that contains additional Conv1d layers with a fixed dilation=1 followed by each layer in self.convs1
35
+
36
+ Args:
37
+ h (AttrDict): Hyperparameters.
38
+ channels (int): Number of convolution channels.
39
+ kernel_size (int): Size of the convolution kernel. Default is 3.
40
+ dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).
41
+ activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ h: AttrDict,
47
+ channels: int,
48
+ kernel_size: int = 3,
49
+ dilation: tuple = (1, 3, 5),
50
+ activation: str = None,
51
+ ):
52
+ super().__init__()
53
+
54
+ self.h = h
55
+
56
+ self.convs1 = nn.ModuleList(
57
+ [
58
+ weight_norm(
59
+ Conv1d(
60
+ channels,
61
+ channels,
62
+ kernel_size,
63
+ stride=1,
64
+ dilation=d,
65
+ padding=get_padding(kernel_size, d),
66
+ )
67
+ )
68
+ for d in dilation
69
+ ]
70
+ )
71
+ self.convs1.apply(init_weights)
72
+
73
+ self.convs2 = nn.ModuleList(
74
+ [
75
+ weight_norm(
76
+ Conv1d(
77
+ channels,
78
+ channels,
79
+ kernel_size,
80
+ stride=1,
81
+ dilation=1,
82
+ padding=get_padding(kernel_size, 1),
83
+ )
84
+ )
85
+ for _ in range(len(dilation))
86
+ ]
87
+ )
88
+ self.convs2.apply(init_weights)
89
+
90
+ self.num_layers = len(self.convs1) + len(
91
+ self.convs2
92
+ ) # Total number of conv layers
93
+
94
+ # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
95
+ if self.h.get("use_cuda_kernel", False):
96
+ from modules.audio_detokenizer.vocoder.alias_free_activation.cuda.activation1d import (
97
+ Activation1d as CudaActivation1d,
98
+ )
99
+
100
+ Activation1d = CudaActivation1d
101
+ else:
102
+ Activation1d = TorchActivation1d
103
+
104
+ # Activation functions
105
+ if activation == "snake":
106
+ self.activations = nn.ModuleList(
107
+ [
108
+ Activation1d(
109
+ activation=Snake(
110
+ channels, alpha_logscale=h.snake_logscale
111
+ )
112
+ )
113
+ for _ in range(self.num_layers)
114
+ ]
115
+ )
116
+ elif activation == "snakebeta":
117
+ self.activations = nn.ModuleList(
118
+ [
119
+ Activation1d(
120
+ activation=SnakeBeta(
121
+ channels, alpha_logscale=h.snake_logscale
122
+ )
123
+ )
124
+ for _ in range(self.num_layers)
125
+ ]
126
+ )
127
+ else:
128
+ raise NotImplementedError(
129
+ "activation incorrectly specified. check the config file and look for 'activation'."
130
+ )
131
+
132
+ def forward(self, x):
133
+ acts1, acts2 = self.activations[::2], self.activations[1::2]
134
+ for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
135
+ xt = a1(x)
136
+ xt = c1(xt)
137
+ xt = a2(xt)
138
+ xt = c2(xt)
139
+ x = xt + x
140
+
141
+ return x
142
+
143
+ def remove_weight_norm(self):
144
+ for l in self.convs1:
145
+ remove_weight_norm(l)
146
+ for l in self.convs2:
147
+ remove_weight_norm(l)
148
+
149
+
150
+ class AMPBlock2(torch.nn.Module):
151
+ """
152
+ AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
153
+ Unlike AMPBlock1, AMPBlock2 does not contain extra Conv1d layers with fixed dilation=1
154
+
155
+ Args:
156
+ h (AttrDict): Hyperparameters.
157
+ channels (int): Number of convolution channels.
158
+ kernel_size (int): Size of the convolution kernel. Default is 3.
159
+ dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).
160
+ activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.
161
+ """
162
+
163
+ def __init__(
164
+ self,
165
+ h: AttrDict,
166
+ channels: int,
167
+ kernel_size: int = 3,
168
+ dilation: tuple = (1, 3, 5),
169
+ activation: str = None,
170
+ ):
171
+ super().__init__()
172
+
173
+ self.h = h
174
+
175
+ self.convs = nn.ModuleList(
176
+ [
177
+ weight_norm(
178
+ Conv1d(
179
+ channels,
180
+ channels,
181
+ kernel_size,
182
+ stride=1,
183
+ dilation=d,
184
+ padding=get_padding(kernel_size, d),
185
+ )
186
+ )
187
+ for d in dilation
188
+ ]
189
+ )
190
+ self.convs.apply(init_weights)
191
+
192
+ self.num_layers = len(self.convs) # Total number of conv layers
193
+
194
+ # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
195
+ if self.h.get("use_cuda_kernel", False):
196
+ from modules.audio_detokenizer.vocoder.alias_free_activation.cuda.activation1d import (
197
+ Activation1d as CudaActivation1d,
198
+ )
199
+
200
+ Activation1d = CudaActivation1d
201
+ else:
202
+ Activation1d = TorchActivation1d
203
+
204
+ # Activation functions
205
+ if activation == "snake":
206
+ self.activations = nn.ModuleList(
207
+ [
208
+ Activation1d(
209
+ activation=Snake(
210
+ channels, alpha_logscale=h.snake_logscale
211
+ )
212
+ )
213
+ for _ in range(self.num_layers)
214
+ ]
215
+ )
216
+ elif activation == "snakebeta":
217
+ self.activations = nn.ModuleList(
218
+ [
219
+ Activation1d(
220
+ activation=SnakeBeta(
221
+ channels, alpha_logscale=h.snake_logscale
222
+ )
223
+ )
224
+ for _ in range(self.num_layers)
225
+ ]
226
+ )
227
+ else:
228
+ raise NotImplementedError(
229
+ "activation incorrectly specified. check the config file and look for 'activation'."
230
+ )
231
+
232
+ def forward(self, x):
233
+ for c, a in zip(self.convs, self.activations):
234
+ xt = a(x)
235
+ xt = c(xt)
236
+ x = xt + x
237
+
238
+ def remove_weight_norm(self):
239
+ for l in self.convs:
240
+ remove_weight_norm(l)
241
+
242
+
243
+ class BigVGAN(
244
+ torch.nn.Module,
245
+ PyTorchModelHubMixin,
246
+ library_name="bigvgan",
247
+ repo_url="https://github.com/NVIDIA/BigVGAN",
248
+ docs_url="https://github.com/NVIDIA/BigVGAN/blob/main/README.md",
249
+ pipeline_tag="audio-to-audio",
250
+ license="mit",
251
+ tags=["neural-vocoder", "audio-generation", "arxiv:2206.04658"],
252
+ ):
253
+ """
254
+ BigVGAN is a neural vocoder model that applies anti-aliased periodic activation for residual blocks (resblocks).
255
+ New in BigVGAN-v2: it can optionally use optimized CUDA kernels for AMP (anti-aliased multi-periodicity) blocks.
256
+
257
+ Args:
258
+ h (AttrDict): Hyperparameters.
259
+ use_cuda_kernel (bool): If set to True, loads optimized CUDA kernels for AMP. This should be used for inference only, as training is not supported with CUDA kernels.
260
+
261
+ Note:
262
+ - The `use_cuda_kernel` parameter should be used for inference only, as training with CUDA kernels is not supported.
263
+ - Ensure that the activation function is correctly specified in the hyperparameters (h.activation).
264
+ """
265
+
266
+ def __init__(self, h: AttrDict, use_cuda_kernel: bool = False):
267
+ super().__init__()
268
+ self.h = h
269
+ self.h["use_cuda_kernel"] = use_cuda_kernel
270
+
271
+ # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
272
+ if self.h.get("use_cuda_kernel", False):
273
+ from modules.audio_detokenizer.vocoder.alias_free_activation.cuda.activation1d import (
274
+ Activation1d as CudaActivation1d,
275
+ )
276
+
277
+ Activation1d = CudaActivation1d
278
+ else:
279
+ Activation1d = TorchActivation1d
280
+
281
+ self.num_kernels = len(h.resblock_kernel_sizes)
282
+ self.num_upsamples = len(h.upsample_rates)
283
+
284
+ # Pre-conv
285
+ self.conv_pre = weight_norm(
286
+ Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
287
+ )
288
+
289
+ # Define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
290
+ if h.resblock == "1":
291
+ resblock_class = AMPBlock1
292
+ elif h.resblock == "2":
293
+ resblock_class = AMPBlock2
294
+ else:
295
+ raise ValueError(
296
+ f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}"
297
+ )
298
+
299
+ # Transposed conv-based upsamplers. does not apply anti-aliasing
300
+ self.ups = nn.ModuleList()
301
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
302
+ self.ups.append(
303
+ nn.ModuleList(
304
+ [
305
+ weight_norm(
306
+ ConvTranspose1d(
307
+ h.upsample_initial_channel // (2**i),
308
+ h.upsample_initial_channel // (2 ** (i + 1)),
309
+ k,
310
+ u,
311
+ padding=(k - u) // 2,
312
+ )
313
+ )
314
+ ]
315
+ )
316
+ )
317
+
318
+ # Residual blocks using anti-aliased multi-periodicity composition modules (AMP)
319
+ self.resblocks = nn.ModuleList()
320
+ for i in range(len(self.ups)):
321
+ ch = h.upsample_initial_channel // (2 ** (i + 1))
322
+ for j, (k, d) in enumerate(
323
+ zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
324
+ ):
325
+ self.resblocks.append(
326
+ resblock_class(h, ch, k, d, activation=h.activation)
327
+ )
328
+
329
+ # Post-conv
330
+ activation_post = (
331
+ Snake(ch, alpha_logscale=h.snake_logscale)
332
+ if h.activation == "snake"
333
+ else (
334
+ SnakeBeta(ch, alpha_logscale=h.snake_logscale)
335
+ if h.activation == "snakebeta"
336
+ else None
337
+ )
338
+ )
339
+ if activation_post is None:
340
+ raise NotImplementedError(
341
+ "activation incorrectly specified. check the config file and look for 'activation'."
342
+ )
343
+
344
+ self.activation_post = Activation1d(activation=activation_post)
345
+
346
+ # Whether to use bias for the final conv_post. Default to True for backward compatibility
347
+ self.use_bias_at_final = h.get("use_bias_at_final", True)
348
+ self.conv_post = weight_norm(
349
+ Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final)
350
+ )
351
+
352
+ # Weight initialization
353
+ for i in range(len(self.ups)):
354
+ self.ups[i].apply(init_weights)
355
+ self.conv_post.apply(init_weights)
356
+
357
+ # Final tanh activation. Defaults to True for backward compatibility
358
+ self.use_tanh_at_final = h.get("use_tanh_at_final", True)
359
+
360
+ def forward(self, x):
361
+ # Pre-conv
362
+ x = self.conv_pre(x)
363
+
364
+ for i in range(self.num_upsamples):
365
+ # Upsampling
366
+ for i_up in range(len(self.ups[i])):
367
+ x = self.ups[i][i_up](x)
368
+ # AMP blocks
369
+ xs = None
370
+ for j in range(self.num_kernels):
371
+ if xs is None:
372
+ xs = self.resblocks[i * self.num_kernels + j](x)
373
+ else:
374
+ xs += self.resblocks[i * self.num_kernels + j](x)
375
+ x = xs / self.num_kernels
376
+
377
+ # Post-conv
378
+ x = self.activation_post(x)
379
+ x = self.conv_post(x)
380
+ # Final tanh activation
381
+ if self.use_tanh_at_final:
382
+ x = torch.tanh(x)
383
+ else:
384
+ x = torch.clamp(x, min=-1.0, max=1.0) # Bound the output to [-1, 1]
385
+
386
+ return x
387
+
388
+ def remove_weight_norm(self):
389
+ try:
390
+ print("Removing weight norm...")
391
+ for l in self.ups:
392
+ for l_i in l:
393
+ remove_weight_norm(l_i)
394
+ for l in self.resblocks:
395
+ l.remove_weight_norm()
396
+ remove_weight_norm(self.conv_pre)
397
+ remove_weight_norm(self.conv_post)
398
+ except ValueError:
399
+ print("[INFO] Model already removed weight norm. Skipping!")
400
+ pass
401
+
402
+ # Additional methods for huggingface_hub support
403
+ def _save_pretrained(self, save_directory: Path) -> None:
404
+ """Save weights and config.json from a Pytorch model to a local directory."""
405
+
406
+ model_path = save_directory / "bigvgan_generator.pt"
407
+ torch.save({"generator": self.state_dict()}, model_path)
408
+
409
+ config_path = save_directory / "config.json"
410
+ with open(config_path, "w") as config_file:
411
+ json.dump(self.h, config_file, indent=4)
412
+
413
+ @classmethod
414
+ def _from_pretrained(
415
+ cls,
416
+ *,
417
+ model_id: str,
418
+ revision: str,
419
+ cache_dir: str,
420
+ force_download: bool,
421
+ proxies: Optional[Dict],
422
+ resume_download: bool,
423
+ local_files_only: bool,
424
+ token: Union[str, bool, None],
425
+ map_location: str = "cpu", # Additional argument
426
+ strict: bool = False, # Additional argument
427
+ use_cuda_kernel: bool = False,
428
+ **model_kwargs,
429
+ ):
430
+ """Load Pytorch pretrained weights and return the loaded model."""
431
+
432
+ # Download and load hyperparameters (h) used by BigVGAN
433
+ if os.path.isdir(model_id):
434
+ print("Loading config.json from local directory")
435
+ config_file = os.path.join(model_id, "config.json")
436
+ else:
437
+ config_file = hf_hub_download(
438
+ repo_id=model_id,
439
+ filename="config.json",
440
+ revision=revision,
441
+ cache_dir=cache_dir,
442
+ force_download=force_download,
443
+ proxies=proxies,
444
+ resume_download=resume_download,
445
+ token=token,
446
+ local_files_only=local_files_only,
447
+ )
448
+ h = load_hparams_from_json(config_file)
449
+
450
+ # instantiate BigVGAN using h
451
+ if use_cuda_kernel:
452
+ print(
453
+ f"[WARNING] You have specified use_cuda_kernel=True during BigVGAN.from_pretrained(). Only inference is supported (training is not implemented)!"
454
+ )
455
+ print(
456
+ f"[WARNING] You need nvcc and ninja installed in your system that matches your PyTorch build is using to build the kernel. If not, the model will fail to initialize or generate incorrect waveform!"
457
+ )
458
+ print(
459
+ f"[WARNING] For detail, see the official GitHub repository: https://github.com/NVIDIA/BigVGAN?tab=readme-ov-file#using-custom-cuda-kernel-for-synthesis"
460
+ )
461
+ model = cls(h, use_cuda_kernel=use_cuda_kernel)
462
+
463
+ # Download and load pretrained generator weight
464
+ if os.path.isdir(model_id):
465
+ print("Loading weights from local directory")
466
+ model_file = os.path.join(model_id, "bigvgan_generator.pt")
467
+ else:
468
+ print(f"Loading weights from {model_id}")
469
+ model_file = hf_hub_download(
470
+ repo_id=model_id,
471
+ filename="bigvgan_generator.pt",
472
+ revision=revision,
473
+ cache_dir=cache_dir,
474
+ force_download=force_download,
475
+ proxies=proxies,
476
+ resume_download=resume_download,
477
+ token=token,
478
+ local_files_only=local_files_only,
479
+ )
480
+
481
+ checkpoint_dict = torch.load(model_file, map_location=map_location, weights_only=True)
482
+
483
+ try:
484
+ model.load_state_dict(checkpoint_dict["generator"])
485
+ except RuntimeError:
486
+ print(
487
+ f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
488
+ )
489
+ model.remove_weight_norm()
490
+ model.load_state_dict(checkpoint_dict["generator"])
491
+
492
+ return model
modules/audio_detokenizer/vocoder/utils.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from librosa.filters import mel as librosa_mel_fn
2
+ import torch
3
+ import os
4
+ mel_basis_cache = {}
5
+ hann_window_cache = {}
6
+
7
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
8
+ return torch.log(torch.clamp(x, min=clip_val) * C)
9
+
10
+
11
+ def spectral_normalize_torch(magnitudes):
12
+ return dynamic_range_compression_torch(magnitudes)
13
+
14
+ def get_melspec(
15
+ y: torch.Tensor,
16
+ n_fft: int,
17
+ num_mels: int,
18
+ sampling_rate: int,
19
+ hop_size: int,
20
+ win_size: int,
21
+ fmin: int,
22
+ fmax: int = None,
23
+ center: bool = False,
24
+ ) -> torch.Tensor:
25
+ """
26
+ Calculate the mel spectrogram of an input signal.
27
+ This function uses slaney norm for the librosa mel filterbank (using librosa.filters.mel) and uses Hann window for STFT (using torch.stft).
28
+
29
+ Args:
30
+ y (torch.Tensor): Input signal.
31
+ n_fft (int): FFT size.
32
+ num_mels (int): Number of mel bins.
33
+ sampling_rate (int): Sampling rate of the input signal.
34
+ hop_size (int): Hop size for STFT.
35
+ win_size (int): Window size for STFT.
36
+ fmin (int): Minimum frequency for mel filterbank.
37
+ fmax (int): Maximum frequency for mel filterbank. If None, defaults to half the sampling rate (fmax = sr / 2.0) inside librosa_mel_fn
38
+ center (bool): Whether to pad the input to center the frames. Default is False.
39
+
40
+ Returns:
41
+ torch.Tensor: Mel spectrogram.
42
+ """
43
+ if torch.min(y) < -1.0:
44
+ print(f"[WARNING] Min value of input waveform signal is {torch.min(y)}")
45
+ if torch.max(y) > 1.0:
46
+ print(f"[WARNING] Max value of input waveform signal is {torch.max(y)}")
47
+
48
+ device = y.device
49
+ key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}"
50
+
51
+ if key not in mel_basis_cache:
52
+ mel = librosa_mel_fn(
53
+ sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
54
+ )
55
+ mel_basis_cache[key] = torch.from_numpy(mel).float().to(device)
56
+ hann_window_cache[key] = torch.hann_window(win_size).to(device)
57
+
58
+ mel_basis = mel_basis_cache[key]
59
+ hann_window = hann_window_cache[key]
60
+
61
+ padding = (n_fft - hop_size) // 2
62
+ y = torch.nn.functional.pad(
63
+ y.unsqueeze(1), (padding, padding), mode="reflect"
64
+ ).squeeze(1)
65
+
66
+ spec = torch.stft(
67
+ y,
68
+ n_fft,
69
+ hop_length=hop_size,
70
+ win_length=win_size,
71
+ window=hann_window,
72
+ center=center,
73
+ pad_mode="reflect",
74
+ normalized=False,
75
+ onesided=True,
76
+ return_complex=True,
77
+ )
78
+ spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
79
+
80
+ mel_spec = torch.matmul(mel_basis, spec)
81
+ mel_spec = spectral_normalize_torch(mel_spec)
82
+
83
+ return mel_spec
84
+
85
+
86
+ class AttrDict(dict):
87
+ def __init__(self, *args, **kwargs):
88
+ super(AttrDict, self).__init__(*args, **kwargs)
89
+ self.__dict__ = self
90
+
91
+ def load_checkpoint(filepath, device):
92
+ assert os.path.isfile(filepath)
93
+ print(f"Loading '{filepath}'")
94
+ checkpoint_dict = torch.load(filepath, map_location=device, weights_only=True)
95
+ print("Complete.")
96
+ return checkpoint_dict
97
+
98
+ def init_weights(m, mean=0.0, std=0.01):
99
+ classname = m.__class__.__name__
100
+ if classname.find("Conv") != -1:
101
+ m.weight.data.normal_(mean, std)
102
+
103
+
104
+ def get_padding(kernel_size, dilation=1):
105
+ return int((kernel_size * dilation - dilation) / 2)
modules/audio_tokenizer/audio_tokenizer.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import librosa
3
+ import yaml
4
+ from transformers import Wav2Vec2BertModel, SeamlessM4TFeatureExtractor
5
+ import safetensors
6
+ import accelerate
7
+ import soundfile as sf
8
+ import math
9
+ from einops import rearrange
10
+ from modules.audio_tokenizer.rep_codec import RepCodec
11
+
12
+
13
+ class AudioTokenizer(object):
14
+ def __init__(self, **kwargs):
15
+ self.device = kwargs.pop('device')
16
+ print(self.device)
17
+ # tokenize
18
+ feat_stats = kwargs.pop('feat_stats')
19
+ feat_stats = torch.load(feat_stats, map_location='cpu')
20
+ self.feat_mean = feat_stats['mean']
21
+ self.feat_std = torch.sqrt(feat_stats['var'])
22
+ wav2vec_ckpt = kwargs.pop("wav2vec_ckpt")
23
+ self.semantic_model = Wav2Vec2BertModel.from_pretrained(wav2vec_ckpt)
24
+ self.semantic_model.eval()
25
+ self.semantic_model.to(self.device)
26
+ self.semantic_processor = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
27
+
28
+ self.semantic_codec = RepCodec()
29
+ self.semantic_codec.eval()
30
+ pretrained_path = kwargs.pop("semantic_codec_ckpt")
31
+ safetensors.torch.load_model(self.semantic_codec, pretrained_path)
32
+ self.semantic_codec.to(self.device)
33
+
34
+ self.max_length = 2048
35
+
36
+
37
+ @torch.no_grad()
38
+ def tokenize(self, speech):
39
+ # Input:
40
+ # speech: torch tensor, shape[B, N_speech]
41
+ # Output:
42
+ # semantic token: torch tensor, shape[B, N]
43
+
44
+ inputs = self.semantic_processor(speech.cpu(), sampling_rate=16000, return_tensors="pt")
45
+ input_features = inputs["input_features"].to(self.device)
46
+ attention_mask = inputs["attention_mask"].to(self.device)
47
+ seg_num = math.ceil(input_features.shape[1] / self.max_length)
48
+ pad_num = seg_num * self.max_length - input_features.shape[1]
49
+ input_features = torch.nn.functional.pad(input_features, (0, 0, 0, pad_num, 0,0), value=0)
50
+ attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_num, 0, 0), value=0)
51
+ input_features = rearrange(input_features, "b (s n) d -> (b s) n d", s =seg_num)
52
+ attention_mask = rearrange(attention_mask, "b (s n) -> (b s) n", s=seg_num)
53
+
54
+
55
+ feats = self.semantic_model(
56
+ input_features=input_features,
57
+ attention_mask=attention_mask,
58
+ output_hidden_states=True,
59
+ )
60
+ feat = feats.hidden_states[17]
61
+ feat = rearrange(feat, "(b s) n d -> b (s n) d", s=seg_num)
62
+ feat = feat[:, :feat.shape[1]-pad_num, :]
63
+ feat = (feat - self.feat_mean.to(feat)) / self.feat_std.to(feat)
64
+ semantic_token, _ = self.semantic_codec.quantize(feat)
65
+ return semantic_token
66
+
67
+ def get_audio_tokenizer():
68
+ config = dict()
69
+ config['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'
70
+ config['feat_stats'] = 'resources/audio_tokenizer/stats.pt'
71
+ config['wav2vec_ckpt'] = 'facebook/w2v-bert-2.0'
72
+ config['semantic_codec_ckpt'] = 'resources/audio_tokenizer/model.safetensors'
73
+ audio_tokenizer = AudioTokenizer(**config)
74
+ return audio_tokenizer
75
+
76
+ get_audio_tokenizer()
modules/audio_tokenizer/quantize/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .vector_quantize import VectorQuantize
2
+ from .residual_vq import ResidualVQ
3
+ from .factorized_vector_quantize import FactorizedVectorQuantize
modules/audio_tokenizer/quantize/factorized_vector_quantize.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from einops import rearrange
6
+ from torch.nn.utils import weight_norm
7
+
8
+
9
+ def WNConv1d(*args, **kwargs):
10
+ return weight_norm(nn.Conv1d(*args, **kwargs))
11
+
12
+
13
+ def WNConvTranspose1d(*args, **kwargs):
14
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
15
+
16
+
17
+ class FactorizedVectorQuantize(nn.Module):
18
+ def __init__(
19
+ self,
20
+ input_dim,
21
+ codebook_size,
22
+ codebook_dim,
23
+ commitment=0.005,
24
+ codebook_loss_weight=1.0,
25
+ use_l2_normlize=True,
26
+ ):
27
+ super().__init__()
28
+ self.input_dim = input_dim
29
+ self.codebook_size = codebook_size
30
+ self.codebook_dim = codebook_dim
31
+ self.commitment = commitment
32
+ self.codebook_loss_weight = codebook_loss_weight
33
+ self.use_l2_normlize = use_l2_normlize
34
+
35
+ if self.input_dim != self.codebook_dim:
36
+ self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)
37
+ self.out_project = WNConv1d(
38
+ self.codebook_dim, self.input_dim, kernel_size=1
39
+ )
40
+
41
+ else:
42
+ self.in_project = nn.Identity()
43
+ self.out_project = nn.Identity()
44
+
45
+ self.codebook = nn.Embedding(self.codebook_size, self.codebook_dim)
46
+
47
+ def forward(self, z):
48
+ """
49
+ Parameters
50
+ ----------
51
+ z: torch.Tensor[B x D x T]
52
+
53
+ Returns
54
+ -------
55
+ z_q: torch.Tensor[B x D x T]
56
+ Quantized continuous representation of input
57
+ commit_loss: Tensor[B]
58
+ Commitment loss to train encoder to predict vectors closer to codebook entries
59
+ codebook_loss: Tensor[B]
60
+ Codebook loss to update the codebook
61
+ indices: torch.Tensor[B x T]
62
+ Codebook indices (quantized discrete representation of input)
63
+ z_e: torch.Tensor[B x D x T]
64
+ Projected latents (continuous representation of input before quantization)
65
+ """
66
+
67
+ # Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim
68
+ z_e = self.in_project(z)
69
+ z_q, indices = self.decode_latents(z_e)
70
+
71
+ # Compute commitment loss and codebook loss
72
+ if self.training:
73
+ commit_loss = (
74
+ F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
75
+ * self.commitment
76
+ )
77
+ codebook_loss = (
78
+ F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
79
+ * self.codebook_loss_weight
80
+ )
81
+ else:
82
+ commit_loss = torch.zeros(z.shape[0], device=z.device)
83
+ codebook_loss = torch.zeros(z.shape[0], device=z.device)
84
+
85
+ z_q = z_e + (z_q - z_e).detach()
86
+
87
+ z_q = self.out_project(z_q)
88
+
89
+ return z_q, commit_loss, codebook_loss, indices, z_e
90
+
91
+ def embed_code(self, embed_id):
92
+ return F.embedding(embed_id, self.codebook.weight)
93
+
94
+ def decode_code(self, embed_id):
95
+ return self.embed_code(embed_id).transpose(1, 2)
96
+
97
+ def decode_latents(self, latents):
98
+ encodings = rearrange(latents, "b d t -> (b t) d")
99
+ codebook = self.codebook.weight
100
+
101
+ # L2 normalize encodings and codebook
102
+ if self.use_l2_normlize:
103
+ encodings = F.normalize(encodings)
104
+ codebook = F.normalize(codebook)
105
+
106
+ # Compute euclidean distance between encodings and codebook,
107
+ # if use_l2_normlize is True, the distance is equal to cosine distance
108
+ dist = (
109
+ encodings.pow(2).sum(1, keepdim=True)
110
+ - 2 * encodings @ codebook.t()
111
+ + codebook.pow(2).sum(1, keepdim=True).t()
112
+ )
113
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
114
+ z_q = self.decode_code(indices)
115
+
116
+ return z_q, indices
117
+
118
+ def vq2emb(self, vq, out_proj=True):
119
+ emb = self.decode_code(vq)
120
+ if out_proj:
121
+ emb = self.out_project(emb)
122
+ return emb
123
+
124
+ def latent2dist(self, latents):
125
+ encodings = rearrange(latents, "b d t -> (b t) d")
126
+ codebook = self.codebook.weight
127
+
128
+ # L2 normalize encodings and codebook
129
+ if self.use_l2_normlize:
130
+ encodings = F.normalize(encodings)
131
+ codebook = F.normalize(codebook)
132
+
133
+ # Compute euclidean distance between encodings and codebook,
134
+ # if use_l2_normlize is True, the distance is equal to cosine distance
135
+ dist = (
136
+ encodings.pow(2).sum(1, keepdim=True)
137
+ - 2 * encodings @ codebook.t()
138
+ + codebook.pow(2).sum(1, keepdim=True).t()
139
+ ) # (b*t, k)
140
+
141
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
142
+ dist = rearrange(dist, "(b t) k -> b t k", b=latents.size(0))
143
+ z_q = self.decode_code(indices)
144
+
145
+ return -dist, indices, z_q
modules/audio_tokenizer/quantize/residual_vq.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from torch.nn.utils import weight_norm
9
+
10
+
11
+ from .vector_quantize import VectorQuantize
12
+ from .factorized_vector_quantize import FactorizedVectorQuantize
13
+
14
+
15
+ class ResidualVQ(nn.Module):
16
+ """
17
+ Introduced in SoundStream: An end2end neural audio codec
18
+ https://arxiv.org/abs/2107.03312
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ input_dim: int = 256,
24
+ num_quantizers: int = 8,
25
+ codebook_size: int = 1024,
26
+ codebook_dim: int = 256,
27
+ quantizer_type: str = "vq", # "vq" or "fvq" or "lfq"
28
+ quantizer_dropout: float = 0.5,
29
+ **kwargs,
30
+ ):
31
+ super().__init__()
32
+
33
+ self.input_dim = input_dim
34
+ self.num_quantizers = num_quantizers
35
+ self.codebook_size = codebook_size
36
+ self.codebook_dim = codebook_dim
37
+ self.quantizer_type = quantizer_type
38
+ self.quantizer_dropout = quantizer_dropout
39
+
40
+ if quantizer_type == "vq":
41
+ VQ = VectorQuantize
42
+ elif quantizer_type == "fvq":
43
+ VQ = FactorizedVectorQuantize
44
+ else:
45
+ raise ValueError(f"Unknown quantizer type {quantizer_type}")
46
+
47
+ self.quantizers = nn.ModuleList(
48
+ [
49
+ VQ(
50
+ input_dim=input_dim,
51
+ codebook_size=codebook_size,
52
+ codebook_dim=codebook_dim,
53
+ **kwargs,
54
+ )
55
+ for _ in range(num_quantizers)
56
+ ]
57
+ )
58
+
59
+ def forward(self, z, n_quantizers: int = None):
60
+ """
61
+ Parameters
62
+ ----------
63
+ z : Tensor[B x D x T]
64
+ n_quantizers : int, optional
65
+ No. of quantizers to use
66
+ (n_quantizers < self.n_codebooks ex: for quantizer dropout)
67
+ Note: if `self.quantizer_dropout` is True, this argument is ignored
68
+ when in training mode, and a random number of quantizers is used.
69
+ Returns
70
+ -------
71
+ "quantized_out" : Tensor[B x D x T]
72
+ Quantized continuous representation of input
73
+ "all_indices" : Tensor[N x B x T]
74
+ Codebook indices for each codebook
75
+ (quantized discrete representation of input)
76
+ "all_commit_losses" : Tensor[N]
77
+ "all_codebook_losses" : Tensor[N]
78
+ "all_quantized" : Tensor[N x B x D x T]
79
+ """
80
+
81
+ quantized_out = 0.0
82
+ residual = z
83
+
84
+ all_commit_losses = []
85
+ all_codebook_losses = []
86
+ all_indices = []
87
+ all_quantized = []
88
+
89
+ if n_quantizers is None:
90
+ n_quantizers = self.num_quantizers
91
+
92
+ if self.training:
93
+ n_quantizers = torch.ones((z.shape[0],)) * self.num_quantizers + 1
94
+ dropout = torch.randint(1, self.num_quantizers + 1, (z.shape[0],))
95
+ n_dropout = int(z.shape[0] * self.quantizer_dropout)
96
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
97
+ n_quantizers = n_quantizers.to(z.device)
98
+
99
+ for i, quantizer in enumerate(self.quantizers):
100
+ if self.training is False and i >= n_quantizers:
101
+ break
102
+
103
+ z_q_i, commit_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
104
+ residual
105
+ )
106
+
107
+ # Create mask to apply quantizer dropout
108
+ mask = (
109
+ torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
110
+ )
111
+ quantized_out = quantized_out + z_q_i * mask[:, None, None]
112
+ residual = residual - z_q_i
113
+
114
+ commit_loss_i = (commit_loss_i * mask).mean()
115
+ codebook_loss_i = (codebook_loss_i * mask).mean()
116
+
117
+ all_commit_losses.append(commit_loss_i)
118
+ all_codebook_losses.append(codebook_loss_i)
119
+ all_indices.append(indices_i)
120
+ all_quantized.append(z_q_i)
121
+
122
+ all_commit_losses, all_codebook_losses, all_indices, all_quantized = map(
123
+ torch.stack,
124
+ (all_commit_losses, all_codebook_losses, all_indices, all_quantized),
125
+ )
126
+
127
+ return (
128
+ quantized_out,
129
+ all_indices,
130
+ all_commit_losses,
131
+ all_codebook_losses,
132
+ all_quantized,
133
+ )
134
+
135
+ def vq2emb(self, vq, n_quantizers=None):
136
+ quantized_out = 0.0
137
+ if n_quantizers is None:
138
+ n_quantizers = self.num_quantizers
139
+ for idx, quantizer in enumerate(self.quantizers):
140
+ if idx >= n_quantizers:
141
+ break
142
+ quantized_out += quantizer.vq2emb(vq[idx])
143
+ return quantized_out
144
+
145
+ def latent2dist(self, z, n_quantizers=None):
146
+ quantized_out = 0.0
147
+ residual = z
148
+
149
+ all_dists = []
150
+ all_indices = []
151
+
152
+ if n_quantizers is None:
153
+ n_quantizers = self.num_quantizers
154
+
155
+ for i, quantizer in enumerate(self.quantizers):
156
+ if self.training is False and i >= n_quantizers:
157
+ break
158
+ dist_i, indices_i, z_q_i = quantizer.latent2dist(residual)
159
+ all_dists.append(dist_i)
160
+ all_indices.append(indices_i)
161
+
162
+ quantized_out = quantized_out + z_q_i
163
+ residual = residual - z_q_i
164
+
165
+ all_dists = torch.stack(all_dists)
166
+ all_indices = torch.stack(all_indices)
167
+
168
+ return all_dists, all_indices
modules/audio_tokenizer/quantize/vector_quantize.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from einops import rearrange, repeat
6
+ from torch.nn.utils import weight_norm
7
+
8
+
9
+ def WNConv1d(*args, **kwargs):
10
+ return weight_norm(nn.Conv1d(*args, **kwargs))
11
+
12
+
13
+ def WNConvTranspose1d(*args, **kwargs):
14
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
15
+
16
+
17
+ def l2norm(t):
18
+ return F.normalize(t, p=2, dim=-1)
19
+
20
+
21
+ def ema_inplace(moving_avg, new, decay):
22
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
23
+
24
+
25
+ def laplace_smoothing(x, n_categories, eps=1e-5):
26
+ return (x + eps) / (x.sum() + n_categories * eps)
27
+
28
+
29
+ def sample_vectors(samples, num):
30
+ num_samples, device = samples.shape[0], samples.device
31
+
32
+ if num_samples >= num:
33
+ indices = torch.randperm(num_samples, device=device)[:num]
34
+ else:
35
+ indices = torch.randint(0, num_samples, (num,), device=device)
36
+
37
+ return samples[indices]
38
+
39
+
40
+ def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False):
41
+ dim, dtype, device = samples.shape[-1], samples.dtype, samples.device
42
+
43
+ means = sample_vectors(samples, num_clusters)
44
+
45
+ for _ in range(num_iters):
46
+ if use_cosine_sim:
47
+ dists = samples @ means.t()
48
+ else:
49
+ diffs = rearrange(samples, "n d -> n () d") - rearrange(
50
+ means, "c d -> () c d"
51
+ )
52
+ dists = -(diffs**2).sum(dim=-1)
53
+
54
+ buckets = dists.max(dim=-1).indices
55
+ bins = torch.bincount(buckets, minlength=num_clusters)
56
+ zero_mask = bins == 0
57
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
58
+
59
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
60
+ new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
61
+ new_means = new_means / bins_min_clamped[..., None]
62
+
63
+ if use_cosine_sim:
64
+ new_means = l2norm(new_means)
65
+
66
+ means = torch.where(zero_mask[..., None], means, new_means)
67
+
68
+ return means, bins
69
+
70
+
71
+ class EuclideanCodebook(nn.Module):
72
+ def __init__(
73
+ self,
74
+ dim,
75
+ codebook_size,
76
+ kmeans_init=False,
77
+ kmeans_iters=10,
78
+ decay=0.8,
79
+ eps=1e-5,
80
+ threshold_ema_dead_code=2,
81
+ weight_init=False,
82
+ ):
83
+ super().__init__()
84
+
85
+ self.decay = decay
86
+ init_fn = torch.randn if not weight_init else torch.zeros
87
+ embed = init_fn(codebook_size, dim)
88
+
89
+ if weight_init:
90
+ nn.init.uniform_(embed, -1 / codebook_size, 1 / codebook_size)
91
+
92
+ self.codebook_size = codebook_size
93
+ self.kmeans_iters = kmeans_iters
94
+ self.eps = eps
95
+ self.threshold_ema_dead_code = threshold_ema_dead_code
96
+
97
+ self.register_buffer(
98
+ "initted", torch.Tensor([not kmeans_init])
99
+ ) # if kmeans_init is True, then initted is False; otherwise, initted is True
100
+ self.register_buffer("cluster_size", torch.zeros(codebook_size))
101
+ self.register_buffer("embed", embed)
102
+ self.register_buffer("embed_avg", embed.clone())
103
+
104
+ def init_embed_(self, data):
105
+ embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
106
+ self.embed.data.copy_(embed)
107
+ self.embed_avg.data.copy_(embed)
108
+ self.cluster_size.data.copy_(cluster_size)
109
+ self.initted.data.copy_(torch.Tensor([True]))
110
+
111
+ def replace(self, samples, mask):
112
+ modified_codebook = torch.where(
113
+ mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
114
+ )
115
+ self.embed.data.copy_(modified_codebook)
116
+
117
+ def expire_codes_(self, batch_samples):
118
+ if self.threshold_ema_dead_code == 0:
119
+ return
120
+
121
+ expired_codes = self.cluster_size < self.threshold_ema_dead_code
122
+ if not torch.any(expired_codes):
123
+ return
124
+ batch_samples = rearrange(batch_samples, "... d -> (...) d")
125
+ self.replace(batch_samples, mask=expired_codes)
126
+
127
+ def forward(self, x):
128
+ shape, dtype = x.shape, x.dtype
129
+ flatten = rearrange(x, "... d -> (...) d")
130
+ embed = self.embed.t() # (codebook_size, dim) -> (dim, codebook_size)
131
+
132
+ if not self.initted:
133
+ self.init_embed_(flatten)
134
+
135
+ dist = -(
136
+ flatten.pow(2).sum(1, keepdim=True)
137
+ - 2 * flatten @ embed
138
+ + embed.pow(2).sum(0, keepdim=True)
139
+ )
140
+
141
+ embed_ind = dist.max(dim=-1).indices
142
+ embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
143
+ embed_ind = embed_ind.view(*shape[:-1])
144
+ quantize = F.embedding(embed_ind, self.embed)
145
+
146
+ if self.training:
147
+ ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
148
+ embed_sum = (
149
+ flatten.t() @ embed_onehot
150
+ ) # (dim, ...) @ (..., codebook_size) -> (dim, codebook_size)
151
+ ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
152
+ cluster_size = (
153
+ laplace_smoothing(self.cluster_size, self.codebook_size, self.eps)
154
+ * self.cluster_size.sum()
155
+ )
156
+ embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
157
+ self.embed.data.copy_(embed_normalized)
158
+ self.expire_codes_(x)
159
+
160
+ return quantize, embed_ind
161
+
162
+ def vq2emb(self, vq):
163
+ quantize = F.embedding(vq, self.embed)
164
+ return quantize
165
+
166
+ def latent2dist(self, x):
167
+ shape, dtype = x.shape, x.dtype
168
+ flatten = rearrange(x, "... d -> (...) d")
169
+ embed = self.embed.t() # (codebook_size, dim) -> (dim, codebook_size)
170
+
171
+ if not self.initted:
172
+ self.init_embed_(flatten)
173
+
174
+ dist = -(
175
+ flatten.pow(2).sum(1, keepdim=True)
176
+ - 2 * flatten @ embed
177
+ + embed.pow(2).sum(0, keepdim=True)
178
+ )
179
+
180
+ embed_ind = dist.max(dim=-1).indices
181
+ embed_ind = embed_ind.view(*shape[:-1])
182
+ quantize = F.embedding(embed_ind, self.embed)
183
+
184
+ dist = dist.view(*shape[:-1], -1)
185
+
186
+ return dist, embed_ind, quantize
187
+
188
+
189
+ class SimpleCodebook(nn.Module):
190
+ def __init__(
191
+ self,
192
+ dim,
193
+ codebook_size,
194
+ use_l2_normlize=False,
195
+ ):
196
+ super().__init__()
197
+
198
+ self.dim = dim
199
+ self.codebook_size = codebook_size
200
+ self.use_l2_normlize = use_l2_normlize
201
+
202
+ self.embed = nn.Embedding(self.codebook_size, self.dim)
203
+
204
+ def forward(self, x):
205
+ shape, dtype = x.shape, x.dtype
206
+ flatten = rearrange(x, "... d -> (...) d")
207
+ embed = self.embed.weight.t() # (codebook_size, dim) -> (dim, codebook_size)
208
+
209
+ if self.use_l2_normlize:
210
+ flatten = F.normalize(flatten)
211
+ embed = F.normalize(embed)
212
+
213
+ dist = -(
214
+ flatten.pow(2).sum(1, keepdim=True)
215
+ - 2 * flatten @ embed
216
+ + embed.pow(2).sum(0, keepdim=True)
217
+ )
218
+
219
+ embed_ind = dist.max(dim=-1).indices
220
+ embed_ind = embed_ind.view(*shape[:-1])
221
+ quantize = F.embedding(embed_ind, self.embed)
222
+
223
+ return quantize, embed_ind
224
+
225
+ def vq2emb(self, vq):
226
+ quantize = F.embedding(vq, self.embed.weight)
227
+ return quantize
228
+
229
+ def latent2dist(self, x):
230
+ shape, dtype = x.shape, x.dtype
231
+ flatten = rearrange(x, "... d -> (...) d")
232
+ embed = self.embed.weight.t() # (codebook_size, dim) -> (dim, codebook_size)
233
+
234
+ if self.use_l2_normlize:
235
+ flatten = F.normalize(flatten)
236
+ embed = F.normalize(embed)
237
+
238
+ dist = -(
239
+ flatten.pow(2).sum(1, keepdim=True)
240
+ - 2 * flatten @ embed
241
+ + embed.pow(2).sum(0, keepdim=True)
242
+ )
243
+
244
+ embed_ind = dist.max(dim=-1).indices
245
+ embed_ind = embed_ind.view(*shape[:-1])
246
+ quantize = F.embedding(embed_ind, self.embed)
247
+
248
+ dist = dist.view(*shape[:-1], -1)
249
+
250
+ return dist, embed_ind, quantize
251
+
252
+
253
+ class VectorQuantize(nn.Module):
254
+ """Vector quantization and factorized vecotor quantization implementation
255
+ Args:
256
+ input_dim (int): Dimension of input.
257
+ codebook_size (int): Codebook size.
258
+ codebook_dim (int): Codebook dimension. We suggest use codebook_dim = input_dim
259
+ if use codebook_type == "euclidean", otherwise, if you want to use
260
+ factorized vector quantization, use codebook_dim as small number (e.g. 8 or 32).
261
+ commitment (float): Weight for commitment loss.
262
+ use_l2_normlize (bool): Whether to use l2 normlized codes for factorized vecotor quantization,
263
+ we suggest use it as True if you want to use factorized vector quantization
264
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
265
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
266
+ decay (float): Decay for exponential moving average over the codebooks.
267
+ epsilon (float): Epsilon value for numerical stability.
268
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
269
+ that have an exponential moving average cluster size less than the specified threshold with
270
+ randomly selected vector from the current batch.
271
+ """
272
+
273
+ def __init__(
274
+ self,
275
+ input_dim,
276
+ codebook_size,
277
+ codebook_dim,
278
+ commitment=0.005,
279
+ codebook_loss_weight=1.0,
280
+ use_l2_normlize=False,
281
+ codebook_type="euclidean", # "euclidean" or "simple"
282
+ kmeans_init=False,
283
+ kmeans_iters=10,
284
+ decay=0.8,
285
+ eps=1e-5,
286
+ threshold_ema_dead_code=2,
287
+ weight_init=False,
288
+ ):
289
+ super().__init__()
290
+ self.input_dim = input_dim
291
+ self.codebook_size = codebook_size
292
+ self.codebook_dim = codebook_dim
293
+ self.commitment = commitment
294
+ self.codebook_loss_weight = codebook_loss_weight
295
+ self.use_l2_normlize = use_l2_normlize
296
+ self.codebook_type = codebook_type
297
+ self.kmeans_init = kmeans_init
298
+ self.kmeans_iters = kmeans_iters
299
+ self.decay = decay
300
+ self.eps = eps
301
+ self.threshold_ema_dead_code = threshold_ema_dead_code
302
+ self.weight_init = weight_init
303
+
304
+ if self.input_dim != self.codebook_dim:
305
+ self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)
306
+ self.out_project = WNConv1d(
307
+ self.codebook_dim, self.input_dim, kernel_size=1
308
+ )
309
+
310
+ else:
311
+ self.in_project = nn.Identity()
312
+ self.out_project = nn.Identity()
313
+
314
+ if self.codebook_type == "euclidean":
315
+ self.codebook = EuclideanCodebook(
316
+ self.codebook_dim,
317
+ codebook_size=self.codebook_size,
318
+ kmeans_init=self.kmeans_init,
319
+ kmeans_iters=self.kmeans_iters,
320
+ decay=self.decay,
321
+ eps=self.eps,
322
+ threshold_ema_dead_code=self.threshold_ema_dead_code,
323
+ weight_init=self.weight_init,
324
+ )
325
+ elif self.codebook_type == "simple":
326
+ self.codebook = SimpleCodebook(
327
+ self.codebook_dim,
328
+ codebook_size=self.codebook_size,
329
+ use_l2_normlize=self.use_l2_normlize,
330
+ )
331
+ else:
332
+ raise NotImplementedError(
333
+ f"codebook_type {self.codebook_type} is not implemented!"
334
+ )
335
+
336
+ def forward(self, z):
337
+ """
338
+ Parameters
339
+ ----------
340
+ z: torch.Tensor[B x D x T]
341
+
342
+ Returns
343
+ -------
344
+ z_q: torch.Tensor[B x D x T]
345
+ Quantized continuous representation of input
346
+ commit_loss: Tensor[B]
347
+ Commitment loss to train encoder to predict vectors closer to codebook entries
348
+ codebook_loss: Tensor[B]
349
+ Codebook loss to update the codebook
350
+ indices: torch.Tensor[B x T]
351
+ Codebook indices (quantized discrete representation of input)
352
+ z_e: torch.Tensor[B x D x T]
353
+ Projected latents (continuous representation of input before quantization)
354
+ """
355
+
356
+ # Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim
357
+ z_e = self.in_project(z)
358
+ z_q, indices = self.decode_latents(z_e)
359
+
360
+ # Compute commitment loss and codebook loss
361
+ if self.training:
362
+ commit_loss = (
363
+ F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
364
+ * self.commitment
365
+ )
366
+ codebook_loss = (
367
+ F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
368
+ * self.codebook_loss_weight
369
+ )
370
+ else:
371
+ commit_loss = torch.zeros(z.shape[0], device=z.device)
372
+ codebook_loss = torch.zeros(z.shape[0], device=z.device)
373
+
374
+ z_q = z_e + (z_q - z_e).detach()
375
+
376
+ z_q = self.out_project(z_q)
377
+
378
+ return z_q, commit_loss, codebook_loss, indices, z_e
379
+
380
+ def decode_latents(self, latents):
381
+ encodings = rearrange(latents, "b d t -> b t d")
382
+ z_q, indices = self.codebook(encodings)
383
+ z_q = z_q.transpose(1, 2)
384
+ return z_q, indices
385
+
386
+ def vq2emb(self, vq, out_proj=True):
387
+ emb = self.codebook.vq2emb(vq)
388
+ emb = emb.transpose(1, 2)
389
+ if out_proj:
390
+ emb = self.out_project(emb)
391
+ return emb
392
+
393
+ def latent2dist(self, latents):
394
+ latents = rearrange(latents, "b d t -> b t d")
395
+ dist, embed_ind, quantize = self.codebook.latent2dist(latents)
396
+ return dist, embed_ind, quantize.transpose(1, 2)
modules/audio_tokenizer/rep_codec.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ from modules.audio_tokenizer.quantize import ResidualVQ
6
+ from modules.audio_tokenizer.vocos import VocosBackbone
7
+ from modules.audio_tokenizer.transformer import TransformerEncoder
8
+
9
+ def init_weights(m):
10
+ if isinstance(m, nn.Conv1d):
11
+ nn.init.trunc_normal_(m.weight, std=0.02)
12
+ nn.init.constant_(m.bias, 0)
13
+ if isinstance(m, nn.Linear):
14
+ nn.init.trunc_normal_(m.weight, std=0.02)
15
+ nn.init.constant_(m.bias, 0)
16
+
17
+ class RepCodec(nn.Module):
18
+ def __init__(
19
+ self,
20
+ codebook_size=8192,
21
+ hidden_size=1024,
22
+ codebook_dim=8,
23
+ vocos_dim=384,
24
+ vocos_intermediate_dim=2048,
25
+ vocos_num_layers=12,
26
+ num_quantizers=1,
27
+ use_timbre_encoder=False,
28
+ cfg=None,
29
+ ):
30
+ super().__init__()
31
+ codebook_size = (
32
+ cfg.codebook_size
33
+ if cfg is not None and hasattr(cfg, "codebook_size")
34
+ else codebook_size
35
+ )
36
+ codebook_dim = (
37
+ cfg.codebook_dim
38
+ if cfg is not None and hasattr(cfg, "codebook_dim")
39
+ else codebook_dim
40
+ )
41
+ hidden_size = (
42
+ cfg.hidden_size
43
+ if cfg is not None and hasattr(cfg, "hidden_size")
44
+ else hidden_size
45
+ )
46
+ vocos_dim = (
47
+ cfg.vocos_dim
48
+ if cfg is not None and hasattr(cfg, "vocos_dim")
49
+ else vocos_dim
50
+ )
51
+ vocos_intermediate_dim = (
52
+ cfg.vocos_intermediate_dim
53
+ if cfg is not None and hasattr(cfg, "vocos_dim")
54
+ else vocos_intermediate_dim
55
+ )
56
+ vocos_num_layers = (
57
+ cfg.vocos_num_layers
58
+ if cfg is not None and hasattr(cfg, "vocos_dim")
59
+ else vocos_num_layers
60
+ )
61
+ num_quantizers = (
62
+ cfg.num_quantizers
63
+ if cfg is not None and hasattr(cfg, "num_quantizers")
64
+ else num_quantizers
65
+ )
66
+ use_timbre_encoder = (
67
+ cfg.use_timbre_encoder
68
+ if cfg is not None and hasattr(cfg, "use_timbre_encoder")
69
+ else use_timbre_encoder
70
+ )
71
+
72
+ self.codebook_size = codebook_size
73
+ self.codebook_dim = codebook_dim
74
+ self.hidden_size = hidden_size
75
+ self.vocos_dim = vocos_dim
76
+ self.vocos_intermediate_dim = vocos_intermediate_dim
77
+ self.vocos_num_layers = vocos_num_layers
78
+ self.num_quantizers = num_quantizers
79
+ self.use_timbre_encoder = use_timbre_encoder
80
+
81
+ self.encoder = nn.Sequential(
82
+ VocosBackbone(
83
+ input_channels=self.hidden_size,
84
+ dim=384,
85
+ intermediate_dim=2048,
86
+ num_layers=12,
87
+ adanorm_num_embeddings=None
88
+ ),
89
+ nn.Linear(384, self.hidden_size)
90
+ )
91
+ self.decoder = nn.Sequential(
92
+ VocosBackbone(
93
+ input_channels=self.hidden_size,
94
+ dim=384,
95
+ intermediate_dim=2048,
96
+ num_layers=12,
97
+ adanorm_num_embeddings=None
98
+ ),
99
+ nn.Linear(384, self.hidden_size)
100
+ )
101
+
102
+ self.quantizer = ResidualVQ(
103
+ input_dim=hidden_size,
104
+ num_quantizers=num_quantizers,
105
+ codebook_size=codebook_size,
106
+ codebook_dim=codebook_dim,
107
+ quantizer_type="fvq",
108
+ quantizer_dropout=0.0,
109
+ commitment=0.15,
110
+ codebook_loss_weight=1.0,
111
+ use_l2_normlize=True,
112
+ )
113
+
114
+ if self.use_timbre_encoder: #TODO: write encoder hidden (256) as a hyparam
115
+ self.timbre_in = nn.Linear(hidden_size, 256)
116
+ self.timbre_encoder = TransformerEncoder(
117
+ enc_emb_tokens=None,
118
+ encoder_layer=4,
119
+ encoder_hidden=256,
120
+ encoder_head=4,
121
+ conv_filter_size=1024,
122
+ conv_kernel_size=5,
123
+ encoder_dropout=0.1,
124
+ use_pe=False,
125
+ cfg=None,
126
+ )
127
+ self.timbre_out = nn.Linear(256, hidden_size)
128
+ self.timbre_linear = nn.Linear(hidden_size, hidden_size * 2)
129
+ self.timbre_linear.bias.data[:hidden_size] = 1
130
+ self.timbre_linear.bias.data[hidden_size:] = 0
131
+ self.timbre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False)
132
+ self.enc_ln = nn.LayerNorm(hidden_size, elementwise_affine=False)
133
+
134
+ self.reset_parameters()
135
+
136
+ def forward(self, x):
137
+
138
+ x = self.encoder(x.transpose(1, 2)).transpose(1, 2)
139
+
140
+ if self.use_timbre_encoder:
141
+ x_timbre = x
142
+ x = x.transpose(1, 2)
143
+ x = self.enc_ln(x)
144
+ x = x.transpose(1, 2)
145
+
146
+ (
147
+ quantized_out,
148
+ all_indices,
149
+ all_commit_losses,
150
+ all_codebook_losses,
151
+ _,
152
+ ) = self.quantizer(x)
153
+
154
+ if self.use_timbre_encoder:
155
+ x_timbre = x_timbre.transpose(1, 2)
156
+ x_timbre = self.timbre_in(x_timbre)
157
+ x_timbre = self.timbre_encoder(x_timbre, None, None)
158
+ x_timbre = self.timbre_out(x_timbre)
159
+ x_timbre = x_timbre.transpose(1, 2)
160
+ spk_embs = torch.mean(x_timbre, dim=2)
161
+
162
+ style = self.timbre_linear(spk_embs).unsqueeze(2) # (B, 2d, 1)
163
+ gamma, beta = style.chunk(2, 1) # (B, d, 1)
164
+ quantized_out = quantized_out.transpose(1, 2)
165
+ quantized_out = self.timbre_norm(quantized_out)
166
+ quantized_out = quantized_out.transpose(1, 2)
167
+ quantized_out = quantized_out * gamma + beta
168
+
169
+
170
+ x_rec = self.decoder(quantized_out)
171
+
172
+ codebook_loss = (all_codebook_losses + all_commit_losses).mean()
173
+ all_indices = all_indices
174
+
175
+ return x_rec, codebook_loss, all_indices
176
+
177
+ def quantize(self, x):
178
+ x = self.encoder(x.transpose(1, 2)).transpose(1, 2)
179
+
180
+ if self.use_timbre_encoder:
181
+ x = x.transpose(1, 2)
182
+ x = self.enc_ln(x)
183
+ x = x.transpose(1, 2)
184
+
185
+ (
186
+ quantized_out,
187
+ all_indices,
188
+ all_commit_losses,
189
+ all_codebook_losses,
190
+ _,
191
+ ) = self.quantizer(x)
192
+ if all_indices.shape[0] == 1:
193
+ return all_indices.squeeze(0), quantized_out.transpose(1, 2)
194
+ return all_indices, quantized_out.transpose(1, 2)
195
+
196
+ def reset_parameters(self):
197
+ self.apply(init_weights)
modules/audio_tokenizer/transformer.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+ import math
6
+
7
+
8
+ class StyleAdaptiveLayerNorm(nn.Module):
9
+ def __init__(self, normalized_shape, eps=1e-5):
10
+ super().__init__()
11
+ self.in_dim = normalized_shape
12
+ self.norm = nn.LayerNorm(self.in_dim, eps=eps, elementwise_affine=False)
13
+ self.style = nn.Linear(self.in_dim, self.in_dim * 2)
14
+ self.style.bias.data[: self.in_dim] = 1
15
+ self.style.bias.data[self.in_dim :] = 0
16
+
17
+ def forward(self, x, condition):
18
+ # x: (B, T, d); condition: (B, T, d)
19
+
20
+ style = self.style(torch.mean(condition, dim=1, keepdim=True))
21
+
22
+ gamma, beta = style.chunk(2, -1)
23
+
24
+ out = self.norm(x)
25
+
26
+ out = gamma * out + beta
27
+ return out
28
+
29
+
30
+ class PositionalEncoding(nn.Module):
31
+ def __init__(self, d_model, dropout, max_len=5000):
32
+ super().__init__()
33
+
34
+ self.dropout = dropout
35
+ position = torch.arange(max_len).unsqueeze(1)
36
+ div_term = torch.exp(
37
+ torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
38
+ )
39
+ pe = torch.zeros(max_len, 1, d_model)
40
+ pe[:, 0, 0::2] = torch.sin(position * div_term)
41
+ pe[:, 0, 1::2] = torch.cos(position * div_term)
42
+ self.register_buffer("pe", pe)
43
+
44
+ def forward(self, x):
45
+ x = x + self.pe[: x.size(0)]
46
+ return F.dropout(x, self.dropout, training=self.training)
47
+
48
+
49
+ class TransformerFFNLayer(nn.Module):
50
+ def __init__(
51
+ self, encoder_hidden, conv_filter_size, conv_kernel_size, encoder_dropout
52
+ ):
53
+ super().__init__()
54
+
55
+ self.encoder_hidden = encoder_hidden
56
+ self.conv_filter_size = conv_filter_size
57
+ self.conv_kernel_size = conv_kernel_size
58
+ self.encoder_dropout = encoder_dropout
59
+
60
+ self.ffn_1 = nn.Conv1d(
61
+ self.encoder_hidden,
62
+ self.conv_filter_size,
63
+ self.conv_kernel_size,
64
+ padding=self.conv_kernel_size // 2,
65
+ )
66
+ self.ffn_1.weight.data.normal_(0.0, 0.02)
67
+ self.ffn_2 = nn.Linear(self.conv_filter_size, self.encoder_hidden)
68
+ self.ffn_2.weight.data.normal_(0.0, 0.02)
69
+
70
+ def forward(self, x):
71
+ # x: (B, T, d)
72
+ x = self.ffn_1(x.permute(0, 2, 1)).permute(
73
+ 0, 2, 1
74
+ ) # (B, T, d) -> (B, d, T) -> (B, T, d)
75
+ x = F.relu(x)
76
+ x = F.dropout(x, self.encoder_dropout, training=self.training)
77
+ x = self.ffn_2(x)
78
+ return x
79
+
80
+
81
+ class TransformerEncoderLayer(nn.Module):
82
+ def __init__(
83
+ self,
84
+ encoder_hidden,
85
+ encoder_head,
86
+ conv_filter_size,
87
+ conv_kernel_size,
88
+ encoder_dropout,
89
+ use_cln,
90
+ ):
91
+ super().__init__()
92
+ self.encoder_hidden = encoder_hidden
93
+ self.encoder_head = encoder_head
94
+ self.conv_filter_size = conv_filter_size
95
+ self.conv_kernel_size = conv_kernel_size
96
+ self.encoder_dropout = encoder_dropout
97
+ self.use_cln = use_cln
98
+
99
+ if not self.use_cln:
100
+ self.ln_1 = nn.LayerNorm(self.encoder_hidden)
101
+ self.ln_2 = nn.LayerNorm(self.encoder_hidden)
102
+ else:
103
+ self.ln_1 = StyleAdaptiveLayerNorm(self.encoder_hidden)
104
+ self.ln_2 = StyleAdaptiveLayerNorm(self.encoder_hidden)
105
+
106
+ self.self_attn = nn.MultiheadAttention(
107
+ self.encoder_hidden, self.encoder_head, batch_first=True
108
+ )
109
+
110
+ self.ffn = TransformerFFNLayer(
111
+ self.encoder_hidden,
112
+ self.conv_filter_size,
113
+ self.conv_kernel_size,
114
+ self.encoder_dropout,
115
+ )
116
+
117
+ def forward(self, x, key_padding_mask, conditon=None):
118
+ # x: (B, T, d); key_padding_mask: (B, T), mask is 0; condition: (B, T, d)
119
+
120
+ # self attention
121
+ residual = x
122
+ if self.use_cln:
123
+ x = self.ln_1(x, conditon)
124
+ else:
125
+ x = self.ln_1(x)
126
+
127
+ if key_padding_mask != None:
128
+ key_padding_mask_input = ~(key_padding_mask.bool())
129
+ else:
130
+ key_padding_mask_input = None
131
+ x, _ = self.self_attn(
132
+ query=x, key=x, value=x, key_padding_mask=key_padding_mask_input
133
+ )
134
+ x = F.dropout(x, self.encoder_dropout, training=self.training)
135
+ x = residual + x
136
+
137
+ # ffn
138
+ residual = x
139
+ if self.use_cln:
140
+ x = self.ln_2(x, conditon)
141
+ else:
142
+ x = self.ln_2(x)
143
+ x = self.ffn(x)
144
+ x = residual + x
145
+
146
+ return x
147
+
148
+
149
+ class TransformerEncoder(nn.Module):
150
+ def __init__(
151
+ self,
152
+ enc_emb_tokens=None,
153
+ encoder_layer=4,
154
+ encoder_hidden=256,
155
+ encoder_head=4,
156
+ conv_filter_size=1024,
157
+ conv_kernel_size=5,
158
+ encoder_dropout=0.1,
159
+ use_cln=False,
160
+ use_pe=True,
161
+ cfg=None,
162
+ ):
163
+ super().__init__()
164
+
165
+ self.encoder_layer = (
166
+ encoder_layer if encoder_layer is not None else cfg.encoder_layer
167
+ )
168
+ self.encoder_hidden = (
169
+ encoder_hidden if encoder_hidden is not None else cfg.encoder_hidden
170
+ )
171
+ self.encoder_head = (
172
+ encoder_head if encoder_head is not None else cfg.encoder_head
173
+ )
174
+ self.conv_filter_size = (
175
+ conv_filter_size if conv_filter_size is not None else cfg.conv_filter_size
176
+ )
177
+ self.conv_kernel_size = (
178
+ conv_kernel_size if conv_kernel_size is not None else cfg.conv_kernel_size
179
+ )
180
+ self.encoder_dropout = (
181
+ encoder_dropout if encoder_dropout is not None else cfg.encoder_dropout
182
+ )
183
+ self.use_pe = use_pe if use_pe is not None else cfg.use_pe
184
+ self.use_cln = use_cln if use_cln is not None else cfg.use_cln
185
+
186
+ if enc_emb_tokens != None:
187
+ self.use_enc_emb = True
188
+ self.enc_emb_tokens = enc_emb_tokens
189
+ else:
190
+ self.use_enc_emb = False
191
+
192
+ if self.use_pe:
193
+ self.position_emb = PositionalEncoding(
194
+ self.encoder_hidden, self.encoder_dropout
195
+ )
196
+
197
+ self.layers = nn.ModuleList([])
198
+ self.layers.extend(
199
+ [
200
+ TransformerEncoderLayer(
201
+ self.encoder_hidden,
202
+ self.encoder_head,
203
+ self.conv_filter_size,
204
+ self.conv_kernel_size,
205
+ self.encoder_dropout,
206
+ self.use_cln,
207
+ )
208
+ for i in range(self.encoder_layer)
209
+ ]
210
+ )
211
+
212
+ if self.use_cln:
213
+ self.last_ln = StyleAdaptiveLayerNorm(self.encoder_hidden)
214
+ else:
215
+ self.last_ln = nn.LayerNorm(self.encoder_hidden)
216
+
217
+ def forward(self, x, key_padding_mask, condition=None):
218
+ if len(x.shape) == 2 and self.use_enc_emb:
219
+ x = self.enc_emb_tokens(x)
220
+ if self.use_pe:
221
+ x = self.position_emb(x)
222
+ else:
223
+ if self.use_pe:
224
+ x = self.position_emb(x) # (B, T, d)
225
+
226
+ for layer in self.layers:
227
+ x = layer(x, key_padding_mask, condition)
228
+
229
+ if self.use_cln:
230
+ x = self.last_ln(x, condition)
231
+ else:
232
+ x = self.last_ln(x)
233
+
234
+ return x
modules/audio_tokenizer/vocos.py ADDED
@@ -0,0 +1,845 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import numpy as np
4
+ import scipy
5
+ import torch
6
+ from torch import nn, view_as_real, view_as_complex
7
+ from torch import nn
8
+ from torch.nn.utils import weight_norm, remove_weight_norm
9
+ from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz
10
+
11
+
12
+ def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
13
+ """
14
+ Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
15
+
16
+ Args:
17
+ x (Tensor): Input tensor.
18
+ clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
19
+
20
+ Returns:
21
+ Tensor: Element-wise logarithm of the input tensor with clipping applied.
22
+ """
23
+ return torch.log(torch.clip(x, min=clip_val))
24
+
25
+
26
+ def symlog(x: torch.Tensor) -> torch.Tensor:
27
+ return torch.sign(x) * torch.log1p(x.abs())
28
+
29
+
30
+ def symexp(x: torch.Tensor) -> torch.Tensor:
31
+ return torch.sign(x) * (torch.exp(x.abs()) - 1)
32
+
33
+
34
+ class STFT(nn.Module):
35
+ def __init__(
36
+ self,
37
+ n_fft: int,
38
+ hop_length: int,
39
+ win_length: int,
40
+ center=True,
41
+ ):
42
+ super().__init__()
43
+ self.center = center
44
+ self.n_fft = n_fft
45
+ self.hop_length = hop_length
46
+ self.win_length = win_length
47
+ window = torch.hann_window(win_length)
48
+ self.register_buffer("window", window)
49
+
50
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
51
+ # x: (B, T * hop_length)
52
+
53
+ if not self.center:
54
+ pad = self.win_length - self.hop_length
55
+ x = torch.nn.functional.pad(x, (pad // 2, pad // 2), mode="reflect")
56
+
57
+ stft_spec = torch.stft(
58
+ x,
59
+ self.n_fft,
60
+ hop_length=self.hop_length,
61
+ win_length=self.win_length,
62
+ window=self.window,
63
+ center=self.center,
64
+ return_complex=False,
65
+ ) # (B, n_fft // 2 + 1, T, 2)
66
+
67
+ rea = stft_spec[:, :, :, 0] # (B, n_fft // 2 + 1, T, 2)
68
+ imag = stft_spec[:, :, :, 1] # (B, n_fft // 2 + 1, T, 2)
69
+
70
+ log_mag = torch.log(
71
+ torch.abs(torch.sqrt(torch.pow(rea, 2) + torch.pow(imag, 2))) + 1e-5
72
+ ) # (B, n_fft // 2 + 1, T)
73
+ phase = torch.atan2(imag, rea) # (B, n_fft // 2 + 1, T)
74
+
75
+ return log_mag, phase
76
+
77
+
78
+ class ISTFT(nn.Module):
79
+ """
80
+ Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
81
+ windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
82
+ See issue: https://github.com/pytorch/pytorch/issues/62323
83
+ Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
84
+ The NOLA constraint is met as we trim padded samples anyway.
85
+
86
+ Args:
87
+ n_fft (int): Size of Fourier transform.
88
+ hop_length (int): The distance between neighboring sliding window frames.
89
+ win_length (int): The size of window frame and STFT filter.
90
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
91
+ """
92
+
93
+ def __init__(
94
+ self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"
95
+ ):
96
+ super().__init__()
97
+ if padding not in ["center", "same"]:
98
+ raise ValueError("Padding must be 'center' or 'same'.")
99
+ self.padding = padding
100
+ self.n_fft = n_fft
101
+ self.hop_length = hop_length
102
+ self.win_length = win_length
103
+ window = torch.hann_window(win_length)
104
+ self.register_buffer("window", window)
105
+
106
+ def forward(self, spec: torch.Tensor) -> torch.Tensor:
107
+ """
108
+ Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
109
+
110
+ Args:
111
+ spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
112
+ N is the number of frequency bins, and T is the number of time frames.
113
+
114
+ Returns:
115
+ Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
116
+ """
117
+ if self.padding == "center":
118
+ # Fallback to pytorch native implementation
119
+ return torch.istft(
120
+ spec,
121
+ self.n_fft,
122
+ self.hop_length,
123
+ self.win_length,
124
+ self.window,
125
+ center=True,
126
+ )
127
+ elif self.padding == "same":
128
+ pad = (self.win_length - self.hop_length) // 2
129
+ else:
130
+ raise ValueError("Padding must be 'center' or 'same'.")
131
+
132
+ assert spec.dim() == 3, "Expected a 3D tensor as input"
133
+ B, N, T = spec.shape
134
+
135
+ # Inverse FFT
136
+ ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
137
+ ifft = ifft * self.window[None, :, None]
138
+
139
+ # Overlap and Add
140
+ output_size = (T - 1) * self.hop_length + self.win_length
141
+ y = torch.nn.functional.fold(
142
+ ifft,
143
+ output_size=(1, output_size),
144
+ kernel_size=(1, self.win_length),
145
+ stride=(1, self.hop_length),
146
+ )[:, 0, 0, pad:-pad]
147
+
148
+ # Window envelope
149
+ window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
150
+ window_envelope = torch.nn.functional.fold(
151
+ window_sq,
152
+ output_size=(1, output_size),
153
+ kernel_size=(1, self.win_length),
154
+ stride=(1, self.hop_length),
155
+ ).squeeze()[pad:-pad]
156
+
157
+ # Normalize
158
+ assert (window_envelope > 1e-11).all()
159
+ y = y / window_envelope
160
+
161
+ return y
162
+
163
+
164
+ class MDCT(nn.Module):
165
+ """
166
+ Modified Discrete Cosine Transform (MDCT) module.
167
+
168
+ Args:
169
+ frame_len (int): Length of the MDCT frame.
170
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
171
+ """
172
+
173
+ def __init__(self, frame_len: int, padding: str = "same"):
174
+ super().__init__()
175
+ if padding not in ["center", "same"]:
176
+ raise ValueError("Padding must be 'center' or 'same'.")
177
+ self.padding = padding
178
+ self.frame_len = frame_len
179
+ N = frame_len // 2
180
+ n0 = (N + 1) / 2
181
+ window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
182
+ self.register_buffer("window", window)
183
+
184
+ pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len)
185
+ post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N)
186
+ # view_as_real: NCCL Backend does not support ComplexFloat data type
187
+ # https://github.com/pytorch/pytorch/issues/71613
188
+ self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
189
+ self.register_buffer("post_twiddle", view_as_real(post_twiddle))
190
+
191
+ def forward(self, audio: torch.Tensor) -> torch.Tensor:
192
+ """
193
+ Apply the Modified Discrete Cosine Transform (MDCT) to the input audio.
194
+
195
+ Args:
196
+ audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size
197
+ and T is the length of the audio.
198
+
199
+ Returns:
200
+ Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames
201
+ and N is the number of frequency bins.
202
+ """
203
+ if self.padding == "center":
204
+ audio = torch.nn.functional.pad(
205
+ audio, (self.frame_len // 2, self.frame_len // 2)
206
+ )
207
+ elif self.padding == "same":
208
+ # hop_length is 1/2 frame_len
209
+ audio = torch.nn.functional.pad(
210
+ audio, (self.frame_len // 4, self.frame_len // 4)
211
+ )
212
+ else:
213
+ raise ValueError("Padding must be 'center' or 'same'.")
214
+
215
+ x = audio.unfold(-1, self.frame_len, self.frame_len // 2)
216
+ N = self.frame_len // 2
217
+ x = x * self.window.expand(x.shape)
218
+ X = torch.fft.fft(
219
+ x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1
220
+ )[..., :N]
221
+ res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N)
222
+ return torch.real(res) * np.sqrt(2)
223
+
224
+
225
+ class IMDCT(nn.Module):
226
+ """
227
+ Inverse Modified Discrete Cosine Transform (IMDCT) module.
228
+
229
+ Args:
230
+ frame_len (int): Length of the MDCT frame.
231
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
232
+ """
233
+
234
+ def __init__(self, frame_len: int, padding: str = "same"):
235
+ super().__init__()
236
+ if padding not in ["center", "same"]:
237
+ raise ValueError("Padding must be 'center' or 'same'.")
238
+ self.padding = padding
239
+ self.frame_len = frame_len
240
+ N = frame_len // 2
241
+ n0 = (N + 1) / 2
242
+ window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
243
+ self.register_buffer("window", window)
244
+
245
+ pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N)
246
+ post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2))
247
+ self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
248
+ self.register_buffer("post_twiddle", view_as_real(post_twiddle))
249
+
250
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
251
+ """
252
+ Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients.
253
+
254
+ Args:
255
+ X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size,
256
+ L is the number of frames, and N is the number of frequency bins.
257
+
258
+ Returns:
259
+ Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio.
260
+ """
261
+ B, L, N = X.shape
262
+ Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device)
263
+ Y[..., :N] = X
264
+ Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,)))
265
+ y = torch.fft.ifft(
266
+ Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1
267
+ )
268
+ y = (
269
+ torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape))
270
+ * np.sqrt(N)
271
+ * np.sqrt(2)
272
+ )
273
+ result = y * self.window.expand(y.shape)
274
+ output_size = (1, (L + 1) * N)
275
+ audio = torch.nn.functional.fold(
276
+ result.transpose(1, 2),
277
+ output_size=output_size,
278
+ kernel_size=(1, self.frame_len),
279
+ stride=(1, self.frame_len // 2),
280
+ )[:, 0, 0, :]
281
+
282
+ if self.padding == "center":
283
+ pad = self.frame_len // 2
284
+ elif self.padding == "same":
285
+ pad = self.frame_len // 4
286
+ else:
287
+ raise ValueError("Padding must be 'center' or 'same'.")
288
+
289
+ audio = audio[:, pad:-pad]
290
+ return audio
291
+
292
+
293
+ class FourierHead(nn.Module):
294
+ """Base class for inverse fourier modules."""
295
+
296
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
297
+ """
298
+ Args:
299
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
300
+ L is the sequence length, and H denotes the model dimension.
301
+
302
+ Returns:
303
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
304
+ """
305
+ raise NotImplementedError("Subclasses must implement the forward method.")
306
+
307
+
308
+ class ISTFTHead(FourierHead):
309
+ """
310
+ ISTFT Head module for predicting STFT complex coefficients.
311
+
312
+ Args:
313
+ dim (int): Hidden dimension of the model.
314
+ n_fft (int): Size of Fourier transform.
315
+ hop_length (int): The distance between neighboring sliding window frames, which should align with
316
+ the resolution of the input features.
317
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
318
+ """
319
+
320
+ def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
321
+ super().__init__()
322
+ out_dim = n_fft + 2
323
+ self.out = torch.nn.Linear(dim, out_dim)
324
+ self.istft = ISTFT(
325
+ n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding
326
+ )
327
+
328
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
329
+ """
330
+ Forward pass of the ISTFTHead module.
331
+
332
+ Args:
333
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
334
+ L is the sequence length, and H denotes the model dimension.
335
+
336
+ Returns:
337
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
338
+ """
339
+ x = self.out(x).transpose(1, 2)
340
+ mag, p = x.chunk(2, dim=1)
341
+ mag = torch.exp(mag)
342
+ mag = torch.clip(
343
+ mag, max=1e2
344
+ ) # safeguard to prevent excessively large magnitudes
345
+ # wrapping happens here. These two lines produce real and imaginary value
346
+ x = torch.cos(p)
347
+ y = torch.sin(p)
348
+ # recalculating phase here does not produce anything new
349
+ # only costs time
350
+ # phase = torch.atan2(y, x)
351
+ # S = mag * torch.exp(phase * 1j)
352
+ # better directly produce the complex value
353
+ S = mag * (x + 1j * y)
354
+ audio = self.istft(S)
355
+ return audio
356
+
357
+
358
+ class IMDCTSymExpHead(FourierHead):
359
+ """
360
+ IMDCT Head module for predicting MDCT coefficients with symmetric exponential function
361
+
362
+ Args:
363
+ dim (int): Hidden dimension of the model.
364
+ mdct_frame_len (int): Length of the MDCT frame.
365
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
366
+ sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized
367
+ based on perceptual scaling. Defaults to None.
368
+ clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
369
+ """
370
+
371
+ def __init__(
372
+ self,
373
+ dim: int,
374
+ mdct_frame_len: int,
375
+ padding: str = "same",
376
+ sample_rate: Optional[int] = None,
377
+ clip_audio: bool = False,
378
+ ):
379
+ super().__init__()
380
+ out_dim = mdct_frame_len // 2
381
+ self.out = nn.Linear(dim, out_dim)
382
+ self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
383
+ self.clip_audio = clip_audio
384
+
385
+ if sample_rate is not None:
386
+ # optionally init the last layer following mel-scale
387
+ m_max = _hz_to_mel(sample_rate // 2)
388
+ m_pts = torch.linspace(0, m_max, out_dim)
389
+ f_pts = _mel_to_hz(m_pts)
390
+ scale = 1 - (f_pts / f_pts.max())
391
+
392
+ with torch.no_grad():
393
+ self.out.weight.mul_(scale.view(-1, 1))
394
+
395
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
396
+ """
397
+ Forward pass of the IMDCTSymExpHead module.
398
+
399
+ Args:
400
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
401
+ L is the sequence length, and H denotes the model dimension.
402
+
403
+ Returns:
404
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
405
+ """
406
+ x = self.out(x)
407
+ x = symexp(x)
408
+ x = torch.clip(
409
+ x, min=-1e2, max=1e2
410
+ ) # safeguard to prevent excessively large magnitudes
411
+ audio = self.imdct(x)
412
+ if self.clip_audio:
413
+ audio = torch.clip(x, min=-1.0, max=1.0)
414
+
415
+ return audio
416
+
417
+
418
+ class IMDCTCosHead(FourierHead):
419
+ """
420
+ IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p)
421
+
422
+ Args:
423
+ dim (int): Hidden dimension of the model.
424
+ mdct_frame_len (int): Length of the MDCT frame.
425
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
426
+ clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
427
+ """
428
+
429
+ def __init__(
430
+ self,
431
+ dim: int,
432
+ mdct_frame_len: int,
433
+ padding: str = "same",
434
+ clip_audio: bool = False,
435
+ ):
436
+ super().__init__()
437
+ self.clip_audio = clip_audio
438
+ self.out = nn.Linear(dim, mdct_frame_len)
439
+ self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
440
+
441
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
442
+ """
443
+ Forward pass of the IMDCTCosHead module.
444
+
445
+ Args:
446
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
447
+ L is the sequence length, and H denotes the model dimension.
448
+
449
+ Returns:
450
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
451
+ """
452
+ x = self.out(x)
453
+ m, p = x.chunk(2, dim=2)
454
+ m = torch.exp(m).clip(
455
+ max=1e2
456
+ ) # safeguard to prevent excessively large magnitudes
457
+ audio = self.imdct(m * torch.cos(p))
458
+ if self.clip_audio:
459
+ audio = torch.clip(x, min=-1.0, max=1.0)
460
+ return audio
461
+
462
+
463
+ class ConvNeXtBlock(nn.Module):
464
+ """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
465
+
466
+ Args:
467
+ dim (int): Number of input channels.
468
+ intermediate_dim (int): Dimensionality of the intermediate layer.
469
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
470
+ Defaults to None.
471
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
472
+ None means non-conditional LayerNorm. Defaults to None.
473
+ """
474
+
475
+ def __init__(
476
+ self,
477
+ dim: int,
478
+ intermediate_dim: int,
479
+ layer_scale_init_value: float,
480
+ adanorm_num_embeddings: Optional[int] = None,
481
+ ):
482
+ super().__init__()
483
+ self.dwconv = nn.Conv1d(
484
+ dim, dim, kernel_size=7, padding=3, groups=dim
485
+ ) # depthwise conv
486
+ self.adanorm = adanorm_num_embeddings is not None
487
+ if adanorm_num_embeddings:
488
+ self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
489
+ else:
490
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
491
+ self.pwconv1 = nn.Linear(
492
+ dim, intermediate_dim
493
+ ) # pointwise/1x1 convs, implemented with linear layers
494
+ self.act = nn.GELU()
495
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
496
+ self.gamma = (
497
+ nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
498
+ if layer_scale_init_value > 0
499
+ else None
500
+ )
501
+
502
+ def forward(
503
+ self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
504
+ ) -> torch.Tensor:
505
+ residual = x
506
+ x = self.dwconv(x)
507
+ x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
508
+ if self.adanorm:
509
+ assert cond_embedding_id is not None
510
+ x = self.norm(x, cond_embedding_id)
511
+ else:
512
+ x = self.norm(x)
513
+ x = self.pwconv1(x)
514
+ x = self.act(x)
515
+ x = self.pwconv2(x)
516
+ if self.gamma is not None:
517
+ x = self.gamma * x
518
+ x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
519
+
520
+ x = residual + x
521
+ return x
522
+
523
+
524
+ class AdaLayerNorm(nn.Module):
525
+ """
526
+ Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
527
+
528
+ Args:
529
+ num_embeddings (int): Number of embeddings.
530
+ embedding_dim (int): Dimension of the embeddings.
531
+ """
532
+
533
+ def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
534
+ super().__init__()
535
+ self.eps = eps
536
+ self.dim = embedding_dim
537
+ self.scale = nn.Embedding(
538
+ num_embeddings=num_embeddings, embedding_dim=embedding_dim
539
+ )
540
+ self.shift = nn.Embedding(
541
+ num_embeddings=num_embeddings, embedding_dim=embedding_dim
542
+ )
543
+ torch.nn.init.ones_(self.scale.weight)
544
+ torch.nn.init.zeros_(self.shift.weight)
545
+
546
+ def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
547
+ scale = self.scale(cond_embedding_id)
548
+ shift = self.shift(cond_embedding_id)
549
+ x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
550
+ x = x * scale + shift
551
+ return x
552
+
553
+
554
+ class ResBlock1(nn.Module):
555
+ """
556
+ ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
557
+ but without upsampling layers.
558
+
559
+ Args:
560
+ dim (int): Number of input channels.
561
+ kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
562
+ dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
563
+ Defaults to (1, 3, 5).
564
+ lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
565
+ Defaults to 0.1.
566
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
567
+ Defaults to None.
568
+ """
569
+
570
+ def __init__(
571
+ self,
572
+ dim: int,
573
+ kernel_size: int = 3,
574
+ dilation: Tuple[int, int, int] = (1, 3, 5),
575
+ lrelu_slope: float = 0.1,
576
+ layer_scale_init_value: Optional[float] = None,
577
+ ):
578
+ super().__init__()
579
+ self.lrelu_slope = lrelu_slope
580
+ self.convs1 = nn.ModuleList(
581
+ [
582
+ weight_norm(
583
+ nn.Conv1d(
584
+ dim,
585
+ dim,
586
+ kernel_size,
587
+ 1,
588
+ dilation=dilation[0],
589
+ padding=self.get_padding(kernel_size, dilation[0]),
590
+ )
591
+ ),
592
+ weight_norm(
593
+ nn.Conv1d(
594
+ dim,
595
+ dim,
596
+ kernel_size,
597
+ 1,
598
+ dilation=dilation[1],
599
+ padding=self.get_padding(kernel_size, dilation[1]),
600
+ )
601
+ ),
602
+ weight_norm(
603
+ nn.Conv1d(
604
+ dim,
605
+ dim,
606
+ kernel_size,
607
+ 1,
608
+ dilation=dilation[2],
609
+ padding=self.get_padding(kernel_size, dilation[2]),
610
+ )
611
+ ),
612
+ ]
613
+ )
614
+
615
+ self.convs2 = nn.ModuleList(
616
+ [
617
+ weight_norm(
618
+ nn.Conv1d(
619
+ dim,
620
+ dim,
621
+ kernel_size,
622
+ 1,
623
+ dilation=1,
624
+ padding=self.get_padding(kernel_size, 1),
625
+ )
626
+ ),
627
+ weight_norm(
628
+ nn.Conv1d(
629
+ dim,
630
+ dim,
631
+ kernel_size,
632
+ 1,
633
+ dilation=1,
634
+ padding=self.get_padding(kernel_size, 1),
635
+ )
636
+ ),
637
+ weight_norm(
638
+ nn.Conv1d(
639
+ dim,
640
+ dim,
641
+ kernel_size,
642
+ 1,
643
+ dilation=1,
644
+ padding=self.get_padding(kernel_size, 1),
645
+ )
646
+ ),
647
+ ]
648
+ )
649
+
650
+ self.gamma = nn.ParameterList(
651
+ [
652
+ (
653
+ nn.Parameter(
654
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
655
+ )
656
+ if layer_scale_init_value is not None
657
+ else None
658
+ ),
659
+ (
660
+ nn.Parameter(
661
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
662
+ )
663
+ if layer_scale_init_value is not None
664
+ else None
665
+ ),
666
+ (
667
+ nn.Parameter(
668
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
669
+ )
670
+ if layer_scale_init_value is not None
671
+ else None
672
+ ),
673
+ ]
674
+ )
675
+
676
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
677
+ for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
678
+ xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
679
+ xt = c1(xt)
680
+ xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
681
+ xt = c2(xt)
682
+ if gamma is not None:
683
+ xt = gamma * xt
684
+ x = xt + x
685
+ return x
686
+
687
+ def remove_weight_norm(self):
688
+ for l in self.convs1:
689
+ remove_weight_norm(l)
690
+ for l in self.convs2:
691
+ remove_weight_norm(l)
692
+
693
+ @staticmethod
694
+ def get_padding(kernel_size: int, dilation: int = 1) -> int:
695
+ return int((kernel_size * dilation - dilation) / 2)
696
+
697
+
698
+ class Backbone(nn.Module):
699
+ """Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
700
+
701
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
702
+ """
703
+ Args:
704
+ x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
705
+ C denotes output features, and L is the sequence length.
706
+
707
+ Returns:
708
+ Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
709
+ and H denotes the model dimension.
710
+ """
711
+ raise NotImplementedError("Subclasses must implement the forward method.")
712
+
713
+
714
+ class VocosBackbone(Backbone):
715
+ """
716
+ Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
717
+
718
+ Args:
719
+ input_channels (int): Number of input features channels.
720
+ dim (int): Hidden dimension of the model.
721
+ intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
722
+ num_layers (int): Number of ConvNeXtBlock layers.
723
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
724
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
725
+ None means non-conditional model. Defaults to None.
726
+ """
727
+
728
+ def __init__(
729
+ self,
730
+ input_channels: int,
731
+ dim: int,
732
+ intermediate_dim: int,
733
+ num_layers: int,
734
+ layer_scale_init_value: Optional[float] = None,
735
+ adanorm_num_embeddings: Optional[int] = None,
736
+ ):
737
+ super().__init__()
738
+ self.input_channels = input_channels
739
+ self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
740
+ self.adanorm = adanorm_num_embeddings is not None
741
+ if adanorm_num_embeddings:
742
+ self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
743
+ else:
744
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
745
+ layer_scale_init_value = layer_scale_init_value or 1 / num_layers
746
+ self.convnext = nn.ModuleList(
747
+ [
748
+ ConvNeXtBlock(
749
+ dim=dim,
750
+ intermediate_dim=intermediate_dim,
751
+ layer_scale_init_value=layer_scale_init_value,
752
+ adanorm_num_embeddings=adanorm_num_embeddings,
753
+ )
754
+ for _ in range(num_layers)
755
+ ]
756
+ )
757
+ self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
758
+ self.apply(self._init_weights)
759
+
760
+ def _init_weights(self, m):
761
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
762
+ nn.init.trunc_normal_(m.weight, std=0.02)
763
+ nn.init.constant_(m.bias, 0)
764
+
765
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
766
+ bandwidth_id = kwargs.get("bandwidth_id", None)
767
+ x = self.embed(x)
768
+ if self.adanorm:
769
+ assert bandwidth_id is not None
770
+ x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id)
771
+ else:
772
+ x = self.norm(x.transpose(1, 2))
773
+ x = x.transpose(1, 2)
774
+ for conv_block in self.convnext:
775
+ x = conv_block(x, cond_embedding_id=bandwidth_id)
776
+ x = self.final_layer_norm(x.transpose(1, 2))
777
+ return x
778
+
779
+
780
+ class VocosResNetBackbone(Backbone):
781
+ """
782
+ Vocos backbone module built with ResBlocks.
783
+
784
+ Args:
785
+ input_channels (int): Number of input features channels.
786
+ dim (int): Hidden dimension of the model.
787
+ num_blocks (int): Number of ResBlock1 blocks.
788
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
789
+ """
790
+
791
+ def __init__(
792
+ self,
793
+ input_channels,
794
+ dim,
795
+ num_blocks,
796
+ layer_scale_init_value=None,
797
+ ):
798
+ super().__init__()
799
+ self.input_channels = input_channels
800
+ self.embed = weight_norm(
801
+ nn.Conv1d(input_channels, dim, kernel_size=3, padding=1)
802
+ )
803
+ layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
804
+ self.resnet = nn.Sequential(
805
+ *[
806
+ ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value)
807
+ for _ in range(num_blocks)
808
+ ]
809
+ )
810
+
811
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
812
+ x = self.embed(x)
813
+ x = self.resnet(x)
814
+ x = x.transpose(1, 2)
815
+ return x
816
+
817
+
818
+ class Vocos(nn.Module):
819
+ def __init__(
820
+ self,
821
+ input_channels: int = 256,
822
+ dim: int = 384,
823
+ intermediate_dim: int = 1152,
824
+ num_layers: int = 8,
825
+ adanorm_num_embeddings: int = 4,
826
+ n_fft: int = 800,
827
+ hop_size: int = 200,
828
+ padding: str = "same",
829
+ ):
830
+ super().__init__()
831
+
832
+ self.backbone = VocosBackbone(
833
+ input_channels=input_channels,
834
+ dim=dim,
835
+ intermediate_dim=intermediate_dim,
836
+ num_layers=num_layers,
837
+ adanorm_num_embeddings=adanorm_num_embeddings,
838
+ )
839
+ self.head = ISTFTHead(dim, n_fft, hop_size, padding)
840
+
841
+ def forward(self, x):
842
+ x = self.backbone(x)
843
+ x = self.head(x)
844
+
845
+ return x[:, None, :]
modules/tokenizer/tokenizer.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC
2
+ from abc import abstractmethod
3
+ import sentencepiece as spm
4
+ from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model
5
+ from typing import Any, Union
6
+ import numpy as np
7
+ from dataclasses import dataclass
8
+
9
+ def encode_pieces(sp_model: spm.SentencePieceProcessor, text: str, sample=False):
10
+ """Encode text into sentence pieces. Only supports py3."""
11
+
12
+ if not sample:
13
+ pieces = sp_model.EncodeAsPieces(text)
14
+ else:
15
+ pieces = sp_model.SampleEncodeAsPieces(text, 64, 0.1)
16
+
17
+ return pieces
18
+
19
+
20
+ class AbstractTokenizer(ABC):
21
+ """Abstract class for tokenizer."""
22
+
23
+ def __init__(self, name):
24
+ self.name = name
25
+ super().__init__()
26
+
27
+ @property
28
+ @abstractmethod
29
+ def vocab_size(self):
30
+ pass
31
+
32
+ @property
33
+ @abstractmethod
34
+ def vocab(self):
35
+ """Dictionary from vocab text token to id token."""
36
+ pass
37
+
38
+ @property
39
+ @abstractmethod
40
+ def inv_vocab(self):
41
+ """Dictionary from vocab id token to text token."""
42
+ pass
43
+
44
+ @abstractmethod
45
+ def tokenize(self, text):
46
+ pass
47
+
48
+ def detokenize(self, token_ids):
49
+ raise NotImplementedError('detokenizer is not implemented for {} '
50
+ 'tokenizer'.format(self.name))
51
+
52
+ @property
53
+ def cls(self):
54
+ raise NotImplementedError('CLS is not provided for {} '
55
+ 'tokenizer'.format(self.name))
56
+
57
+ @property
58
+ def sep(self):
59
+ raise NotImplementedError('SEP is not provided for {} '
60
+ 'tokenizer'.format(self.name))
61
+
62
+ @property
63
+ def pad(self):
64
+ raise NotImplementedError('PAD is not provided for {} '
65
+ 'tokenizer'.format(self.name))
66
+
67
+ @property
68
+ def eod(self):
69
+ raise NotImplementedError('EOD is not provided for {} '
70
+ 'tokenizer'.format(self.name))
71
+
72
+ @property
73
+ def mask(self):
74
+ raise NotImplementedError('MASK is not provided for {} '
75
+ 'tokenizer'.format(self.name))
76
+
77
+
78
+ class SPieceTokenizer(AbstractTokenizer):
79
+ def __init__(self, spm_file: str):
80
+ super().__init__('Sentence Piece')
81
+ self.sp_model = spm.SentencePieceProcessor()
82
+ self.sp_model.Load(spm_file)
83
+ self.eod_id = self.get_token_id('</s>')
84
+
85
+ self.special_ids = set([
86
+ self.sp_model.pad_id(),
87
+ self.sp_model.eos_id(),
88
+ self.sp_model.bos_id(),
89
+ self.sp_model.unk_id(),
90
+ self.eod_id,
91
+ ])
92
+
93
+ # initialize index_2_bytes
94
+ self._initialize_index_2_bytes()
95
+
96
+ def encode_pieces(self, text: str, sample=False):
97
+ if not sample:
98
+ pieces = self.sp_model.EncodeAsPieces(text)
99
+ else:
100
+ pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1)
101
+ return pieces
102
+
103
+ def _initialize_index_2_bytes(self):
104
+ proto = sp_pb2_model.ModelProto()
105
+ proto.ParseFromString(self.sp_model.serialized_model_proto())
106
+ self.index_2_numbytes = [0] * len(proto.pieces)
107
+ for i, p in enumerate(proto.pieces):
108
+ clean_piece = p.piece.replace('▁', '')
109
+ self.index_2_numbytes[i] = len(clean_piece.encode('utf-8'))
110
+
111
+ def set_add_dummy_prefix(self, add_dummy_prefix: bool = False):
112
+ proto = sp_pb2_model.ModelProto()
113
+ proto.ParseFromString(self.sp_model.serialized_model_proto())
114
+ if proto.normalizer_spec.add_dummy_prefix != add_dummy_prefix:
115
+ proto.normalizer_spec.add_dummy_prefix = add_dummy_prefix
116
+ self.sp_model.LoadFromSerializedProto(proto.SerializeToString())
117
+ print(f"> set add_dummy_prefix to {add_dummy_prefix} ...", flush=True)
118
+
119
+ def add_special_id(self, token_id):
120
+ self.special_ids.add(token_id)
121
+
122
+ @property
123
+ def has_dummy_prefix(self):
124
+ pieces = self.sp_model.EncodeAsPieces("hello")
125
+ return pieces[0].startswith('▁')
126
+
127
+ @property
128
+ def vocab_size(self):
129
+ return self.sp_model.GetPieceSize()
130
+
131
+ @property
132
+ def vocab(self):
133
+ """Dictionary from vocab text token to id token."""
134
+ return self.sp_model
135
+
136
+ def get_array_bytes(self, array):
137
+ return sum(self.index_2_numbytes[i] if i < self.vocab_size else 2 for i in array)
138
+
139
+ def tokenize(self, text):
140
+ tokens = encode_pieces(self.sp_model, text)
141
+ return self.convert_tokens_to_ids(tokens)
142
+
143
+ def encode(self, text: str, bos: bool=False, eos: bool=False, **kwargs: Any) -> list[int]:
144
+ tokens = self.encode_pieces(text)
145
+ t = self.convert_tokens_to_ids(tokens)
146
+ if bos:
147
+ t.insert(0, self.bos_id)
148
+ if eos:
149
+ t.append(self.eos_id)
150
+ return t
151
+
152
+ def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]:
153
+ if isinstance(tokens, str):
154
+ return self.sp_model.PieceToId(tokens)
155
+ return [self.sp_model.PieceToId(token) for token in tokens]
156
+
157
+ def detokenize(self, token_ids):
158
+ if isinstance(token_ids, list):
159
+ pieces = [self.sp_model.IdToPiece(id) for id in token_ids]
160
+ else:
161
+ pieces = [self.sp_model.IdToPiece(id) for id in token_ids.tolist()]
162
+ return pieces
163
+
164
+ def decode(self, token_ids: Union[int, list[int]], skip_special_tokens: bool = False) -> str:
165
+ assert not skip_special_tokens, "skip_special_tokens is not supported"
166
+ if isinstance(token_ids, (int, np.integer)):
167
+ return self.detokenize([int(token_ids)])[0]
168
+ return ''.join(self.detokenize(token_ids))
169
+
170
+ def get_token_id(self, token):
171
+ return self.sp_model.PieceToId(token)
172
+
173
+ def inv_vocab(self):
174
+ # TODO: to be implemented
175
+ return {}
176
+
177
+ def decode_pieces(self, pieces):
178
+ return self.sp_model.DecodePieces(pieces)
179
+
180
+ @property
181
+ def eod(self):
182
+ return self.eod_id
183
+
184
+ @property
185
+ def pad_id(self):
186
+ return self.sp_model.pad_id()
187
+
188
+ @property
189
+ def eos_id(self):
190
+ return self.sp_model.eos_id()
191
+
192
+ @property
193
+ def bos_id(self):
194
+ return self.sp_model.bos_id()
195
+
196
+ @property
197
+ def unk_id(self):
198
+ return self.sp_model.unk_id()
199
+
200
+ @property
201
+ def pad_token_id(self):
202
+ return self.pad_id
203
+
204
+ @property
205
+ def eos_token_id(self):
206
+ return self.eos_id
207
+
208
+
209
+ @dataclass
210
+ class ExtraTokens:
211
+ msg_end: int
212
+ user_msg_start: int
213
+ assistant_msg_start: int
214
+ name_end: int
215
+ media_begin: int
216
+ media_content: int
217
+ media_end: int
218
+ pad: int
219
+
220
+
221
+ def instantiate_extra_tokens(tokenizer: AbstractTokenizer):
222
+ if isinstance(tokenizer, SPieceTokenizer):
223
+ map_fn = lambda x: tokenizer.convert_tokens_to_ids(x)
224
+ else:
225
+ raise ValueError(f"Invalid tokenizer type: {type(tokenizer)}")
226
+
227
+ return ExtraTokens(
228
+ msg_end=map_fn('[extra_id_0]'),
229
+ user_msg_start=map_fn('[extra_id_1]'),
230
+ assistant_msg_start=map_fn('[extra_id_2]'),
231
+ name_end=map_fn('[extra_id_12]'),
232
+ media_begin=map_fn('[extra_id_13]'),
233
+ media_content=map_fn('[extra_id_14]'),
234
+ media_end=map_fn('[extra_id_15]'),
235
+ pad=tokenizer.pad_id
236
+ )
237
+
238
+ def get_tokenizer_and_extra_tokens():
239
+ sp_model_path = "resources/tokenizer/160k.model"
240
+ tokenizer = SPieceTokenizer(sp_model_path)
241
+ tokenizer.set_add_dummy_prefix(False)
242
+ extra_tokens = instantiate_extra_tokens(tokenizer)
243
+ return tokenizer, extra_tokens
readme.md ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MoonCast: High-Quality Zero-Shot Podcast Generation
2
+
3
+ ## Overview
4
+ Demo page: [demo](https://mooncastdemo.github.io)
5
+
6
+ Paper: [paper](https://arxiv.org/abs/2503.14345)
7
+
8
+ We open-source this system to advance the field of human-like speech synthesis. Our goal is to create more natural and expressive synthetic voices that bridge the gap between machines and humans. We hope this project will inspire researchers and developers to explore new possibilities in voice technology. We warmly welcome contributions from anyone interested in this project. Whether through code, documentation, feedback, or sharing your insights, every input helps make this project better.
9
+
10
+ ### Abstract
11
+ Recent advances in text-to-speech synthesis have achieved notable success in generating high-quality short utterances for individual speakers. However, these systems still face challenges when extending their capabilities to long, multi-speaker, and spontaneous dialogues, typical of real-world scenarios such as podcasts. These limitations arise from two primary challenges: 1) long speech: podcasts typically span several minutes, exceeding the upper limit of most existing work; 2) spontaneity: podcasts are marked by their spontaneous, oral nature, which sharply contrasts with formal, written contexts; existing works often fall short in capturing this spontaneity. In this paper, we propose MoonCast, a solution for high-quality zero-shot podcast generation, aiming to synthesize natural podcast-style speech from text-only sources (e.g., stories, technical reports, news in TXT, PDF, or Web URL formats) using the voices of unseen speakers. To generate long audio, we adopt a long-context language model-based audio modeling approach utilizing large-scale long-context speech data. To enhance spontaneity, we utilize a podcast generation module to generate scripts with spontaneous details, which have been empirically shown to be as crucial as the text-to-speech modeling itself. Experiments demonstrate that MoonCast outperforms baselines, with particularly notable improvements in spontaneity and coherence.
12
+
13
+ ## Environment Setup
14
+ - Create conda environment.
15
+
16
+ ``` sh
17
+ conda create -n mooncast -y python=3.10
18
+ conda activate mooncast
19
+ pip install -r requirements.txt
20
+ pip install flash-attn --no-build-isolation
21
+ pip install huggingface_hub
22
+ pip install gradio==5.22.0
23
+ ```
24
+
25
+ - Download the pretrained weights.
26
+ ``` sh
27
+ python download_pretrain.py
28
+ ```
29
+
30
+ ## Example Usage
31
+
32
+ The audio prompts used in this project are sourced from publicly available podcast segments and are intended solely for demonstration purposes. Redistribution of these audio files, whether in their original form or as generated audio, is strictly prohibited. If you have any concerns or questions regarding the use of these audio files, please contact us at [email protected]
33
+
34
+ ```sh
35
+ CUDA_VISIBLE_DEVICIES=0 python inference.py
36
+ ```
37
+
38
+ ## Disclaimer
39
+ This project is intended for **research purposes only**. We strongly encourage users to **use this project and its generated audio responsibly**. **We are not responsible for any misuse or abuse of this project**. By using this project, you agree to comply with all applicable laws and ethical guidelines.
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.3.1
2
+ torchaudio==2.3.1
3
+ sentencepiece==0.2.0
4
+ protobuf
5
+ numpy
6
+
7
+ librosa==0.9.1
8
+ pyyaml
9
+ transformers
10
+ safetensors
11
+ einops
12
+ scipy
13
+ timm==1.0.7
14
+ torchdyn
15
+ librosa
16
+ accelerate==0.26.0
17
+ ninja
18
+ cryptography
test/test_audio_detokenizer.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import torch
3
+ sys.path.append('.')
4
+ from modules.audio_tokenizer.audio_tokenizer import get_audio_tokenizer
5
+ from modules.audio_detokenizer.audio_detokenizer import get_audio_detokenizer, detokenize
6
+ import torchaudio
7
+ import librosa
8
+
9
+ if __name__ == '__main__':
10
+ audio_tokenizer = get_audio_tokenizer()
11
+ audio_detokenizer = get_audio_detokenizer()
12
+
13
+ input_wav_16k, _ = librosa.load("en_prompt0.wav", sr=16000)
14
+ input_wav_24k, _ = librosa.load("en_prompt0.wav", sr=24000)
15
+
16
+ prompt_sec = 1
17
+ prompt_wav_16k = input_wav_16k[:16000*prompt_sec]
18
+ prompt_wav_24k = input_wav_24k[:24000*prompt_sec]
19
+ input_wav_16k = input_wav_16k[16000*prompt_sec:]
20
+ input_wav_24k = input_wav_24k[24000*prompt_sec:]
21
+
22
+ prompt_wav_24k = torch.tensor(prompt_wav_24k)[None, :].cuda()
23
+ prompt_wav_16k = torch.tensor(prompt_wav_16k)[None, :].cuda()
24
+ input_wav_24k = torch.tensor(input_wav_24k)[None, :].cuda()
25
+ input_wav_16k = torch.tensor(input_wav_16k)[None, :].cuda()
26
+
27
+ semantic_token = audio_tokenizer.tokenize(input_wav_16k)
28
+ prompt_semantic_token = audio_tokenizer.tokenize(prompt_wav_16k)
29
+
30
+ recon_wav = detokenize(audio_detokenizer, semantic_token, prompt_wav_24k, prompt_semantic_token)
31
+ print(recon_wav.shape)
32
+ torchaudio.save("test/tmp_recon_en_prompt0.wav", recon_wav.cpu(), 24000)
33
+
34
+ print("All tests passed!")
test/test_audio_tokenizer.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('.')
3
+ from modules.audio_tokenizer.audio_tokenizer import get_audio_tokenizer
4
+ import torch
5
+
6
+ if __name__ == '__main__':
7
+ audio_tokenizer = get_audio_tokenizer()
8
+
9
+ input_wav = torch.zeros(1, 8000)
10
+ semantic_token = audio_tokenizer.tokenize(input_wav)
11
+ semantic_token = semantic_token.cpu().numpy().tolist()
12
+ assert semantic_token == [[ 765, 3512, 7469, 7469, 7028, 2567, 6008, 7469, 6217, 2567, 7649, 7469,
13
+ 3292, 2567, 7649, 7469, 3292, 2567, 948, 7469, 3292, 2567, 948, 7469]]
14
+
15
+ print("All tests passed!")
test/test_tokenizer.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('.')
3
+ from modules.tokenizer.tokenizer import get_tokenizer_and_extra_tokens
4
+
5
+
6
+
7
+ if __name__ == '__main__':
8
+ tokenizer, extra_tokens = get_tokenizer_and_extra_tokens()
9
+
10
+ assert tokenizer.encode("user") == [1495]
11
+ assert tokenizer.decode([1495]) == "user"
12
+
13
+ assert tokenizer.encode("0") == [501]
14
+ assert tokenizer.decode([501]) == "0"
15
+
16
+ assert tokenizer.encode("1") == [503]
17
+ assert tokenizer.decode([503]) == "1"
18
+
19
+ assert tokenizer.encode("assistant") == [110866]
20
+ assert tokenizer.decode([110866]) == "assistant"
21
+
22
+ assert tokenizer.encode("audio") == [26229]
23
+ assert tokenizer.decode([26229]) == "audio"
24
+
25
+
26
+ assert extra_tokens.msg_end == 260
27
+ assert extra_tokens.user_msg_start == 261
28
+ assert extra_tokens.assistant_msg_start == 262
29
+ assert extra_tokens.name_end == 272
30
+ assert extra_tokens.media_begin == 273
31
+ assert extra_tokens.media_content == 274
32
+ assert extra_tokens.media_end == 275
33
+
34
+ assert [tokenizer.convert_tokens_to_ids(i) for i in ['<0x0A>', '</s>', '[extra_id_0]']] == [14, 1, 260]
35
+
36
+ print("All tests passed!")
37
+
zh_prompt0.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f4a334352093e0aafb2217b53f9a027e7619c74b3823f70e752aa9f51aebc597
3
+ size 240044
zh_prompt1.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:21a376ffa8591b16ad3c2403d6d9f9a5053b799039770c5d49ceaa0c92d6eafe
3
+ size 228964