Spaces:
Running
on
Zero
Running
on
Zero
Upload folder using huggingface_hub
Browse files- .gitattributes +4 -0
- .gitignore +7 -0
- .gradio/certificate.pem +31 -0
- LICENSE +21 -0
- README.md +3 -9
- app.py +211 -0
- download_pretrain.py +2 -0
- en_prompt0.wav +3 -0
- en_prompt1.wav +3 -0
- inference.py +342 -0
- modules/audio_detokenizer/audio_detokenizer.py +249 -0
- modules/audio_detokenizer/bigvgan_wrapper.py +94 -0
- modules/audio_detokenizer/flow_matching/dit_block.py +236 -0
- modules/audio_detokenizer/flow_matching/model.py +295 -0
- modules/audio_detokenizer/flow_matching/ode_wrapper.py +164 -0
- modules/audio_detokenizer/flow_matching/scheduler.py +82 -0
- modules/audio_detokenizer/semantic_fm_prefix_streaming.py +273 -0
- modules/audio_detokenizer/vocoder/activations.py +123 -0
- modules/audio_detokenizer/vocoder/alias_free_activation/__init__.py +0 -0
- modules/audio_detokenizer/vocoder/alias_free_activation/cuda/__init__.py +0 -0
- modules/audio_detokenizer/vocoder/alias_free_activation/cuda/activation1d.py +77 -0
- modules/audio_detokenizer/vocoder/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
- modules/audio_detokenizer/vocoder/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
- modules/audio_detokenizer/vocoder/alias_free_activation/cuda/compat.h +29 -0
- modules/audio_detokenizer/vocoder/alias_free_activation/cuda/load.py +86 -0
- modules/audio_detokenizer/vocoder/alias_free_activation/cuda/type_shim.h +92 -0
- modules/audio_detokenizer/vocoder/alias_free_activation/torch/__init__.py +6 -0
- modules/audio_detokenizer/vocoder/alias_free_activation/torch/act.py +30 -0
- modules/audio_detokenizer/vocoder/alias_free_activation/torch/filter.py +101 -0
- modules/audio_detokenizer/vocoder/alias_free_activation/torch/resample.py +58 -0
- modules/audio_detokenizer/vocoder/bigvgan.py +492 -0
- modules/audio_detokenizer/vocoder/utils.py +105 -0
- modules/audio_tokenizer/audio_tokenizer.py +76 -0
- modules/audio_tokenizer/quantize/__init__.py +3 -0
- modules/audio_tokenizer/quantize/factorized_vector_quantize.py +145 -0
- modules/audio_tokenizer/quantize/residual_vq.py +168 -0
- modules/audio_tokenizer/quantize/vector_quantize.py +396 -0
- modules/audio_tokenizer/rep_codec.py +197 -0
- modules/audio_tokenizer/transformer.py +234 -0
- modules/audio_tokenizer/vocos.py +845 -0
- modules/tokenizer/tokenizer.py +243 -0
- readme.md +39 -0
- requirements.txt +18 -0
- test/test_audio_detokenizer.py +34 -0
- test/test_audio_tokenizer.py +15 -0
- test/test_tokenizer.py +37 -0
- zh_prompt0.wav +3 -0
- zh_prompt1.wav +3 -0
.gitattributes
CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
en_prompt0.wav filter=lfs diff=lfs merge=lfs -text
|
37 |
+
en_prompt1.wav filter=lfs diff=lfs merge=lfs -text
|
38 |
+
zh_prompt0.wav filter=lfs diff=lfs merge=lfs -text
|
39 |
+
zh_prompt1.wav filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.safetensors
|
2 |
+
*.pt
|
3 |
+
*.vscode
|
4 |
+
**/__pycache__/
|
5 |
+
modules/audio_detokenizer/vocoder/alias_free_activation/cuda/build/
|
6 |
+
tmp*
|
7 |
+
resources/
|
.gradio/certificate.pem
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-----BEGIN CERTIFICATE-----
|
2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
31 |
+
-----END CERTIFICATE-----
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2025 Zeqian Ju
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,12 +1,6 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji: 🚀
|
4 |
-
colorFrom: red
|
5 |
-
colorTo: purple
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 5.23.0
|
8 |
app_file: app.py
|
9 |
-
|
|
|
10 |
---
|
11 |
-
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: mooncast
|
|
|
|
|
|
|
|
|
|
|
3 |
app_file: app.py
|
4 |
+
sdk: gradio
|
5 |
+
sdk_version: 5.22.0
|
6 |
---
|
|
|
|
app.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
from huggingface_hub import snapshot_download
|
5 |
+
from inference import Model
|
6 |
+
import base64
|
7 |
+
|
8 |
+
snapshot_download(repo_id="jzq11111/mooncast", local_dir='./resources/')
|
9 |
+
model = Model()
|
10 |
+
model.generate_config.max_new_tokens = 50 * 50 # no more than 20s per turn
|
11 |
+
|
12 |
+
|
13 |
+
def process_json_and_generate_audio(prompt_audio_role0_file, prompt_text_role0, prompt_audio_role1_file, prompt_text_role1, json_dialogue_input_str):
|
14 |
+
try:
|
15 |
+
print(json_dialogue_input_str, type(json_dialogue_input_str))
|
16 |
+
print(prompt_audio_role0_file, prompt_text_role0, prompt_audio_role1_file, prompt_text_role1)
|
17 |
+
# json_data = json.loads(json_dialogue_input_str)
|
18 |
+
json_data = eval(json_dialogue_input_str.strip())
|
19 |
+
print(json_data, type(json_data))
|
20 |
+
|
21 |
+
def validate_json(data):
|
22 |
+
try:
|
23 |
+
if not isinstance(data, list):
|
24 |
+
return "json must be a dictionary"
|
25 |
+
cur_spk_should_be = 0
|
26 |
+
for item in data:
|
27 |
+
if item['role'] != str(cur_spk_should_be):
|
28 |
+
return f"role should be {cur_spk_should_be} in item {item}"
|
29 |
+
cur_spk_should_be = 1 - cur_spk_should_be
|
30 |
+
return None
|
31 |
+
except Exception as e:
|
32 |
+
return str(e)
|
33 |
+
|
34 |
+
|
35 |
+
validation_error = validate_json(json_data)
|
36 |
+
if validation_error:
|
37 |
+
raise gr.Error(validation_error)
|
38 |
+
|
39 |
+
role_mapping = {
|
40 |
+
"0": {
|
41 |
+
"ref_audio": prompt_audio_role0_file,
|
42 |
+
"ref_text": prompt_text_role0,
|
43 |
+
},
|
44 |
+
"1": {
|
45 |
+
"ref_audio": prompt_audio_role1_file,
|
46 |
+
"ref_text": prompt_text_role1,
|
47 |
+
}
|
48 |
+
}
|
49 |
+
|
50 |
+
# 完整输入 JSON (你需要根据你的模型调整)
|
51 |
+
model_input_json = {
|
52 |
+
"role_mapping": role_mapping,
|
53 |
+
"dialogue": json_data, # 从用户输入的 JSON 中获取 dialogue
|
54 |
+
}
|
55 |
+
print("模型推理输入 JSON:", model_input_json)
|
56 |
+
|
57 |
+
|
58 |
+
# 4. **[重要] 调用你的 Model 类的 `inference` 方法**
|
59 |
+
# audio_bytes = model.inference(model_input_json)
|
60 |
+
|
61 |
+
# 5. 返回音频 bytes 给 Gradio (Gradio 会自动处理音频 bytes 并播放)
|
62 |
+
# return base64.b64decode(audio_bytes)
|
63 |
+
for cur_chunk in model.inference(model_input_json, streaming=True):
|
64 |
+
yield base64.b64decode(cur_chunk)
|
65 |
+
|
66 |
+
except Exception as e:
|
67 |
+
# return str(e) # 返回错误信息给 Gradio
|
68 |
+
raise gr.Error(str(e))
|
69 |
+
|
70 |
+
title_en = "# PODCAST generator (supports English and Chinese)"
|
71 |
+
title_zh = "# 播客生成 (支持英文和中文)"
|
72 |
+
|
73 |
+
input_labels_en = ["Prompt Audio for Role 0", "Prompt Text for Role 0", "Prompt Audio for Role 1", "Prompt Text for Role 1", "Dialogue JSON Input"]
|
74 |
+
input_labels_zh = ["角色 0 的 Prompt 音频", "角色 0 的 Prompt 文本", "角色 1 的 Prompt 音频", "角色 1 的 Prompt 文本", "对话 JSON 输入"]
|
75 |
+
|
76 |
+
output_label_en = "Generated Audio Output (streaming)"
|
77 |
+
output_label_zh = "生成的音频输出(流式)"
|
78 |
+
|
79 |
+
example_prompt_text_role0_en = "Yeah, no, this is my backyard. It's never ending So just the way I like it. So social distancing has never been a problem."
|
80 |
+
example_prompt_text_role0_zh = "可以每天都骑并且可能会让你爱上骑车,然后通过爱上骑车的你省了很多很多钱。"
|
81 |
+
example_prompt_text_role1_en = "I'm doing great And. Look, it couldn't be any better than having you at your set, which is the outdoors."
|
82 |
+
example_prompt_text_role1_zh = "他最后就能让同样食材炒出来的菜味道大大提升。"
|
83 |
+
|
84 |
+
text_placeholder_zh = "对话轮流进行, 每轮最多50秒。文本越自然, 生成的音频效果越好。"
|
85 |
+
text_placeholder_en = "Dialogue alternates between roles. Limit each turn to a maximum of 50 seconds. The more natural the text, the better the generated audio."
|
86 |
+
|
87 |
+
|
88 |
+
example_json_en = '''[
|
89 |
+
{
|
90 |
+
"role": "0",
|
91 |
+
"text": "In an awesome time, And, we're even gonna do a second episode too So. This is part one part two, coming at some point in the future There. We are.",
|
92 |
+
},
|
93 |
+
{
|
94 |
+
"role": "1",
|
95 |
+
"text": "I love it. So grateful Thank you So I'm really excited. That's awesome. Yeah.",
|
96 |
+
},
|
97 |
+
{
|
98 |
+
"role": "0",
|
99 |
+
"text": "All I was told, which is good because I don't want to really talk too much more is that you're really really into fitness and nutrition And overall holistic I love it Yes.",
|
100 |
+
},
|
101 |
+
{
|
102 |
+
"role": "1",
|
103 |
+
"text": "Yeah So I started around thirteen Okay But my parents were fitness instructors as well. Awesome So I came from the beginning, and now it's this transition into this wholeness because I had to chart my. Own path and they weren't into nutrition at all So I had to learn that part."
|
104 |
+
}
|
105 |
+
]'''
|
106 |
+
example_json_zh = '''[
|
107 |
+
{
|
108 |
+
"role": "0",
|
109 |
+
"text": "我觉得啊,就是经历了这么多���的经验, 就是补剂的作用就是九分的努力, 十分之一的补剂。 嗯,选的话肯定是九分更重要,但是我觉得补剂它能够让你九分的努力更加的有效率,更加的避免徒劳无功。 嗯,就是你,你你得先得真的锻炼,真的努力,真的健康饮食,然后再考虑补剂, 那你再加十十分之一的补剂的话,他可能就是说啊, 一半是心理作用,"
|
110 |
+
},
|
111 |
+
{
|
112 |
+
"role": "1",
|
113 |
+
"text": "对,其实很多时候心理作用是非常重要的。嗯,然后我每次用补剂的时候,我就会更加努力,就比如说我在健身之前我喝了一勺蛋白粉,我就会督促自己多练,"
|
114 |
+
},
|
115 |
+
{
|
116 |
+
"role": "0",
|
117 |
+
"text": "其实心理作用只要能实现你的预期目的就可以了。 就比如说给自行车链条加油, 它其实不是必要的,但是它可以让你骑行更顺畅, 然后提高你骑行的频率。"
|
118 |
+
}
|
119 |
+
]
|
120 |
+
'''
|
121 |
+
|
122 |
+
# examples_en = [
|
123 |
+
# ['./en_prompt0.wav', example_prompt_text_role0_en, './en_prompt1.wav', example_prompt_text_role1_en, example_json_en]
|
124 |
+
# ]
|
125 |
+
# examples_zh = [
|
126 |
+
# ['./zh_prompt0.wav', example_prompt_text_role0_zh, './zh_prompt1.wav', example_prompt_text_role1_zh, example_json_zh]
|
127 |
+
# ]
|
128 |
+
|
129 |
+
examples = [
|
130 |
+
['./en_prompt0.wav', example_prompt_text_role0_en, './en_prompt1.wav', example_prompt_text_role1_en, example_json_en],
|
131 |
+
['./zh_prompt0.wav', example_prompt_text_role0_zh, './zh_prompt1.wav', example_prompt_text_role1_zh, example_json_zh]
|
132 |
+
]
|
133 |
+
|
134 |
+
# -------------------- 更新界面元素的函数 --------------------
|
135 |
+
def update_ui_language(language):
|
136 |
+
if language == "English":
|
137 |
+
return gr.update(value=title_en), \
|
138 |
+
gr.update(label="UI Language"), \
|
139 |
+
gr.update(label=input_labels_en[0]), \
|
140 |
+
gr.update(label=input_labels_en[1]), \
|
141 |
+
gr.update(label=input_labels_en[2]), \
|
142 |
+
gr.update(label=input_labels_en[3]), \
|
143 |
+
gr.update(label=input_labels_en[4], placeholder=text_placeholder_en), \
|
144 |
+
gr.update(label=output_label_en), \
|
145 |
+
gr.update(value="Submit"), \
|
146 |
+
gr.update(label="Examples (Demonstration Use Only. Do Not Redistribute.)", headers=input_labels_en)
|
147 |
+
|
148 |
+
elif language == "中文":
|
149 |
+
return gr.update(value=title_zh), \
|
150 |
+
gr.update(label="UI 语言"), \
|
151 |
+
gr.update(label=input_labels_zh[0]), \
|
152 |
+
gr.update(label=input_labels_zh[1]), \
|
153 |
+
gr.update(label=input_labels_zh[2]), \
|
154 |
+
gr.update(label=input_labels_zh[3]), \
|
155 |
+
gr.update(label=input_labels_zh[4], placeholder=text_placeholder_zh), \
|
156 |
+
gr.update(label=output_label_zh), \
|
157 |
+
gr.update(value="提交"), \
|
158 |
+
gr.update(label="示例 (仅用于展示,切勿私自传播。)", headers=input_labels_zh)
|
159 |
+
|
160 |
+
else:
|
161 |
+
raise ValueError("Invalid language selected")
|
162 |
+
|
163 |
+
|
164 |
+
audio_output = gr.Audio(label=output_label_en, streaming=True)
|
165 |
+
css = """
|
166 |
+
.centered-title { /* CSS rule for centering title */
|
167 |
+
text-align: center !important;
|
168 |
+
}
|
169 |
+
"""
|
170 |
+
# -------------------- Gradio 界面定义 (修改) --------------------
|
171 |
+
with gr.Blocks(css=css) as iface:
|
172 |
+
|
173 |
+
title_output = gr.Markdown(value=title_zh, elem_classes="centered-title")
|
174 |
+
language_choice = gr.Radio(["中文", "English"], value="中文", label="UI语言")
|
175 |
+
|
176 |
+
with gr.Row(): # Main row to create two columns
|
177 |
+
with gr.Column(scale=2):
|
178 |
+
json_input = gr.TextArea(label=input_labels_zh[4], lines=15, placeholder=text_placeholder_zh) # Dialogue JSON Input
|
179 |
+
|
180 |
+
with gr.Column(scale=1): # Right column (narrower - scale=1) for prompt inputs
|
181 |
+
audio_input_role0 = gr.Audio(type="filepath", label=input_labels_zh[0]) # Prompt Audio for Role 0
|
182 |
+
text_input_role0 = gr.TextArea(label=input_labels_zh[1], lines=2) # Prompt Text for Role 0
|
183 |
+
|
184 |
+
with gr.Column(scale=1): #
|
185 |
+
audio_input_role1 = gr.Audio(type="filepath", label=input_labels_zh[2]) # Prompt Audio for Role 1
|
186 |
+
text_input_role1 = gr.TextArea(label=input_labels_zh[3], lines=2) # Prompt Text for Role 1
|
187 |
+
|
188 |
+
examples_component = gr.Examples(
|
189 |
+
examples=examples,
|
190 |
+
inputs=[audio_input_role0, text_input_role0, audio_input_role1, text_input_role1, json_input],
|
191 |
+
cache_examples=False,
|
192 |
+
label="示例(仅用于展示,切勿私自传播。)",
|
193 |
+
)
|
194 |
+
|
195 |
+
submit_button = gr.Button("提交")
|
196 |
+
|
197 |
+
submit_button.click(
|
198 |
+
fn=process_json_and_generate_audio,
|
199 |
+
inputs=[audio_input_role0, text_input_role0, audio_input_role1, text_input_role1, json_input],
|
200 |
+
outputs=audio_output
|
201 |
+
)
|
202 |
+
audio_output.render()
|
203 |
+
|
204 |
+
language_choice.change(
|
205 |
+
fn=update_ui_language,
|
206 |
+
inputs=language_choice,
|
207 |
+
outputs=[title_output, language_choice, audio_input_role0, text_input_role0, audio_input_role1, text_input_role1, json_input, audio_output, submit_button, examples_component.dataset]
|
208 |
+
)
|
209 |
+
|
210 |
+
|
211 |
+
iface.launch(share=True)
|
download_pretrain.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from huggingface_hub import snapshot_download
|
2 |
+
snapshot_download(repo_id="jzq11111/mooncast", local_dir='./resources/')
|
en_prompt0.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d789de64f286dd5d0198e67310bb334d9c0bfe95b3369747ccab58ae614c3292
|
3 |
+
size 684388
|
en_prompt1.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:407697746e6a6da1d540fe06379c54c36dd3b1a06fb3a789eb457b981d2cc7f4
|
3 |
+
size 615116
|
inference.py
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import sys
|
3 |
+
sys.path.append(".")
|
4 |
+
from modules.tokenizer.tokenizer import get_tokenizer_and_extra_tokens
|
5 |
+
from modules.audio_tokenizer.audio_tokenizer import get_audio_tokenizer
|
6 |
+
from modules.audio_detokenizer.audio_detokenizer import get_audio_detokenizer, detokenize, detokenize_noref
|
7 |
+
import torch
|
8 |
+
import os
|
9 |
+
from glob import glob
|
10 |
+
import base64
|
11 |
+
import io
|
12 |
+
import torchaudio
|
13 |
+
from transformers import AutoModelForCausalLM, GenerationConfig
|
14 |
+
import librosa
|
15 |
+
from tqdm import tqdm
|
16 |
+
|
17 |
+
class Model(object):
|
18 |
+
def __init__(self):
|
19 |
+
|
20 |
+
|
21 |
+
self.tokenizer, self.extra_tokens = get_tokenizer_and_extra_tokens()
|
22 |
+
self.speech_token_offset = 163840
|
23 |
+
print(self.extra_tokens)
|
24 |
+
self.assistant_ids = self.tokenizer.encode("assistant") # [110866]
|
25 |
+
self.user_ids = self.tokenizer.encode("user") # [1495]
|
26 |
+
self.audio_ids = self.tokenizer.encode("audio") # [26229]
|
27 |
+
self.spk_0_ids = self.tokenizer.encode("0") # [501]
|
28 |
+
self.spk_1_ids = self.tokenizer.encode("1") # [503]
|
29 |
+
|
30 |
+
self.msg_end = self.extra_tokens.msg_end # 260
|
31 |
+
self.user_msg_start = self.extra_tokens.user_msg_start # 261
|
32 |
+
self.assistant_msg_start = self.extra_tokens.assistant_msg_start # 262
|
33 |
+
self.name_end = self.extra_tokens.name_end # 272
|
34 |
+
self.media_begin = self.extra_tokens.media_begin # 273
|
35 |
+
self.media_content = self.extra_tokens.media_content # 274
|
36 |
+
self.media_end = self.extra_tokens.media_end # 275
|
37 |
+
|
38 |
+
self.audio_tokenizer = get_audio_tokenizer()
|
39 |
+
self.audio_detokenizer = get_audio_detokenizer()
|
40 |
+
model_path = "resources/text2semantic"
|
41 |
+
self.model = AutoModelForCausalLM.from_pretrained(model_path, device_map="cuda:0", torch_dtype=torch.bfloat16, trust_remote_code=True, force_download=True).to(torch.cuda.current_device())
|
42 |
+
self.generate_config = GenerationConfig(
|
43 |
+
max_new_tokens=200 * 50, # no more than 200s per turn
|
44 |
+
do_sample=True,
|
45 |
+
top_k=30,
|
46 |
+
top_p=0.8,
|
47 |
+
temperature=0.8,
|
48 |
+
eos_token_id=self.media_end,
|
49 |
+
)
|
50 |
+
|
51 |
+
def _clean_text(self, text):
|
52 |
+
# you can add front-end processing here
|
53 |
+
text = text.replace("“", "")
|
54 |
+
text = text.replace("”", "")
|
55 |
+
text = text.replace("...", " ")
|
56 |
+
text = text.replace("…", " ")
|
57 |
+
text = text.replace("*", "")
|
58 |
+
text = text.replace(":", ",")
|
59 |
+
text = text.replace("‘", "'")
|
60 |
+
text = text.replace("’", "'")
|
61 |
+
text = text.strip()
|
62 |
+
return text
|
63 |
+
|
64 |
+
@torch.inference_mode()
|
65 |
+
def _process_text(self, js):
|
66 |
+
|
67 |
+
if "role_mapping" in js:
|
68 |
+
for role in js["role_mapping"].keys():
|
69 |
+
js["role_mapping"][role]["ref_bpe_ids"] = self.tokenizer.encode(self._clean_text(js["role_mapping"][role]["ref_text"]))
|
70 |
+
|
71 |
+
for turn in js["dialogue"]:
|
72 |
+
turn["bpe_ids"] = self.tokenizer.encode(self._clean_text(turn["text"]))
|
73 |
+
return js
|
74 |
+
|
75 |
+
def inference(self, js, streaming=False):
|
76 |
+
js = self._process_text(js)
|
77 |
+
if "role_mapping" not in js:
|
78 |
+
return self.infer_without_prompt(js, streaming)
|
79 |
+
else:
|
80 |
+
return self.infer_with_prompt(js, streaming)
|
81 |
+
|
82 |
+
@torch.inference_mode()
|
83 |
+
def infer_with_prompt(self, js, streaming=False):
|
84 |
+
user_role_0_ids = [self.user_msg_start] + self.user_ids + self.spk_0_ids + [self.name_end]
|
85 |
+
user_role_1_ids = [self.user_msg_start] + self.user_ids + self.spk_1_ids + [self.name_end]
|
86 |
+
assistant_role_0_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_0_ids + [self.name_end]
|
87 |
+
assistant_role_1_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_1_ids + [self.name_end]
|
88 |
+
|
89 |
+
media_start = [self.media_begin] + self.audio_ids + [self.media_content]
|
90 |
+
media_end = [self.media_end] + [self.msg_end]
|
91 |
+
|
92 |
+
assistant_role_0_ids = torch.LongTensor(assistant_role_0_ids).unsqueeze(0).to(torch.cuda.current_device())
|
93 |
+
assistant_role_1_ids = torch.LongTensor(assistant_role_1_ids).unsqueeze(0).to(torch.cuda.current_device())
|
94 |
+
media_start = torch.LongTensor(media_start).unsqueeze(0).to(torch.cuda.current_device())
|
95 |
+
media_end = torch.LongTensor(media_end).unsqueeze(0).to(torch.cuda.current_device())
|
96 |
+
|
97 |
+
|
98 |
+
prompt = []
|
99 |
+
cur_role_dict = dict()
|
100 |
+
for role, role_item in js["role_mapping"].items():
|
101 |
+
waveform_24k = librosa.load(role_item["ref_audio"], sr=24000)[0]
|
102 |
+
waveform_24k = torch.tensor(waveform_24k).unsqueeze(0).to(torch.cuda.current_device())
|
103 |
+
|
104 |
+
waveform_16k = librosa.load(role_item["ref_audio"], sr=16000)[0]
|
105 |
+
waveform_16k = torch.tensor(waveform_16k).unsqueeze(0).to(torch.cuda.current_device())
|
106 |
+
|
107 |
+
semantic_tokens = self.audio_tokenizer.tokenize(waveform_16k)
|
108 |
+
semantic_tokens = semantic_tokens.to(torch.cuda.current_device())
|
109 |
+
prompt_ids = semantic_tokens + self.speech_token_offset
|
110 |
+
|
111 |
+
cur_role_dict[role] = {
|
112 |
+
"ref_bpe_ids": role_item["ref_bpe_ids"],
|
113 |
+
"wav_24k": waveform_24k,
|
114 |
+
"semantic_tokens": semantic_tokens,
|
115 |
+
"prompt_ids": prompt_ids
|
116 |
+
}
|
117 |
+
|
118 |
+
prompt = prompt + user_role_0_ids + cur_role_dict["0"]["ref_bpe_ids"] + [self.msg_end]
|
119 |
+
prompt = prompt + user_role_1_ids + cur_role_dict["1"]["ref_bpe_ids"] + [self.msg_end]
|
120 |
+
|
121 |
+
for seg_id, turn in enumerate(js["dialogue"]):
|
122 |
+
role_id = turn["role"]
|
123 |
+
cur_user_ids = user_role_0_ids if role_id == "0" else user_role_1_ids
|
124 |
+
cur_start_ids = cur_user_ids + turn["bpe_ids"] + [self.msg_end]
|
125 |
+
prompt = prompt + cur_start_ids
|
126 |
+
|
127 |
+
prompt = torch.LongTensor(prompt).unsqueeze(0).to(torch.cuda.current_device())
|
128 |
+
|
129 |
+
prompt = torch.cat([prompt, assistant_role_0_ids, media_start, cur_role_dict["0"]["prompt_ids"], media_end], dim=-1)
|
130 |
+
prompt = torch.cat([prompt, assistant_role_1_ids, media_start, cur_role_dict["1"]["prompt_ids"], media_end], dim=-1)
|
131 |
+
|
132 |
+
|
133 |
+
generation_config = self.generate_config
|
134 |
+
# you can modify sampling strategy here
|
135 |
+
|
136 |
+
wav_list = []
|
137 |
+
for seg_id, turn in tqdm(enumerate(js["dialogue"])):
|
138 |
+
role_id = turn["role"]
|
139 |
+
cur_assistant_ids = assistant_role_0_ids if role_id == "0" else assistant_role_1_ids
|
140 |
+
prompt = torch.cat([prompt, cur_assistant_ids, media_start], dim=-1)
|
141 |
+
len_prompt = prompt.shape[1]
|
142 |
+
generation_config.min_length = len_prompt + 2
|
143 |
+
# print(generation_config)
|
144 |
+
# todo: add streaming support for generate function
|
145 |
+
outputs = self.model.generate(prompt,
|
146 |
+
generation_config=generation_config)
|
147 |
+
if outputs[0, -1] == self.media_end:
|
148 |
+
outputs = outputs[:, :-1]
|
149 |
+
output_token = outputs[:, len_prompt:]
|
150 |
+
prompt = torch.cat([outputs, media_end], dim=-1)
|
151 |
+
|
152 |
+
torch_token = output_token - self.speech_token_offset
|
153 |
+
if streaming:
|
154 |
+
# gen_speech_fm = detokenize(self.audio_detokenizer, torch_token, cur_role_dict[role_id]["wav_24k"], cur_role_dict[role_id]["semantic_tokens"])
|
155 |
+
# yield detokenize(self.audio_detokenizer, torch_token, cur_role_dict[role_id]["wav_24k"], cur_role_dict[role_id]["semantic_tokens"])
|
156 |
+
for cur_chunk in detokenize(self.audio_detokenizer, torch_token, cur_role_dict[role_id]["wav_24k"], cur_role_dict[role_id]["semantic_tokens"], streaming=True):
|
157 |
+
cur_chunk = cur_chunk.cpu()
|
158 |
+
cur_chunk = cur_chunk / cur_chunk.abs().max()
|
159 |
+
cur_buffer = io.BytesIO()
|
160 |
+
torchaudio.save(cur_buffer, cur_chunk, sample_rate=24000, format="mp3")
|
161 |
+
audio_bytes = cur_buffer.getvalue()
|
162 |
+
audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
|
163 |
+
yield audio_b64
|
164 |
+
else:
|
165 |
+
gen_speech_fm = detokenize(self.audio_detokenizer, torch_token, cur_role_dict[role_id]["wav_24k"], cur_role_dict[role_id]["semantic_tokens"])
|
166 |
+
gen_speech_fm = gen_speech_fm.cpu()
|
167 |
+
gen_speech_fm = gen_speech_fm / gen_speech_fm.abs().max()
|
168 |
+
wav_list.append(gen_speech_fm)
|
169 |
+
del torch_token
|
170 |
+
if not streaming:
|
171 |
+
concat_wav = torch.cat(wav_list, dim=-1).cpu()
|
172 |
+
# print(concat_wav.shape)
|
173 |
+
buffer = io.BytesIO()
|
174 |
+
torchaudio.save(buffer, concat_wav, sample_rate=24000, format="mp3")
|
175 |
+
audio_bytes = buffer.getvalue()
|
176 |
+
audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
|
177 |
+
return audio_b64
|
178 |
+
|
179 |
+
@torch.inference_mode()
|
180 |
+
def infer_without_prompt(self, js, streaming=False):
|
181 |
+
user_role_0_ids = [self.user_msg_start] + self.user_ids + self.spk_0_ids + [self.name_end]
|
182 |
+
user_role_1_ids = [self.user_msg_start] + self.user_ids + self.spk_1_ids + [self.name_end]
|
183 |
+
assistant_role_0_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_0_ids + [self.name_end]
|
184 |
+
assistant_role_1_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_1_ids + [self.name_end]
|
185 |
+
|
186 |
+
media_start = [self.media_begin] + self.audio_ids + [self.media_content]
|
187 |
+
media_end = [self.media_end] + [self.msg_end]
|
188 |
+
|
189 |
+
assistant_role_0_ids = torch.LongTensor(assistant_role_0_ids).unsqueeze(0).to(torch.cuda.current_device())
|
190 |
+
assistant_role_1_ids = torch.LongTensor(assistant_role_1_ids).unsqueeze(0).to(torch.cuda.current_device())
|
191 |
+
media_start = torch.LongTensor(media_start).unsqueeze(0).to(torch.cuda.current_device())
|
192 |
+
media_end = torch.LongTensor(media_end).unsqueeze(0).to(torch.cuda.current_device())
|
193 |
+
|
194 |
+
|
195 |
+
prompt = []
|
196 |
+
for seg_id, turn in enumerate(js["dialogue"]):
|
197 |
+
role_id = turn["role"]
|
198 |
+
cur_user_ids = user_role_0_ids if role_id == "0" else user_role_1_ids
|
199 |
+
cur_start_ids = cur_user_ids + turn["bpe_ids"] + [self.msg_end]
|
200 |
+
prompt = prompt + cur_start_ids
|
201 |
+
|
202 |
+
prompt = torch.LongTensor(prompt).unsqueeze(0).to(torch.cuda.current_device())
|
203 |
+
generation_config = self.generate_config
|
204 |
+
# you can modify sampling strategy here
|
205 |
+
|
206 |
+
wav_list = []
|
207 |
+
for seg_id, turn in tqdm(enumerate(js["dialogue"])):
|
208 |
+
role_id = turn["role"]
|
209 |
+
cur_assistant_ids = assistant_role_0_ids if role_id == "0" else assistant_role_1_ids
|
210 |
+
prompt = torch.cat([prompt, cur_assistant_ids, media_start], dim=-1)
|
211 |
+
len_prompt = prompt.shape[1]
|
212 |
+
generation_config.min_length = len_prompt + 2
|
213 |
+
# print(generation_config)
|
214 |
+
# todo: add streaming support for generate function
|
215 |
+
outputs = self.model.generate(prompt,
|
216 |
+
generation_config=generation_config)
|
217 |
+
if outputs[0, -1] == self.media_end:
|
218 |
+
outputs = outputs[:, :-1]
|
219 |
+
output_token = outputs[:, len_prompt:]
|
220 |
+
prompt = torch.cat([outputs, media_end], dim=-1)
|
221 |
+
|
222 |
+
torch_token = output_token - self.speech_token_offset
|
223 |
+
if streaming:
|
224 |
+
for cur_chunk in detokenize_noref(self.audio_detokenizer, torch_token, streaming=True):
|
225 |
+
cur_chunk = cur_chunk.cpu()
|
226 |
+
cur_chunk = cur_chunk / cur_chunk.abs().max()
|
227 |
+
cur_buffer = io.BytesIO()
|
228 |
+
torchaudio.save(cur_buffer, cur_chunk, sample_rate=24000, format="mp3")
|
229 |
+
audio_bytes = cur_buffer.getvalue()
|
230 |
+
audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
|
231 |
+
yield audio_b64
|
232 |
+
else:
|
233 |
+
gen_speech_fm = detokenize_noref(self.audio_detokenizer, torch_token)
|
234 |
+
gen_speech_fm = gen_speech_fm.cpu()
|
235 |
+
gen_speech_fm = gen_speech_fm / gen_speech_fm.abs().max()
|
236 |
+
wav_list.append(gen_speech_fm)
|
237 |
+
del torch_token
|
238 |
+
|
239 |
+
if not streaming:
|
240 |
+
concat_wav = torch.cat(wav_list, dim=-1).cpu()
|
241 |
+
# print(concat_wav.shape)
|
242 |
+
buffer = io.BytesIO()
|
243 |
+
torchaudio.save(buffer, concat_wav, sample_rate=24000, format="mp3")
|
244 |
+
audio_bytes = buffer.getvalue()
|
245 |
+
audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
|
246 |
+
return audio_b64
|
247 |
+
|
248 |
+
|
249 |
+
if __name__ == "__main__":
|
250 |
+
model = Model()
|
251 |
+
|
252 |
+
# speaker should be interleaved
|
253 |
+
zh_test_json = {
|
254 |
+
"role_mapping": {
|
255 |
+
"0": {
|
256 |
+
"ref_audio": "./zh_prompt0.wav",
|
257 |
+
"ref_text": "可以每天都骑并且可能会让你爱上骑车,然后通过爱上骑车的你省了很多很多钱。", #asr output
|
258 |
+
},
|
259 |
+
"1": {
|
260 |
+
"ref_audio": "./zh_prompt1.wav",
|
261 |
+
"ref_text": "他最后就能让同样食材炒出来的菜味道大大提升。" #asr output
|
262 |
+
}
|
263 |
+
},
|
264 |
+
"dialogue": [
|
265 |
+
{
|
266 |
+
"role": "0",
|
267 |
+
"text": "我觉得啊,就是经历了这么多年的经验, 就是补剂的作用就是九分的努力, 十分之一的补剂。 嗯,选的话肯定是九分更重要,但是我觉得补剂它能够让你九分的努力更加的有效率,更加的避免徒劳无功。 嗯,就是你,你你得先得真的锻炼,真的努力,真的健康饮食,然后再考虑补剂, 那你再加十十分之一的补剂的话,他可能就是说啊, 一半是心理作用,"
|
268 |
+
},
|
269 |
+
{
|
270 |
+
"role": "1",
|
271 |
+
"text": "对,其实很多时候心理作用是非常重要的。嗯,然后我每次用补剂的时候,我就会更加努力,就比如说我在健身之前我喝了一勺蛋白粉,我就会督促自己多练,"
|
272 |
+
},
|
273 |
+
{
|
274 |
+
"role": "0",
|
275 |
+
"text": "其实心理作用只要能实现你的预期目的就可以了。 就比如说给自行车链条加油, 它其实不是必要的,但是它可以让你骑行更顺畅, 然后提高你骑行的频率。"
|
276 |
+
}
|
277 |
+
]
|
278 |
+
}
|
279 |
+
audio_bytes = model.inference(zh_test_json)
|
280 |
+
file_to_save = open(f"tmp_generated_zh.mp3", "wb")
|
281 |
+
file_to_save.write(base64.b64decode(audio_bytes))
|
282 |
+
print("zh done")
|
283 |
+
|
284 |
+
# speaker should be interleaved
|
285 |
+
en_test_json = {
|
286 |
+
"role_mapping": {
|
287 |
+
"0": {
|
288 |
+
"ref_audio": "./en_prompt0.wav",
|
289 |
+
"ref_text": "Yeah, no, this is my backyard. It's never ending So just the way I like it. So social distancing has never been a problem.", #asr output
|
290 |
+
},
|
291 |
+
"1": {
|
292 |
+
"ref_audio": "./en_prompt1.wav",
|
293 |
+
"ref_text": "I'm doing great And. Look, it couldn't be any better than having you at your set, which is the outdoors." #asr output
|
294 |
+
}
|
295 |
+
},
|
296 |
+
"dialogue": [
|
297 |
+
{
|
298 |
+
"role": "0",
|
299 |
+
"text": "In an awesome time, And, we're even gonna do a second episode too So. This is part one part two, coming at some point in the future There. We are.",
|
300 |
+
},
|
301 |
+
{
|
302 |
+
"role": "1",
|
303 |
+
"text": "I love it. So grateful Thank you So I'm really excited. That's awesome. Yeah."
|
304 |
+
},
|
305 |
+
{
|
306 |
+
"role": "0",
|
307 |
+
"text": "All I was told, which is good because I don't want to really talk too much more is that you're really really into fitness and nutrition And overall holistic I love it Yes."
|
308 |
+
},
|
309 |
+
{
|
310 |
+
"role": "1",
|
311 |
+
"text": "Yeah So I started around thirteen Okay But my parents were fitness instructors as well. Awesome So I came from the beginning, and now it's this transition into this wholeness because I had to chart my. Own path and they weren't into nutrition at all So I had to learn that part."
|
312 |
+
}
|
313 |
+
]
|
314 |
+
}
|
315 |
+
audio_bytes = model.inference(en_test_json)
|
316 |
+
file_to_save = open(f"tmp_generated_en.mp3", "wb")
|
317 |
+
file_to_save.write(base64.b64decode(audio_bytes))
|
318 |
+
print("en done")
|
319 |
+
|
320 |
+
|
321 |
+
# also support inference without prompt
|
322 |
+
# speaker should be interleaved
|
323 |
+
without_prompt_test_json = {
|
324 |
+
"dialogue": [
|
325 |
+
{
|
326 |
+
"role": "0",
|
327 |
+
"text": "我觉得啊,就是经历了这么多年的经验, 就是补剂的作用就是九分的努力, 十分之一的补剂。 嗯,选的话肯定是九分更重要,但是我觉得补剂它能够让你九分的努力更加的有效率,更加的避免徒劳无功。 嗯,就是你,你你得先得真的锻炼,真的努力,真的健康饮食,然后再考虑补剂, 那你再加十十分之一的补剂的话,他可能就是说啊, 一半是心理作用,"
|
328 |
+
},
|
329 |
+
{
|
330 |
+
"role": "1",
|
331 |
+
"text": "对,其实很多时候心理作用是非常重要的。嗯,然后我每次用补剂的时候,我就会更加努力,就比如说我在健身之前我喝了一勺蛋白粉,我就会督促自己多练,"
|
332 |
+
},
|
333 |
+
{
|
334 |
+
"role": "0",
|
335 |
+
"text": "其实心理作用只要能实现你的预期目的就可以了。 就比如说给自行车链条加油, 它其实不是必要的,但是它可以让你骑行更顺畅, 然后提高你骑行的频率。"
|
336 |
+
}
|
337 |
+
]
|
338 |
+
}
|
339 |
+
audio_bytes = model.inference(without_prompt_test_json)
|
340 |
+
file_to_save = open(f"tmp_generated_woprompt.mp3", "wb")
|
341 |
+
file_to_save.write(base64.b64decode(audio_bytes))
|
342 |
+
print("without prompt done")
|
modules/audio_detokenizer/audio_detokenizer.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from modules.audio_detokenizer.bigvgan_wrapper import BigVGANWrapper
|
5 |
+
from modules.audio_detokenizer.semantic_fm_prefix_streaming import StreamingSemanticFMWrapper
|
6 |
+
|
7 |
+
|
8 |
+
class PrefixStreamingFlowMatchingDetokenizer:
|
9 |
+
def __init__(self, vocoder: BigVGANWrapper, fm: StreamingSemanticFMWrapper, look_ahead_tokens: int = 0) -> None:
|
10 |
+
self.dtype = torch.bfloat16
|
11 |
+
|
12 |
+
print("Currently using bfloat16 for PrefixFlowMatchingDetokenizer")
|
13 |
+
|
14 |
+
self.vocoder = vocoder
|
15 |
+
self.vocoder.to_dtype(self.dtype)
|
16 |
+
|
17 |
+
self.semantic_fm = fm
|
18 |
+
|
19 |
+
# initialize mel_spec
|
20 |
+
self.max_pos_size = 4096
|
21 |
+
self.is_timbre_semantic_token = False
|
22 |
+
self.pre_mel = None
|
23 |
+
self.frame_size = 480 # how many samples in a frame
|
24 |
+
self.pre_wav = None
|
25 |
+
self.state_dict_backup = None
|
26 |
+
self.hamming_window_cache = {}
|
27 |
+
self.previous_chunk_left = None
|
28 |
+
self.look_ahead_tokens = look_ahead_tokens
|
29 |
+
|
30 |
+
self.clear_states()
|
31 |
+
|
32 |
+
|
33 |
+
@classmethod
|
34 |
+
def from_pretrained(cls, vocoder_config, vocoder_ckpt, fm_config, fm_ckpt, device,
|
35 |
+
look_ahead_tokens=0,
|
36 |
+
max_prompt_chunk=2, max_kv_cache_tokens=900,
|
37 |
+
use_cfg=False, use_cfg_rescale=True, cfg_init=1.5, cfg_scale=7.5, cfg_schedule="linear"):
|
38 |
+
bigvgan = BigVGANWrapper.from_pretrained(vocoder_config, vocoder_ckpt, device)
|
39 |
+
semantic_fm = StreamingSemanticFMWrapper.from_pretrained(fm_config, fm_ckpt, device, max_prompt_chunk=max_prompt_chunk, max_kv_cache_tokens=max_kv_cache_tokens,
|
40 |
+
use_cfg=use_cfg, cfg_scale=cfg_scale, use_cfg_rescale=use_cfg_rescale, cfg_init=cfg_init, cfg_schedule=cfg_schedule)
|
41 |
+
return cls(bigvgan, semantic_fm, look_ahead_tokens=look_ahead_tokens)
|
42 |
+
|
43 |
+
@torch.inference_mode()
|
44 |
+
def prefill(self, timbre_speech, timbre_semantic_token, chunk_size: int, timbre_mel=None):
|
45 |
+
"""
|
46 |
+
Arguments:
|
47 |
+
timbre_speech: torch.Tensor, shape [B, N_speech_24k]
|
48 |
+
timbre_semantic_token: torch.Tensor, shape [B, N]
|
49 |
+
chunk_size: int, chunk size for prefilling
|
50 |
+
timbre_mel: torch.Tensor, shape [B, N, 80], optional, if not None, use this mel spectrogram instead of extracting from timbre_speech
|
51 |
+
"""
|
52 |
+
if timbre_mel is None:
|
53 |
+
assert timbre_speech is not None, "timbre_speech should not be None if timbre_mel is not None"
|
54 |
+
assert len(timbre_semantic_token.shape) == 2 and len(timbre_speech.shape) == 2 and chunk_size > 0
|
55 |
+
assert timbre_speech.shape[0] == 1 and timbre_semantic_token.shape[0] == 1
|
56 |
+
|
57 |
+
mel_spec = self.vocoder.extract_mel_from_wav(wav_data=timbre_speech.squeeze(0))
|
58 |
+
else:
|
59 |
+
assert len(timbre_mel.shape) == 3 and len(timbre_semantic_token.shape) == 2 and chunk_size > 0
|
60 |
+
assert timbre_mel.shape[0] == 1 and timbre_semantic_token.shape[0] == 1
|
61 |
+
mel_spec = timbre_mel.squeeze(0)
|
62 |
+
|
63 |
+
if mel_spec.shape[0] < timbre_semantic_token.shape[1]:
|
64 |
+
# pad mel_spec
|
65 |
+
mel_spec = torch.nn.functional.pad(mel_spec, (0, 0, 0, timbre_semantic_token.shape[1] - mel_spec.shape[0]))
|
66 |
+
elif mel_spec.shape[0] > timbre_semantic_token.shape[1]:
|
67 |
+
# truncate mel_spec
|
68 |
+
mel_spec = mel_spec[:timbre_semantic_token.shape[1], :]
|
69 |
+
|
70 |
+
# clear all states
|
71 |
+
self.semantic_fm.clear_all_states()
|
72 |
+
self.semantic_fm.prefill(mel_spec, timbre_semantic_token.squeeze(0), chunk_size=chunk_size, verbose=False)
|
73 |
+
self.state_dict_backup = self.semantic_fm.state_dict()
|
74 |
+
|
75 |
+
@torch.inference_mode()
|
76 |
+
def detokenize_streaming(self, semantic_token, ode_step=30, verbose=False, ode_solver="neural_ode_euler", is_final=False, upsample_factor=1):
|
77 |
+
assert len(semantic_token.shape) == 2 and ode_step > 0
|
78 |
+
assert semantic_token.shape[0] == 1
|
79 |
+
|
80 |
+
semantic_token = semantic_token.repeat_interleave(upsample_factor, dim=1)
|
81 |
+
|
82 |
+
semantic_token = semantic_token.squeeze(0)
|
83 |
+
|
84 |
+
if self.look_ahead_tokens != 0 and self.previous_chunk_left is not None:
|
85 |
+
semantic_token_previous = self.previous_chunk_left["semantic_token"]
|
86 |
+
semantic_token = torch.cat([semantic_token_previous, semantic_token], dim=-1)
|
87 |
+
|
88 |
+
x_t_chunk = torch.randn(semantic_token.shape[0], 80).to(semantic_token.device).to(self.dtype)
|
89 |
+
|
90 |
+
if self.look_ahead_tokens != 0 and self.previous_chunk_left is None:
|
91 |
+
self.previous_chunk_left = {"semantic_token": None}
|
92 |
+
|
93 |
+
speech_mel = self.semantic_fm.infer_chunk(
|
94 |
+
xt_chunk=x_t_chunk,
|
95 |
+
semantic_tokens_chunk=semantic_token,
|
96 |
+
start_position_id=self.semantic_fm.start_position_id,
|
97 |
+
ode_steps=ode_step,
|
98 |
+
verbose=verbose,
|
99 |
+
look_ahead_tokens=self.look_ahead_tokens * upsample_factor if not is_final else 0,
|
100 |
+
cache=self.previous_chunk_left,
|
101 |
+
ode_solver=ode_solver
|
102 |
+
)
|
103 |
+
|
104 |
+
chunk_size = speech_mel.shape[0]
|
105 |
+
length = speech_mel.shape[0]
|
106 |
+
self.semantic_fm.start_position_id += length
|
107 |
+
self.semantic_fm.update_incremental_state()
|
108 |
+
self.semantic_fm.reserve_kv_cache_tokens += self.semantic_fm.ode_wrapper.kv_cache_tokens
|
109 |
+
|
110 |
+
# smoothing
|
111 |
+
|
112 |
+
# I will maintain the history of seqlen wav
|
113 |
+
# For the first chunk, I will only return the half chunk wav, and save the res half chunk in history
|
114 |
+
# For the rest requests, I will concat the generated wav with the history, output one chunk of the history, save the
|
115 |
+
|
116 |
+
if self.pre_mel is None: # first chunk, related to TTFB
|
117 |
+
concat_mel = speech_mel
|
118 |
+
concat_reconstructed_wav = self.vocoder.decode_mel(concat_mel)
|
119 |
+
if is_final:
|
120 |
+
self.clear_states()
|
121 |
+
self.state_dict_backup = None
|
122 |
+
ret_wav = concat_reconstructed_wav.float()
|
123 |
+
else:
|
124 |
+
reconstructed_wav = concat_reconstructed_wav[:, :int(self.frame_size * chunk_size // 2)] # return the first half chunk
|
125 |
+
|
126 |
+
self.pre_wav = concat_reconstructed_wav[:, -int(self.frame_size * chunk_size // 2):] # log the last half chunk for next generation step
|
127 |
+
self.pre_mel = speech_mel[-chunk_size//2:, :]
|
128 |
+
|
129 |
+
ret_wav = reconstructed_wav.float()
|
130 |
+
else:
|
131 |
+
concat_mel = torch.cat([self.pre_mel, speech_mel], dim=0)
|
132 |
+
concat_reconstructed_wav = self.vocoder.decode_mel(concat_mel)
|
133 |
+
|
134 |
+
if is_final:
|
135 |
+
self.clear_states()
|
136 |
+
self.state_dict_backup = None
|
137 |
+
ret_wav = concat_reconstructed_wav.float()
|
138 |
+
else:
|
139 |
+
# fetch history
|
140 |
+
prev_speech_len = self.pre_wav.shape[1]
|
141 |
+
|
142 |
+
if concat_reconstructed_wav.shape[1] > prev_speech_len * 2:
|
143 |
+
gen_speech_len = prev_speech_len * 2
|
144 |
+
else:
|
145 |
+
gen_speech_len = concat_reconstructed_wav.shape[1] // 2
|
146 |
+
|
147 |
+
|
148 |
+
reconstructed_wav = concat_reconstructed_wav[:, :gen_speech_len] # return the first half chunk
|
149 |
+
|
150 |
+
if gen_speech_len not in self.hamming_window_cache:
|
151 |
+
self.hamming_window_cache[gen_speech_len] = torch.hamming_window(gen_speech_len).to(self.dtype).to(semantic_token.device).unsqueeze(0)
|
152 |
+
|
153 |
+
hamming_window = self.hamming_window_cache[gen_speech_len]
|
154 |
+
|
155 |
+
|
156 |
+
# apply smoothing of the first half chunk
|
157 |
+
reconstructed_wav[:, :int(gen_speech_len // 2 )] = self.pre_wav[:, :int(gen_speech_len // 2 )] * hamming_window[:,-int(gen_speech_len // 2):] + \
|
158 |
+
reconstructed_wav[:, :int(gen_speech_len // 2)] * hamming_window[:, :int(gen_speech_len // 2)]
|
159 |
+
|
160 |
+
res_speech_len = concat_reconstructed_wav.shape[1] - gen_speech_len
|
161 |
+
res_mel_len = res_speech_len // self.frame_size
|
162 |
+
|
163 |
+
self.pre_wav = concat_reconstructed_wav[:, -res_speech_len:]
|
164 |
+
self.pre_mel = speech_mel[-res_mel_len:, :]
|
165 |
+
ret_wav = reconstructed_wav.float()
|
166 |
+
|
167 |
+
if not is_final and self.semantic_fm.start_position_id + 2*chunk_size > self.max_pos_size:
|
168 |
+
# out of position id,
|
169 |
+
self.semantic_fm.clear_all_states()
|
170 |
+
self.semantic_fm.load_state_dict(self.state_dict_backup)
|
171 |
+
|
172 |
+
return ret_wav
|
173 |
+
|
174 |
+
def clear_states(self):
|
175 |
+
self.semantic_fm.clear_all_states()
|
176 |
+
self.previous_chunk_left = None
|
177 |
+
self.pre_mel = None
|
178 |
+
self.pre_wav = None
|
179 |
+
|
180 |
+
def get_audio_detokenizer():
|
181 |
+
fm_model_config = "resources/audio_detokenizer/config.yaml"
|
182 |
+
fm_ckpt_path = "resources/audio_detokenizer/model.pt"
|
183 |
+
|
184 |
+
bigvgan_config_file = "resources/vocoder/config.json"
|
185 |
+
bigvgan_ckpt_path = "resources/vocoder/model.pt"
|
186 |
+
|
187 |
+
device=torch.cuda.current_device()
|
188 |
+
detokenizer = PrefixStreamingFlowMatchingDetokenizer.from_pretrained(
|
189 |
+
vocoder_config=bigvgan_config_file,
|
190 |
+
vocoder_ckpt=bigvgan_ckpt_path,
|
191 |
+
max_prompt_chunk=10, # 10 * 3 = 30s
|
192 |
+
fm_config=fm_model_config,
|
193 |
+
fm_ckpt=fm_ckpt_path,
|
194 |
+
device=device,
|
195 |
+
use_cfg=False,
|
196 |
+
look_ahead_tokens=12)
|
197 |
+
|
198 |
+
return detokenizer
|
199 |
+
|
200 |
+
|
201 |
+
def detokenize(detokenizer, tokens, ref_wav, ref_tokens, streaming=False):
|
202 |
+
with torch.no_grad():
|
203 |
+
detokenizer.clear_states()
|
204 |
+
detokenizer.prefill(ref_wav, ref_tokens, chunk_size=150)
|
205 |
+
cache_speech_collection = []
|
206 |
+
chunk_size = 150
|
207 |
+
first_chunk_size = 100
|
208 |
+
first_chunk_tokens = tokens[:, :first_chunk_size]
|
209 |
+
gen_speech = detokenizer.detokenize_streaming(first_chunk_tokens, is_final=tokens.size(1) <= first_chunk_size)
|
210 |
+
if streaming:
|
211 |
+
yield gen_speech
|
212 |
+
else:
|
213 |
+
cache_speech_collection.append(gen_speech)
|
214 |
+
res_tokens = tokens[:, first_chunk_size:]
|
215 |
+
for i in range(0, res_tokens.size(1), chunk_size):
|
216 |
+
chunk_tokens = res_tokens[:, i:i+chunk_size]
|
217 |
+
gen_speech = detokenizer.detokenize_streaming(chunk_tokens, is_final=(i+chunk_size >= res_tokens.size(1)))
|
218 |
+
if streaming:
|
219 |
+
yield gen_speech
|
220 |
+
else:
|
221 |
+
cache_speech_collection.append(gen_speech)
|
222 |
+
if not streaming:
|
223 |
+
gen_speech_all = torch.cat(cache_speech_collection, dim=-1)
|
224 |
+
return gen_speech_all
|
225 |
+
|
226 |
+
def detokenize_noref(detokenizer, tokens, streaming=False):
|
227 |
+
with torch.no_grad():
|
228 |
+
detokenizer.clear_states()
|
229 |
+
cache_speech_collection = []
|
230 |
+
chunk_size = 150
|
231 |
+
first_chunk_size = 100
|
232 |
+
first_chunk_tokens = tokens[:, :first_chunk_size]
|
233 |
+
gen_speech = detokenizer.detokenize_streaming(first_chunk_tokens, is_final=tokens.size(1) <= first_chunk_size)
|
234 |
+
if streaming:
|
235 |
+
yield gen_speech
|
236 |
+
else:
|
237 |
+
cache_speech_collection.append(gen_speech)
|
238 |
+
res_tokens = tokens[:, first_chunk_size:]
|
239 |
+
for i in range(0, res_tokens.size(1), chunk_size):
|
240 |
+
chunk_tokens = res_tokens[:, i:i+chunk_size]
|
241 |
+
gen_speech = detokenizer.detokenize_streaming(chunk_tokens, is_final=(i+chunk_size >= res_tokens.size(1)))
|
242 |
+
if streaming:
|
243 |
+
yield gen_speech
|
244 |
+
else:
|
245 |
+
cache_speech_collection.append(gen_speech)
|
246 |
+
if not streaming:
|
247 |
+
gen_speech_all = torch.cat(cache_speech_collection, dim=-1)
|
248 |
+
return gen_speech_all
|
249 |
+
|
modules/audio_detokenizer/bigvgan_wrapper.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
|
5 |
+
import librosa
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from modules.audio_detokenizer.vocoder.bigvgan import BigVGAN
|
9 |
+
from modules.audio_detokenizer.vocoder.utils import get_melspec, AttrDict, load_checkpoint
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
|
14 |
+
class BigVGANWrapper:
|
15 |
+
def __init__(self, vocoder: BigVGAN, device: torch.device, h: AttrDict, dtype=None) -> None:
|
16 |
+
self.vocoder = vocoder.to(device)
|
17 |
+
if dtype is not None:
|
18 |
+
self.vocoder = self.vocoder.to(dtype)
|
19 |
+
self.vocoder = self.vocoder.eval()
|
20 |
+
self.device = device
|
21 |
+
self.h = h
|
22 |
+
|
23 |
+
def to_dtype(self, dtype):
|
24 |
+
self.vocoder = self.vocoder.to(dtype)
|
25 |
+
|
26 |
+
def extract_mel_from_wav(self, wav_path=None, wav_data=None):
|
27 |
+
"""
|
28 |
+
params:
|
29 |
+
wav_path: str, path of the wav, should be 24k
|
30 |
+
wav_data: torch.tensor or numpy array, shape [T], wav data, should be 24k
|
31 |
+
return:
|
32 |
+
mel: [T, num_mels], torch.tensor
|
33 |
+
"""
|
34 |
+
if wav_data is None:
|
35 |
+
wav_data, _ = librosa.load(wav_path, sr=self.h["sampling_rate"])
|
36 |
+
|
37 |
+
wav_data = torch.tensor(wav_data).unsqueeze(0)
|
38 |
+
|
39 |
+
mel = get_melspec(y=wav_data, n_fft=self.h["n_fft"], num_mels=self.h["num_mels"], sampling_rate=self.h["sampling_rate"],
|
40 |
+
hop_size=self.h["hop_size"], win_size=self.h["win_size"], fmin=self.h["fmin"], fmax=self.h["fmax"])
|
41 |
+
return mel.squeeze(0).transpose(0, 1)
|
42 |
+
|
43 |
+
@torch.inference_mode()
|
44 |
+
def extract_mel_from_wav_batch(self, wav_data):
|
45 |
+
"""
|
46 |
+
params:
|
47 |
+
wav_data: torch.tensor or numpy array, shape [Batch, T], wav data, should be 24k
|
48 |
+
return:
|
49 |
+
mel: [Batch, T, num_mels], torch.tensor
|
50 |
+
"""
|
51 |
+
|
52 |
+
wav_data = torch.tensor(wav_data)
|
53 |
+
|
54 |
+
mel = get_melspec(wav=wav_data, n_fft=self.h["n_fft"], num_mels=self.h["num_mels"], sampling_rate=self.h["sampling_rate"],
|
55 |
+
hop_size=self.h["hop_size"], win_size=self.h["win_size"], fmin=self.h["fmin"], fmax=self.h["fmax"])
|
56 |
+
return mel.transpose(1, 2)
|
57 |
+
|
58 |
+
def decode_mel(self, mel):
|
59 |
+
"""
|
60 |
+
params:
|
61 |
+
mel: [T, num_mels], torch.tensor
|
62 |
+
return:
|
63 |
+
wav: [1, T], torch.tensor
|
64 |
+
"""
|
65 |
+
mel = mel.transpose(0, 1).unsqueeze(0).to(self.device)
|
66 |
+
wav = self.vocoder(mel)
|
67 |
+
return wav.squeeze(0)
|
68 |
+
|
69 |
+
def decode_mel_batch(self, mel):
|
70 |
+
"""
|
71 |
+
params:
|
72 |
+
mel: [B, T, num_mels], torch.tensor
|
73 |
+
return:
|
74 |
+
wav: [B, 1, T], torch.tensor
|
75 |
+
"""
|
76 |
+
mel = mel.transpose(1, 2).to(self.device)
|
77 |
+
wav = self.vocoder(mel)
|
78 |
+
return wav
|
79 |
+
|
80 |
+
@classmethod
|
81 |
+
def from_pretrained(cls, model_config, ckpt_path, device):
|
82 |
+
with open(model_config) as f:
|
83 |
+
data = f.read()
|
84 |
+
json_config = json.loads(data)
|
85 |
+
h = AttrDict(json_config)
|
86 |
+
vocoder = BigVGAN(h, True)
|
87 |
+
state_dict_g = load_checkpoint(ckpt_path, "cpu")
|
88 |
+
vocoder.load_state_dict(state_dict_g["generator"])
|
89 |
+
|
90 |
+
logger.info(">>> Load vocoder from {}".format(ckpt_path))
|
91 |
+
return cls(vocoder, device, h)
|
92 |
+
|
93 |
+
|
94 |
+
|
modules/audio_detokenizer/flow_matching/dit_block.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from flash_attn import flash_attn_varlen_func, flash_attn_varlen_qkvpacked_func
|
10 |
+
|
11 |
+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
12 |
+
# x shape: bsz, seqlen, self.n_local_heads, self.head_hidden_dim / 2
|
13 |
+
# the last shape is "self.hidden_dim / 2" because we convert to complex
|
14 |
+
assert x.ndim == 4
|
15 |
+
assert freqs_cis.shape == (x.shape[0], x.shape[1], x.shape[-1]), \
|
16 |
+
f'x shape: {x.shape}, freqs_cis shape: {freqs_cis.shape}'
|
17 |
+
|
18 |
+
# reshape freq cis to match and apply pointwise multiply
|
19 |
+
# new shape: bsz, seq_len, 1, self.head_hidden_dim / 2
|
20 |
+
shape = [x.shape[0], x.shape[1], 1, x.shape[-1]]
|
21 |
+
return freqs_cis.view(*shape)
|
22 |
+
|
23 |
+
|
24 |
+
def apply_rotary_emb(
|
25 |
+
xq: torch.Tensor,
|
26 |
+
xk: torch.Tensor,
|
27 |
+
freqs_cis: torch.Tensor,
|
28 |
+
):
|
29 |
+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
30 |
+
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
31 |
+
|
32 |
+
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
33 |
+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
34 |
+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
35 |
+
return xq_out.type_as(xq), xk_out.type_as(xk)
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
class Attention(nn.Module):
|
40 |
+
|
41 |
+
def __init__(
|
42 |
+
self,
|
43 |
+
dim: int,
|
44 |
+
num_heads: int = 8,
|
45 |
+
qkv_bias: bool = False,
|
46 |
+
qk_norm: bool = False,
|
47 |
+
attn_drop: float = 0.,
|
48 |
+
proj_drop: float = 0.,
|
49 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
50 |
+
flash_attention: bool = True
|
51 |
+
) -> None:
|
52 |
+
super().__init__()
|
53 |
+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
54 |
+
self.num_heads = num_heads
|
55 |
+
self.head_dim = dim // num_heads
|
56 |
+
self.scale = self.head_dim ** -0.5
|
57 |
+
self.fused_attn = flash_attention
|
58 |
+
|
59 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
60 |
+
self.qk_norm = qk_norm
|
61 |
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
62 |
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
63 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
64 |
+
self.proj = nn.Linear(dim, dim)
|
65 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
66 |
+
|
67 |
+
def forward(self, x: torch.Tensor, seq_len, cu_seqlens, max_seqlen, cu_seqlens_k, max_seqlen_k, rotary_pos_emb=None, incremental_state=None, nopadding=True) -> torch.Tensor:
|
68 |
+
B, N, C = x.shape
|
69 |
+
|
70 |
+
if self.fused_attn:
|
71 |
+
if nopadding:
|
72 |
+
qkv = self.qkv(x)
|
73 |
+
qkv = qkv.view(B * N, self.num_heads * 3, self.head_dim)
|
74 |
+
q, k, v = qkv.split([self.num_heads] * 3, dim=1)
|
75 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
76 |
+
|
77 |
+
q = q.view(B, N, self.num_heads, self.head_dim)
|
78 |
+
k = k.view(B, N, self.num_heads, self.head_dim)
|
79 |
+
v = v.view(B, N, self.num_heads, self.head_dim)
|
80 |
+
|
81 |
+
if rotary_pos_emb is not None:
|
82 |
+
q, k = apply_rotary_emb(q, k, rotary_pos_emb)
|
83 |
+
|
84 |
+
if incremental_state is not None:
|
85 |
+
if "prev_k" in incremental_state:
|
86 |
+
prev_k = incremental_state["prev_k"]
|
87 |
+
k = torch.cat([prev_k, k], dim=1)
|
88 |
+
|
89 |
+
if "cur_k" not in incremental_state:
|
90 |
+
incremental_state["cur_k"] = {}
|
91 |
+
incremental_state["cur_k"] = k
|
92 |
+
|
93 |
+
if "prev_v" in incremental_state:
|
94 |
+
prev_v = incremental_state["prev_v"]
|
95 |
+
v = torch.cat([prev_v, v], dim=1)
|
96 |
+
|
97 |
+
if "cur_v" not in incremental_state:
|
98 |
+
incremental_state["cur_v"] = {}
|
99 |
+
incremental_state["cur_v"] = v
|
100 |
+
|
101 |
+
q = q.view(B * N, self.num_heads, self.head_dim)
|
102 |
+
k = k.view(-1, self.num_heads, self.head_dim)
|
103 |
+
v = v.view(-1, self.num_heads, self.head_dim)
|
104 |
+
|
105 |
+
x = flash_attn_varlen_func(
|
106 |
+
q=q,
|
107 |
+
k=k,
|
108 |
+
v=v,
|
109 |
+
cu_seqlens_q=cu_seqlens,
|
110 |
+
cu_seqlens_k=cu_seqlens_k,
|
111 |
+
max_seqlen_q=max_seqlen,
|
112 |
+
max_seqlen_k=max_seqlen_k,
|
113 |
+
dropout_p=self.attn_drop.p if self.training else 0.,
|
114 |
+
)
|
115 |
+
else:
|
116 |
+
|
117 |
+
if incremental_state is not None:
|
118 |
+
raise NotImplementedError("It is designed for batching inference. AR-chunk is not supported currently.")
|
119 |
+
|
120 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
|
121 |
+
if self.qk_norm:
|
122 |
+
q, k, v = qkv.unbind(2)
|
123 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
124 |
+
# re-bind
|
125 |
+
qkv = torch.stack((q, k, v), dim=2)
|
126 |
+
|
127 |
+
# pack qkv with seq_len
|
128 |
+
qkv_collect = []
|
129 |
+
for i in range(qkv.shape[0]):
|
130 |
+
qkv_collect.append(
|
131 |
+
qkv[i, :seq_len[i], :, :, :]
|
132 |
+
)
|
133 |
+
|
134 |
+
qkv = torch.cat(qkv_collect, dim=0)
|
135 |
+
|
136 |
+
x = flash_attn_varlen_qkvpacked_func(qkv=qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, dropout_p=self.attn_drop.p if self.training else 0.)
|
137 |
+
|
138 |
+
# unpack and pad 0
|
139 |
+
x_collect = []
|
140 |
+
for i in range(B):
|
141 |
+
x_collect.append(
|
142 |
+
x[cu_seqlens[i]:cu_seqlens[i+1], :, :]
|
143 |
+
)
|
144 |
+
x = torch.nn.utils.rnn.pad_sequence(x_collect, batch_first=True, padding_value=0)
|
145 |
+
|
146 |
+
else:
|
147 |
+
q = q * self.scale
|
148 |
+
attn = q @ k.transpose(-2, -1)
|
149 |
+
attn = attn.softmax(dim=-1)
|
150 |
+
attn = self.attn_drop(attn)
|
151 |
+
x = attn @ v
|
152 |
+
x = x.transpose(1, 2)
|
153 |
+
|
154 |
+
x = x.reshape(B, N, C)
|
155 |
+
x = self.proj(x)
|
156 |
+
x = self.proj_drop(x)
|
157 |
+
return x
|
158 |
+
|
159 |
+
|
160 |
+
def modulate(x, shift, scale):
|
161 |
+
return x * (1 + scale) + shift
|
162 |
+
|
163 |
+
|
164 |
+
class FinalLayer(nn.Module):
|
165 |
+
"""
|
166 |
+
The final layer of DiT.
|
167 |
+
"""
|
168 |
+
def __init__(self, hidden_size, out_channels):
|
169 |
+
super().__init__()
|
170 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
171 |
+
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
|
172 |
+
self.adaLN_modulation = nn.Sequential(
|
173 |
+
nn.SiLU(),
|
174 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
175 |
+
)
|
176 |
+
|
177 |
+
def forward(self, x, c):
|
178 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=2)
|
179 |
+
x = modulate(self.norm_final(x), shift, scale)
|
180 |
+
x = self.linear(x)
|
181 |
+
return x
|
182 |
+
|
183 |
+
|
184 |
+
class DiTBlock(nn.Module):
|
185 |
+
"""
|
186 |
+
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
187 |
+
"""
|
188 |
+
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, ffn_type="conv1d_conv1d", ffn_gated_glu=True, ffn_act_layer="gelu", ffn_conv_kernel_size=5, **block_kwargs):
|
189 |
+
super().__init__()
|
190 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
191 |
+
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
|
192 |
+
|
193 |
+
|
194 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
195 |
+
|
196 |
+
if ffn_type == "vanilla_mlp":
|
197 |
+
from timm.models.vision_transformer import Mlp
|
198 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
199 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
200 |
+
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
|
201 |
+
else:
|
202 |
+
raise NotImplementedError(f"FFN type {ffn_type} is not implemented")
|
203 |
+
|
204 |
+
self.adaLN_modulation = nn.Sequential(
|
205 |
+
nn.SiLU(),
|
206 |
+
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
207 |
+
)
|
208 |
+
|
209 |
+
def forward(self, x, c, seq_len, cu_seqlens, cu_maxlen, cu_seqlens_k, cu_maxlen_k, mask, rotary_pos_emb=None, incremental_state=None, nopadding=True):
|
210 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=2)
|
211 |
+
|
212 |
+
x_ = modulate(self.norm1(x), shift_msa, scale_msa)
|
213 |
+
|
214 |
+
if incremental_state is not None:
|
215 |
+
if "attn_kvcache" not in incremental_state:
|
216 |
+
incremental_state["attn_kvcache"] = {}
|
217 |
+
inc_attn = incremental_state["attn_kvcache"]
|
218 |
+
else:
|
219 |
+
inc_attn = None
|
220 |
+
|
221 |
+
x_ = self.attn(x_, seq_len=seq_len, cu_seqlens=cu_seqlens, max_seqlen=cu_maxlen, cu_seqlens_k=cu_seqlens_k, max_seqlen_k=cu_maxlen_k, rotary_pos_emb=rotary_pos_emb, incremental_state=inc_attn, nopadding=nopadding)
|
222 |
+
|
223 |
+
if not nopadding:
|
224 |
+
x_ = x_ * mask[:, :, None]
|
225 |
+
|
226 |
+
x = x + gate_msa * x_
|
227 |
+
|
228 |
+
x_ = modulate(self.norm2(x), shift_mlp, scale_mlp)
|
229 |
+
|
230 |
+
x_ = self.mlp(x_)
|
231 |
+
|
232 |
+
if not nopadding:
|
233 |
+
x_ = x_ * mask[:, :, None]
|
234 |
+
|
235 |
+
x = x + gate_mlp * x_
|
236 |
+
return x
|
modules/audio_detokenizer/flow_matching/model.py
ADDED
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import math
|
4 |
+
from modules.audio_detokenizer.flow_matching.dit_block import DiTBlock, FinalLayer
|
5 |
+
|
6 |
+
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0,
|
7 |
+
interpolation_factor: int = 1, max_seq_length: int = 4096):
|
8 |
+
print(f'using rope base theta = {theta}, interpolation factor = {interpolation_factor}')
|
9 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
10 |
+
|
11 |
+
# ROPE type-A extention
|
12 |
+
# we choose to use interpolation rather than extrapolation for better position encoding
|
13 |
+
# for scale purposes, t should be a float tensor
|
14 |
+
t = torch.arange(end, device=freqs.device).float()
|
15 |
+
scale = 1.0 / float(interpolation_factor)
|
16 |
+
t *= scale
|
17 |
+
|
18 |
+
freqs = torch.outer(t, freqs).float() # type: ignore
|
19 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
20 |
+
|
21 |
+
# Sometimes, we don't need so many rope emb as seq_len is smaller than max_pos_emb
|
22 |
+
# e.g. rope 1M but seqlen 32k, this will cause gpu memory waste
|
23 |
+
if max_seq_length < end:
|
24 |
+
freqs_cis = freqs_cis[:max_seq_length,].clone()
|
25 |
+
return freqs_cis
|
26 |
+
|
27 |
+
|
28 |
+
class TimestepEmbedder(nn.Module):
|
29 |
+
"""
|
30 |
+
Embeds scalar timesteps into vector representations.
|
31 |
+
"""
|
32 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
33 |
+
super().__init__()
|
34 |
+
self.mlp = nn.Sequential(
|
35 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
36 |
+
nn.SiLU(),
|
37 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
38 |
+
)
|
39 |
+
self.frequency_embedding_size = frequency_embedding_size
|
40 |
+
|
41 |
+
@staticmethod
|
42 |
+
def timestep_embedding(t, dim, max_period=10000):
|
43 |
+
"""
|
44 |
+
Create sinusoidal timestep embeddings.
|
45 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
46 |
+
These may be fractional.
|
47 |
+
:param dim: the dimension of the output.
|
48 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
49 |
+
:return: an (N, D) Tensor of positional embeddings.
|
50 |
+
"""
|
51 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
52 |
+
half = dim // 2
|
53 |
+
freqs = torch.exp(
|
54 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
55 |
+
).float().to(device=t.device)
|
56 |
+
args = t[:, None].float() * freqs[None]
|
57 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
58 |
+
if dim % 2:
|
59 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
60 |
+
return embedding
|
61 |
+
|
62 |
+
def forward(self, t):
|
63 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
64 |
+
t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
|
65 |
+
return t_emb
|
66 |
+
|
67 |
+
|
68 |
+
class SinusoidalPositionalEmbedding(nn.Module):
|
69 |
+
"""This module produces sinusoidal positional embeddings of any length.
|
70 |
+
|
71 |
+
Padding symbols are ignored.
|
72 |
+
"""
|
73 |
+
|
74 |
+
def __init__(self, embedding_dim, padding_idx, init_size=1024):
|
75 |
+
super().__init__()
|
76 |
+
self.embedding_dim = embedding_dim
|
77 |
+
self.padding_idx = padding_idx
|
78 |
+
self.weights = SinusoidalPositionalEmbedding.get_embedding(
|
79 |
+
init_size,
|
80 |
+
embedding_dim,
|
81 |
+
padding_idx,
|
82 |
+
)
|
83 |
+
self.register_buffer('_float_tensor', torch.FloatTensor(1))
|
84 |
+
|
85 |
+
@staticmethod
|
86 |
+
def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
|
87 |
+
"""Build sinusoidal embeddings.
|
88 |
+
|
89 |
+
This matches the implementation in tensor2tensor, but differs slightly
|
90 |
+
from the description in Section 3.5 of "Attention Is All You Need".
|
91 |
+
"""
|
92 |
+
half_dim = embedding_dim // 2 # d/2
|
93 |
+
emb = math.log(10000) / (half_dim - 1) # 2*log(10000)/(d-2)
|
94 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) # -2i/(d-2)*log(10000); i from 0 to (d-2)/2; shape: (d/2, )
|
95 |
+
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) # pos/[1000 ** (2i/(d-2))]; shape: (num_embeddings, d/2)
|
96 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) # shape: (num_embeddings, d)
|
97 |
+
if embedding_dim % 2 == 1:
|
98 |
+
# zero pad
|
99 |
+
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
|
100 |
+
if padding_idx is not None:
|
101 |
+
emb[padding_idx, :] = 0
|
102 |
+
return emb
|
103 |
+
|
104 |
+
def forward(self, input, incremental_state=None, timestep=None, **kwargs):
|
105 |
+
"""Input is expected to be of size [bsz x seqlen]."""
|
106 |
+
bsz, seq_len = input.shape[:2]
|
107 |
+
max_pos = self.padding_idx + 1 + seq_len
|
108 |
+
if self.weights is None or max_pos > self.weights.size(0):
|
109 |
+
# recompute/expand embeddings if needed
|
110 |
+
self.weights = SinusoidalPositionalEmbedding.get_embedding(
|
111 |
+
max_pos,
|
112 |
+
self.embedding_dim,
|
113 |
+
self.padding_idx,
|
114 |
+
)
|
115 |
+
self.weights = self.weights.to(self._float_tensor)
|
116 |
+
|
117 |
+
if incremental_state is not None:
|
118 |
+
# positions is the same for every token when decoding a single step
|
119 |
+
pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
|
120 |
+
return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
|
121 |
+
|
122 |
+
positions = self.make_positions(input, self.padding_idx)
|
123 |
+
return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach() # (B, T, dim)
|
124 |
+
|
125 |
+
def max_positions(self):
|
126 |
+
"""Maximum number of supported positions."""
|
127 |
+
return int(1e5) # an arbitrary large number
|
128 |
+
|
129 |
+
def make_positions(self, tensor, padding_idx):
|
130 |
+
"""Replace non-padding symbols with their position numbers.
|
131 |
+
|
132 |
+
Position numbers begin at padding_idx+1. Padding symbols are ignored.
|
133 |
+
"""
|
134 |
+
# The series of casts and type-conversions here are carefully
|
135 |
+
# balanced to both work with ONNX export and XLA. In particular XLA
|
136 |
+
# prefers ints, cumsum defaults to output longs, and ONNX doesn't know
|
137 |
+
# how to handle the dtype kwarg in cumsum.
|
138 |
+
mask = tensor.ne(padding_idx).int()
|
139 |
+
return (
|
140 |
+
torch.cumsum(mask, dim=1).type_as(mask) * mask
|
141 |
+
).long() + padding_idx
|
142 |
+
|
143 |
+
|
144 |
+
class DiTPrefix(nn.Module):
|
145 |
+
"""
|
146 |
+
Diffusion model with a Transformer backbone.
|
147 |
+
"""
|
148 |
+
def __init__(
|
149 |
+
self,
|
150 |
+
input_size,
|
151 |
+
output_size,
|
152 |
+
semantic_vocab_size,
|
153 |
+
hidden_size=1024,
|
154 |
+
depth=12,
|
155 |
+
num_heads=4,
|
156 |
+
# mlp related
|
157 |
+
mlp_ratio=4.0,
|
158 |
+
ffn_type="conv1d_conv1d",
|
159 |
+
ffn_gated_glu=True,
|
160 |
+
ffn_act_layer="gelu",
|
161 |
+
ffn_conv_kernel_size=5,
|
162 |
+
|
163 |
+
# rope
|
164 |
+
use_rope=False,
|
165 |
+
rope_params={
|
166 |
+
"max_position_embeddings": 4096,
|
167 |
+
"rope_base": 10000.0,
|
168 |
+
"rope_interpolation_factor": 1.0,
|
169 |
+
},
|
170 |
+
|
171 |
+
|
172 |
+
position_embedding_type="sincos",
|
173 |
+
max_seq_len=4096,
|
174 |
+
prompt_cfg_dropout=0.0
|
175 |
+
):
|
176 |
+
super().__init__()
|
177 |
+
self.num_heads = num_heads
|
178 |
+
|
179 |
+
self.prompt_cfg_dropout = prompt_cfg_dropout
|
180 |
+
|
181 |
+
self.t_embedder = TimestepEmbedder(hidden_size)
|
182 |
+
|
183 |
+
self.semantic_token_embedding = nn.Embedding(semantic_vocab_size, hidden_size)
|
184 |
+
|
185 |
+
self.input_linear = nn.Linear(input_size, hidden_size)
|
186 |
+
|
187 |
+
# position embedding
|
188 |
+
if position_embedding_type == "learnable":
|
189 |
+
self.position_embedding = nn.Embedding(max_seq_len+1, hidden_size)
|
190 |
+
elif position_embedding_type == "sincos":
|
191 |
+
self.position_embedding = SinusoidalPositionalEmbedding(hidden_size, 0, max_seq_len+1)
|
192 |
+
elif position_embedding_type == "skip":
|
193 |
+
self.position_embedding = None
|
194 |
+
else:
|
195 |
+
raise NotImplementedError("Position embedding type: {} not implemented.".format(position_embedding_type))
|
196 |
+
|
197 |
+
self.use_rope = use_rope
|
198 |
+
|
199 |
+
if self.use_rope:
|
200 |
+
|
201 |
+
assert hidden_size % num_heads == 0, "Hidden size must be divisible by num_heads for rope position embedding."
|
202 |
+
rope_dim = hidden_size // num_heads
|
203 |
+
|
204 |
+
self.rotary_pos_emb = precompute_freqs_cis(
|
205 |
+
rope_dim, rope_params["max_position_embeddings"],
|
206 |
+
theta=rope_params["rope_base"],
|
207 |
+
interpolation_factor=rope_params["rope_interpolation_factor"],
|
208 |
+
max_seq_length=max_seq_len
|
209 |
+
)
|
210 |
+
|
211 |
+
self.blocks = nn.ModuleList([
|
212 |
+
DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio,
|
213 |
+
ffn_type=ffn_type, ffn_conv_kernel_size=ffn_conv_kernel_size, ffn_gated_glu=ffn_gated_glu, ffn_act_layer=ffn_act_layer) for _ in range(depth)
|
214 |
+
])
|
215 |
+
self.final_layer = FinalLayer(hidden_size, output_size)
|
216 |
+
self.initialize_weights()
|
217 |
+
|
218 |
+
def initialize_weights(self):
|
219 |
+
# Initialize transformer layers:
|
220 |
+
def _basic_init(module):
|
221 |
+
if isinstance(module, nn.Linear):
|
222 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
223 |
+
if module.bias is not None:
|
224 |
+
nn.init.constant_(module.bias, 0)
|
225 |
+
self.apply(_basic_init)
|
226 |
+
|
227 |
+
|
228 |
+
# Initialize timestep embedding MLP:
|
229 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
230 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
231 |
+
|
232 |
+
# Zero-out adaLN modulation layers in DiT blocks:
|
233 |
+
for block in self.blocks:
|
234 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
235 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
236 |
+
|
237 |
+
# Zero-out output layers:
|
238 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
239 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
240 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
241 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
242 |
+
|
243 |
+
def forward(self, x, position_ids, t, condition, seq_len, cu_seqlens, cu_maxlen, cu_seqlens_k, cu_maxlen_k, mask, incremental_state=None, nopadding=True):
|
244 |
+
"""
|
245 |
+
Forward pass of DiT.
|
246 |
+
x: (N, T, C) tensor of inputs (latent representations of speech)
|
247 |
+
position_ids: (N, T) tensor of positional indices
|
248 |
+
t: (N,) tensor of diffusion timesteps
|
249 |
+
condition: (N, T) tensor of semantic tokens
|
250 |
+
seq_len: (N,) tensor of sequence lengths
|
251 |
+
"""
|
252 |
+
|
253 |
+
condition = self.semantic_token_embedding(condition) # (N, T, D)
|
254 |
+
|
255 |
+
x = self.input_linear(x)
|
256 |
+
|
257 |
+
if self.position_embedding is not None:
|
258 |
+
position_emb = self.position_embedding(position_ids)
|
259 |
+
x = x + position_emb
|
260 |
+
|
261 |
+
# ROPE
|
262 |
+
if self.use_rope:
|
263 |
+
bsz, seqlen = position_ids.shape
|
264 |
+
if self.rotary_pos_emb.device != position_ids.device:
|
265 |
+
self.rotary_pos_emb = self.rotary_pos_emb.to(position_ids.device)
|
266 |
+
rotary_pos_emb = torch.zeros((bsz, seqlen, self.rotary_pos_emb.shape[1]),
|
267 |
+
dtype=self.rotary_pos_emb.dtype,
|
268 |
+
device=self.rotary_pos_emb.device)
|
269 |
+
for b in range(bsz):
|
270 |
+
cur_rope = rotary_pos_emb[b]
|
271 |
+
cur_position_ids = position_ids[b]
|
272 |
+
cur_rope[:] = self.rotary_pos_emb[cur_position_ids]
|
273 |
+
else:
|
274 |
+
rotary_pos_emb = None
|
275 |
+
|
276 |
+
t = self.t_embedder(t) # (N, D)
|
277 |
+
c = t.unsqueeze(1) + condition # (N, T, D)
|
278 |
+
|
279 |
+
|
280 |
+
for block_idx, block in enumerate(self.blocks):
|
281 |
+
# x = block(x, c, attn_mask) # (N, T, D)
|
282 |
+
# XXX mask could be None because we always use full mask
|
283 |
+
|
284 |
+
if incremental_state is not None:
|
285 |
+
if block_idx not in incremental_state:
|
286 |
+
incremental_state[block_idx] = {}
|
287 |
+
incr = incremental_state[block_idx]
|
288 |
+
else:
|
289 |
+
incr = None
|
290 |
+
|
291 |
+
x = block(x=x, c=c, seq_len=seq_len, cu_seqlens=cu_seqlens, cu_maxlen=cu_maxlen, cu_seqlens_k=cu_seqlens_k, cu_maxlen_k=cu_maxlen_k, mask=mask, rotary_pos_emb=rotary_pos_emb, incremental_state=incr, nopadding=nopadding)
|
292 |
+
|
293 |
+
x = self.final_layer(x, c) # (N, T, C)
|
294 |
+
return x
|
295 |
+
|
modules/audio_detokenizer/flow_matching/ode_wrapper.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from functools import lru_cache
|
4 |
+
import copy
|
5 |
+
|
6 |
+
|
7 |
+
@lru_cache(maxsize=1)
|
8 |
+
def get_cached_zeros(numel, device="cpu", dtype=torch.float32):
|
9 |
+
return torch.zeros(numel, device=device, dtype=dtype)
|
10 |
+
|
11 |
+
class StreamingODEWrapperForPrefix(nn.Module):
|
12 |
+
def __init__(self, net, x_mask, x_cond, use_cfg=False, use_cfg_rescale=True, cfg_init=1.0, cfg_scale=4.0, cfg_schedule="linear", cfg_token_id=0):
|
13 |
+
super(StreamingODEWrapperForPrefix, self).__init__()
|
14 |
+
self.net = net
|
15 |
+
self.x_mask = x_mask
|
16 |
+
self.x_cond = x_cond
|
17 |
+
|
18 |
+
assert use_cfg == False, "cfg is not supported in streaming detokenizer"
|
19 |
+
|
20 |
+
self.use_cfg = use_cfg
|
21 |
+
self.use_cfg_rescale = use_cfg_rescale
|
22 |
+
self.cfg_init = cfg_init
|
23 |
+
self.cfg_scale = cfg_scale
|
24 |
+
self.cfg_token_id = cfg_token_id
|
25 |
+
self.cfg_schedule = cfg_schedule
|
26 |
+
self.position_ids = None
|
27 |
+
self.seq_len = None
|
28 |
+
|
29 |
+
self.incremental_state = {}
|
30 |
+
self.kv_cache_tokens = 0
|
31 |
+
self.cu_seqlens = None
|
32 |
+
self.cu_maxlen = None
|
33 |
+
|
34 |
+
self.cu_seqlens_k = None
|
35 |
+
self.cu_maxlen_k = None
|
36 |
+
self.previous_seqlen = None
|
37 |
+
|
38 |
+
def clear_all_states(self):
|
39 |
+
self.incremental_state = {}
|
40 |
+
self.kv_cache_tokens = 0
|
41 |
+
self.cu_seqlens = None
|
42 |
+
self.cu_maxlen = None
|
43 |
+
|
44 |
+
self.cu_seqlens_k = None
|
45 |
+
self.cu_maxlen_k = None
|
46 |
+
self.previous_seqlen = None
|
47 |
+
|
48 |
+
def state_dict(self):
|
49 |
+
return {
|
50 |
+
"incremental_state": copy.deepcopy(self.incremental_state),
|
51 |
+
"kv_cache_tokens": copy.deepcopy(self.kv_cache_tokens),
|
52 |
+
"cu_seqlens": copy.deepcopy(self.cu_seqlens),
|
53 |
+
"cu_maxlen": copy.deepcopy(self.cu_maxlen),
|
54 |
+
"cu_seqlens_k": copy.deepcopy(self.cu_seqlens_k),
|
55 |
+
"cu_maxlen_k": copy.deepcopy(self.cu_maxlen_k),
|
56 |
+
"previous_seqlen": copy.deepcopy(self.previous_seqlen)
|
57 |
+
}
|
58 |
+
|
59 |
+
def load_state_dict(self, state_dict):
|
60 |
+
self.incremental_state = state_dict["incremental_state"]
|
61 |
+
self.kv_cache_tokens = state_dict["kv_cache_tokens"]
|
62 |
+
self.cu_seqlens = state_dict["cu_seqlens"]
|
63 |
+
self.cu_maxlen = state_dict["cu_maxlen"]
|
64 |
+
self.cu_seqlens_k = state_dict["cu_seqlens_k"]
|
65 |
+
self.cu_maxlen_k = state_dict["cu_maxlen_k"]
|
66 |
+
self.previous_seqlen = state_dict["previous_seqlen"]
|
67 |
+
|
68 |
+
def set_conditions(self, x_mask, x_cond, start_position_id, cache={}):
|
69 |
+
if not self.use_cfg:
|
70 |
+
self.x_mask = x_mask
|
71 |
+
self.x_cond = x_cond
|
72 |
+
else:
|
73 |
+
self.x_cond = torch.cat((x_cond, x_cond), dim=0)
|
74 |
+
self.x_mask = torch.cat((x_mask, x_mask), dim=0)
|
75 |
+
|
76 |
+
position_ids_cur = [i for i in range(start_position_id, self.x_cond.shape[1] + start_position_id)]
|
77 |
+
position_ids = torch.tensor([position_ids_cur])
|
78 |
+
|
79 |
+
|
80 |
+
if not self.use_cfg:
|
81 |
+
self.position_ids = position_ids.to(self.x_cond.device).long()
|
82 |
+
self.seq_len = torch.Tensor([position_ids.shape[1]]).to(self.x_cond.device).long()
|
83 |
+
else:
|
84 |
+
self.position_ids = torch.cat((position_ids, position_ids), dim=0).to(self.x_cond.device).long()
|
85 |
+
self.seq_len = torch.Tensor([position_ids.shape[1], position_ids.shape[1]]).to(self.x_cond.device).long()
|
86 |
+
|
87 |
+
cu_seqlens = torch.cumsum(self.seq_len, dim=0)
|
88 |
+
self.cu_seqlens = torch.cat([torch.Tensor([0]).to(cu_seqlens.device), cu_seqlens], dim=0).int()
|
89 |
+
self.cu_maxlen = self.seq_len.cpu().max()
|
90 |
+
|
91 |
+
if self.cu_seqlens_k is None:
|
92 |
+
self.cu_seqlens_k = self.cu_seqlens
|
93 |
+
self.cu_maxlen_k = self.cu_maxlen
|
94 |
+
previous_seqlen = self.seq_len
|
95 |
+
else:
|
96 |
+
previous_seqlen_old = cache["previous_seqlen"]
|
97 |
+
previous_seqlen = previous_seqlen_old + self.seq_len
|
98 |
+
# calculate cu_seqlens_k
|
99 |
+
cu_seqlens_k = torch.cumsum(previous_seqlen, dim=0)
|
100 |
+
self.cu_seqlens_k = torch.cat([torch.Tensor([0]).to(cu_seqlens_k.device), cu_seqlens_k], dim=0).int()
|
101 |
+
self.cu_maxlen_k = previous_seqlen.cpu().max()
|
102 |
+
self.previous_seqlen = previous_seqlen
|
103 |
+
ret_cache = {
|
104 |
+
"previous_seqlen": previous_seqlen
|
105 |
+
}
|
106 |
+
return ret_cache
|
107 |
+
|
108 |
+
def update_incremental_state(self, reserve_kv_cache_tokens=0, max_kv_cache_tokens=900, condition_cache={"previous_seqlen"}):
|
109 |
+
|
110 |
+
assert reserve_kv_cache_tokens <= max_kv_cache_tokens, "reserve_kv_cache_tokens should be less than or equal to max_kv_cache_tokens"
|
111 |
+
|
112 |
+
for layer_idx, layer_cache in self.incremental_state.items():
|
113 |
+
# update attention kv cache
|
114 |
+
layer_cache["attn_kvcache"]["prev_k"] = layer_cache["attn_kvcache"]["cur_k"]
|
115 |
+
layer_cache["attn_kvcache"]["prev_v"] = layer_cache["attn_kvcache"]["cur_v"]
|
116 |
+
|
117 |
+
self.kv_cache_tokens = layer_cache["attn_kvcache"]["prev_k"].shape[1]
|
118 |
+
|
119 |
+
if self.kv_cache_tokens > max_kv_cache_tokens:
|
120 |
+
# drop old tokens from reserve kv cache tokens to max_kv_cache_tokens
|
121 |
+
reserve_tokens_excludeprompt = max_kv_cache_tokens - reserve_kv_cache_tokens
|
122 |
+
|
123 |
+
if reserve_kv_cache_tokens == 0:
|
124 |
+
layer_cache["attn_kvcache"]["prev_k"] = layer_cache["attn_kvcache"]["prev_k"][:, -reserve_tokens_excludeprompt:]
|
125 |
+
layer_cache["attn_kvcache"]["prev_v"] = layer_cache["attn_kvcache"]["prev_v"][:, -reserve_tokens_excludeprompt:]
|
126 |
+
elif reserve_tokens_excludeprompt == 0:
|
127 |
+
layer_cache["attn_kvcache"]["prev_k"] = layer_cache["attn_kvcache"]["prev_k"][:, :reserve_kv_cache_tokens]
|
128 |
+
layer_cache["attn_kvcache"]["prev_v"] = layer_cache["attn_kvcache"]["prev_v"][:, :reserve_kv_cache_tokens]
|
129 |
+
else:
|
130 |
+
layer_cache["attn_kvcache"]["prev_k"] = torch.cat([
|
131 |
+
layer_cache["attn_kvcache"]["prev_k"][:, :reserve_kv_cache_tokens],
|
132 |
+
layer_cache["attn_kvcache"]["prev_k"][:, -reserve_tokens_excludeprompt:]
|
133 |
+
], dim=1)
|
134 |
+
|
135 |
+
layer_cache["attn_kvcache"]["prev_v"] = torch.cat([
|
136 |
+
layer_cache["attn_kvcache"]["prev_v"][:, :reserve_kv_cache_tokens],
|
137 |
+
layer_cache["attn_kvcache"]["prev_v"][:, -reserve_tokens_excludeprompt:]
|
138 |
+
], dim=1)
|
139 |
+
|
140 |
+
|
141 |
+
bsz = layer_cache["attn_kvcache"]["prev_k"].shape[0]
|
142 |
+
self.previous_seqlen = torch.Tensor([layer_cache["attn_kvcache"]["prev_k"].shape[1] for i in range(bsz)]).to(layer_cache["attn_kvcache"]["prev_k"].device).long()
|
143 |
+
condition_cache["previous_seqlen"] = self.previous_seqlen
|
144 |
+
self.kv_cache_tokens = layer_cache["attn_kvcache"]["prev_k"].shape[1]
|
145 |
+
|
146 |
+
# clear current cache
|
147 |
+
layer_cache["attn_kvcache"].pop("cur_k")
|
148 |
+
layer_cache["attn_kvcache"].pop("cur_v")
|
149 |
+
|
150 |
+
|
151 |
+
def forward(self, t, x, args=None):
|
152 |
+
# t = torch.tensor([t * 1000] * x.shape[0], device=x.device, dtype=x.dtype).long()
|
153 |
+
t = get_cached_zeros(x.shape[0], device=x.device, dtype=torch.long) + (t * 1000).long()
|
154 |
+
|
155 |
+
if self.use_cfg:
|
156 |
+
raise NotImplementedError("cfg is not supported in streaming detokenizer.")
|
157 |
+
else:
|
158 |
+
pred_noise = self.net(x=x, condition=self.x_cond, t=t, position_ids=self.position_ids,
|
159 |
+
cu_seqlens=self.cu_seqlens, cu_maxlen=self.cu_maxlen,
|
160 |
+
cu_seqlens_k=self.cu_seqlens_k, cu_maxlen_k=self.cu_maxlen_k,
|
161 |
+
incremental_state=self.incremental_state, nopadding=True,
|
162 |
+
mask=None, seq_len=None
|
163 |
+
)
|
164 |
+
return pred_noise
|
modules/audio_detokenizer/flow_matching/scheduler.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from abc import abstractmethod, ABC
|
3 |
+
try:
|
4 |
+
from torchdyn.core import NeuralODE
|
5 |
+
NEURALODE_INSTALLED = True
|
6 |
+
except ImportError:
|
7 |
+
NEURALODE_INSTALLED = False
|
8 |
+
|
9 |
+
class SchedulerBase(ABC):
|
10 |
+
def __init__(self) -> None:
|
11 |
+
pass
|
12 |
+
|
13 |
+
@abstractmethod
|
14 |
+
def set_timesteps(self):
|
15 |
+
pass
|
16 |
+
|
17 |
+
@abstractmethod
|
18 |
+
def step(self):
|
19 |
+
pass
|
20 |
+
|
21 |
+
@abstractmethod
|
22 |
+
def add_noise(self):
|
23 |
+
pass
|
24 |
+
|
25 |
+
|
26 |
+
class StreamingFlowMatchingScheduler(SchedulerBase):
|
27 |
+
def __init__(self, timesteps=1000, sigma_min=1e-4,
|
28 |
+
) -> None:
|
29 |
+
super().__init__()
|
30 |
+
|
31 |
+
self.sigma_min = sigma_min
|
32 |
+
self.timesteps = timesteps
|
33 |
+
self.t_min = 0
|
34 |
+
self.t_max = 1 - self.sigma_min
|
35 |
+
|
36 |
+
self.neural_ode = None
|
37 |
+
|
38 |
+
|
39 |
+
def set_timesteps(self, timesteps=15):
|
40 |
+
self.timesteps = timesteps
|
41 |
+
|
42 |
+
def step(self, xt, predicted_v):
|
43 |
+
|
44 |
+
h = (self.t_max - self.t_min) / self.timesteps
|
45 |
+
h = h * torch.ones(xt.shape[0], dtype=xt.dtype, device=xt.device)
|
46 |
+
|
47 |
+
xt = xt + h * predicted_v
|
48 |
+
return xt
|
49 |
+
|
50 |
+
def sample(self, ode_wrapper, time_steps, xt, verbose=False, x0=None):
|
51 |
+
h = (self.t_max - self.t_min) / self.timesteps
|
52 |
+
h = h * torch.ones(xt.shape[0], dtype=xt.dtype, device=xt.device)
|
53 |
+
|
54 |
+
if verbose:
|
55 |
+
gt_v = x0 - xt
|
56 |
+
|
57 |
+
for t in time_steps:
|
58 |
+
predicted_v = ode_wrapper(t, xt)
|
59 |
+
if verbose:
|
60 |
+
dist = torch.mean(torch.nn.functional.l1_loss(gt_v, predicted_v))
|
61 |
+
print("Time: {}, Distance: {}".format(t, dist))
|
62 |
+
xt = xt + h * predicted_v
|
63 |
+
return xt
|
64 |
+
|
65 |
+
def sample_by_neuralode(self, ode_wrapper, time_steps, xt, verbose=False, x0=None):
|
66 |
+
if not NEURALODE_INSTALLED:
|
67 |
+
raise ImportError("NeuralODE is not installed, please install it first.")
|
68 |
+
|
69 |
+
if self.neural_ode is None:
|
70 |
+
self.neural_ode = NeuralODE(ode_wrapper, solver='euler', sensitivity="adjoint", atol=self.sigma_min, rtol=self.sigma_min)
|
71 |
+
|
72 |
+
eval_points, traj = self.neural_ode(xt, time_steps)
|
73 |
+
return traj[-1]
|
74 |
+
|
75 |
+
|
76 |
+
def add_noise(self, original_samples: torch.FloatTensor,
|
77 |
+
noise: torch.FloatTensor,
|
78 |
+
timesteps: torch.IntTensor,):
|
79 |
+
ut = original_samples - (1 - self.sigma_min) * noise # 和ut的梯度没关系
|
80 |
+
t_unsqueeze = timesteps.unsqueeze(1).unsqueeze(1).float() / self.timesteps
|
81 |
+
x_noisy = t_unsqueeze * original_samples + (1. - (1 - self.sigma_min) * t_unsqueeze) * noise
|
82 |
+
return x_noisy, ut
|
modules/audio_detokenizer/semantic_fm_prefix_streaming.py
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import yaml
|
2 |
+
import logging
|
3 |
+
import time
|
4 |
+
|
5 |
+
import os
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from modules.audio_detokenizer.flow_matching.ode_wrapper import StreamingODEWrapperForPrefix
|
9 |
+
from modules.audio_detokenizer.flow_matching.model import DiTPrefix
|
10 |
+
from modules.audio_detokenizer.flow_matching.scheduler import StreamingFlowMatchingScheduler
|
11 |
+
|
12 |
+
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
|
16 |
+
class StreamingSemanticFMWrapper:
|
17 |
+
def __init__(self, speech_model: DiTPrefix, max_kv_cache_tokens=900, max_prompt_chunk=2,
|
18 |
+
use_cfg=True, use_cfg_rescale=True, cfg_init=1.5, cfg_scale=7.5, cfg_schedule="linear", cfg_token_id=0,
|
19 |
+
normalize_mel=False, mel_mean=None, mel_std=None, device: torch.device = torch.device("cpu")) -> None:
|
20 |
+
|
21 |
+
self.dtype = torch.bfloat16
|
22 |
+
self.speech_model = speech_model.to(device).to(self.dtype)
|
23 |
+
self.speech_model = self.speech_model.eval()
|
24 |
+
self.device = device
|
25 |
+
self.normalize_mel = normalize_mel
|
26 |
+
self.mel_mean = mel_mean
|
27 |
+
self.mel_std = mel_std
|
28 |
+
|
29 |
+
self.use_cfg = use_cfg
|
30 |
+
self.use_cfg_rescale = use_cfg_rescale
|
31 |
+
self.cfg_init = cfg_init
|
32 |
+
self.cfg_scale = cfg_scale
|
33 |
+
self.cfg_schedule = cfg_schedule
|
34 |
+
|
35 |
+
self.incremental_state = {}
|
36 |
+
self.condition_cache = {"previous_seqlen": 0}
|
37 |
+
|
38 |
+
logger.info(f">>> SemanticFMWrapper initialized with use_cfg={use_cfg}, use_cfg_rescale={use_cfg_rescale}, cfg_init={cfg_init}, cfg_scale={cfg_scale}, cfg_schedule={cfg_schedule}")
|
39 |
+
|
40 |
+
self.scheduler = StreamingFlowMatchingScheduler()
|
41 |
+
self.ode_wrapper = StreamingODEWrapperForPrefix(net=self.speech_model, x_mask=None, x_cond=None,
|
42 |
+
use_cfg=use_cfg, use_cfg_rescale=use_cfg_rescale, cfg_init=cfg_init, cfg_scale=cfg_scale, cfg_schedule=cfg_schedule, cfg_token_id=cfg_token_id)
|
43 |
+
|
44 |
+
self.max_kv_cache_tokens = max_kv_cache_tokens
|
45 |
+
self.max_prompt_chunk = max_prompt_chunk
|
46 |
+
self.reserve_kv_cache_tokens = 0
|
47 |
+
|
48 |
+
@torch.inference_mode()
|
49 |
+
def infer_chunk(self, xt_chunk, semantic_tokens_chunk, start_position_id,
|
50 |
+
cache = None, look_ahead_tokens=0,
|
51 |
+
ode_steps=15, verbose=False, ode_solver="neural_ode_euler"):
|
52 |
+
"""
|
53 |
+
semantic_tokens: [T_1], torch.LongTensor
|
54 |
+
xt: [T_2, 80], torch.Tensor, DO NOT normalize it outside
|
55 |
+
ode_steps: int, number of ode steps, default 15
|
56 |
+
verbose: bool, default False
|
57 |
+
ode_solver: str, ode solver, expected in ("neural_ode_euler", "naive_euler"), default "neural_ode_euler"
|
58 |
+
"""
|
59 |
+
bs = 1
|
60 |
+
|
61 |
+
self.scheduler.set_timesteps(ode_steps)
|
62 |
+
|
63 |
+
semantic_tokens_chunk = semantic_tokens_chunk.unsqueeze(0).to(self.device)
|
64 |
+
xt_chunk = xt_chunk.unsqueeze(0).to(self.device).to(self.dtype)
|
65 |
+
|
66 |
+
t_span = torch.linspace(0, 1, self.scheduler.timesteps)
|
67 |
+
|
68 |
+
x_mask = torch.zeros(bs, xt_chunk.shape[1], device=self.device).bool()
|
69 |
+
|
70 |
+
cache_ret = self.ode_wrapper.set_conditions(x_mask=x_mask, x_cond=semantic_tokens_chunk, start_position_id=start_position_id, cache=self.condition_cache)
|
71 |
+
|
72 |
+
if verbose:
|
73 |
+
t_start = time.time()
|
74 |
+
if ode_solver == "neural_ode_euler":
|
75 |
+
x_t = self.scheduler.sample_by_neuralode(self.ode_wrapper, time_steps=t_span, xt=xt_chunk, verbose=False)
|
76 |
+
elif ode_solver == "naive_euler":
|
77 |
+
x_t = self.scheduler.sample(ode_wrapper=self.ode_wrapper, time_steps=t_span, xt=xt_chunk, verbose=False)
|
78 |
+
else:
|
79 |
+
raise NotImplementedError("ode_solver should be in ('neural_ode_euler', 'naive_euler')")
|
80 |
+
|
81 |
+
if look_ahead_tokens > 0:
|
82 |
+
semantic_tokens_left = semantic_tokens_chunk.view(-1)[-look_ahead_tokens:]
|
83 |
+
cache["semantic_token"] = semantic_tokens_left
|
84 |
+
x_t_ret = x_t[:, :-look_ahead_tokens, :]
|
85 |
+
else:
|
86 |
+
x_t_ret = x_t
|
87 |
+
|
88 |
+
if look_ahead_tokens > 0:
|
89 |
+
x_mask = torch.zeros(bs, xt_chunk.shape[1] - look_ahead_tokens, device=self.device).bool()
|
90 |
+
self.condition_cache = self.ode_wrapper.set_conditions(x_mask=x_mask, x_cond=semantic_tokens_chunk[:, :-look_ahead_tokens], start_position_id=start_position_id, cache=self.condition_cache)
|
91 |
+
self.ode_wrapper(torch.Tensor([0.999]).to(x_t_ret.device), x_t_ret)
|
92 |
+
else:
|
93 |
+
self.condition_cache = cache_ret
|
94 |
+
|
95 |
+
if verbose:
|
96 |
+
t_end = time.time()
|
97 |
+
logger.info(f"[ODE Chunk] Time cost: {t_end - t_start}")
|
98 |
+
|
99 |
+
if self.normalize_mel:
|
100 |
+
x_t_ret = x_t_ret * self.mel_std + self.mel_mean
|
101 |
+
return x_t_ret.squeeze(0)
|
102 |
+
|
103 |
+
|
104 |
+
@torch.inference_mode()
|
105 |
+
def infer_mel(self, semantic_tokens, ode_steps=15, chunk_size=150, verbose=False, ode_solver="neural_ode_euler"):
|
106 |
+
"""
|
107 |
+
semantic_tokens: [T_1], torch.LongTensor
|
108 |
+
prompt: [T_2, 80], torch.Tensor, DO NOT normalize it outside
|
109 |
+
prompt_semantic_tokens, [T_2], torch.LongTensor
|
110 |
+
ode_steps: int, number of ode steps, default 15
|
111 |
+
verbose: bool, default False
|
112 |
+
ode_solver: str, ode solver, expected in ("neural_ode_euler", "naive_euler"), default "neural_ode_euler"
|
113 |
+
"""
|
114 |
+
assert semantic_tokens.dim() == 1
|
115 |
+
|
116 |
+
x_t = torch.randn(semantic_tokens.shape[0], 80).to(self.device).to(self.dtype)
|
117 |
+
|
118 |
+
seq_len = semantic_tokens.shape[0]
|
119 |
+
|
120 |
+
num_chunks = seq_len // chunk_size
|
121 |
+
if seq_len % chunk_size != 0:
|
122 |
+
num_chunks += 1
|
123 |
+
|
124 |
+
x_pred_collect = []
|
125 |
+
|
126 |
+
if verbose:
|
127 |
+
t_start = time.time()
|
128 |
+
|
129 |
+
for chunk_id in range(num_chunks):
|
130 |
+
start = chunk_id * chunk_size
|
131 |
+
end = min(start + chunk_size, seq_len)
|
132 |
+
semantic_tokens_chunk = semantic_tokens[start:end]
|
133 |
+
x_t_chunk = x_t[start:end, :]
|
134 |
+
|
135 |
+
x_pred = self.infer_chunk(xt_chunk=x_t_chunk, semantic_tokens_chunk=semantic_tokens_chunk, start_position_id=self.start_position_id,
|
136 |
+
ode_steps=ode_steps, verbose=verbose, ode_solver=ode_solver)
|
137 |
+
self.start_position_id += end - start
|
138 |
+
self.update_incremental_state()
|
139 |
+
|
140 |
+
x_pred_collect.append(x_pred)
|
141 |
+
|
142 |
+
if verbose:
|
143 |
+
t_end = time.time()
|
144 |
+
logger.info(f"[ODE] Time cost: {t_end - t_start}")
|
145 |
+
|
146 |
+
x_pred = torch.cat(x_pred_collect, dim=0)
|
147 |
+
|
148 |
+
return x_pred
|
149 |
+
|
150 |
+
def clear_all_states(self):
|
151 |
+
self.start_position_id = 0
|
152 |
+
self.condition_cache = {"previous_seqlen": 0}
|
153 |
+
self.ode_wrapper.clear_all_states()
|
154 |
+
|
155 |
+
def state_dict(self):
|
156 |
+
return {
|
157 |
+
"start_position_id": self.start_position_id,
|
158 |
+
"ode_wrapper": self.ode_wrapper.state_dict(),
|
159 |
+
"condition_cache": self.condition_cache
|
160 |
+
}
|
161 |
+
|
162 |
+
def load_state_dict(self, state_dict):
|
163 |
+
if state_dict is not None:
|
164 |
+
self.start_position_id = state_dict["start_position_id"]
|
165 |
+
self.ode_wrapper.load_state_dict(state_dict["ode_wrapper"])
|
166 |
+
self.condition_cache = state_dict["condition_cache"]
|
167 |
+
|
168 |
+
def update_incremental_state(self):
|
169 |
+
self.ode_wrapper.update_incremental_state(reserve_kv_cache_tokens=0, max_kv_cache_tokens=self.max_kv_cache_tokens, condition_cache=self.condition_cache)
|
170 |
+
|
171 |
+
@torch.inference_mode()
|
172 |
+
def prefill(self, mel, semantic_token, chunk_size=150, verbose=False):
|
173 |
+
"""
|
174 |
+
mel: [T, 80], torch.Tensor
|
175 |
+
semantic_token: [T], torch.LongTensor
|
176 |
+
chunk_size: int, default 150
|
177 |
+
"""
|
178 |
+
assert mel.dim() == 2
|
179 |
+
assert semantic_token.dim() == 1
|
180 |
+
assert semantic_token.shape[0] == mel.shape[0], "Semantic token and mel shape mismatch"
|
181 |
+
seq_len = mel.shape[0]
|
182 |
+
num_chunks = min(seq_len // chunk_size, self.max_prompt_chunk)
|
183 |
+
start_pos = seq_len - num_chunks * chunk_size
|
184 |
+
|
185 |
+
res_mel = mel[:start_pos, :]
|
186 |
+
res_semantic_token = semantic_token[:start_pos]
|
187 |
+
self.prefill_chunk(res_mel, res_semantic_token, start_position_id=self.start_position_id)
|
188 |
+
self.start_position_id += start_pos
|
189 |
+
self.update_incremental_state()
|
190 |
+
self.reserve_kv_cache_tokens += self.ode_wrapper.kv_cache_tokens
|
191 |
+
|
192 |
+
if verbose:
|
193 |
+
logger.info("Prefilling prompt with {} chunks".format(num_chunks))
|
194 |
+
start_time = time.time()
|
195 |
+
|
196 |
+
for chunk_id in range(num_chunks):
|
197 |
+
start = start_pos + chunk_id * chunk_size
|
198 |
+
end = start + chunk_size
|
199 |
+
mel_chunk = mel[start:end, :]
|
200 |
+
semantic_token_chunk = semantic_token[start:end]
|
201 |
+
|
202 |
+
self.prefill_chunk(mel_chunk, semantic_token_chunk, start_position_id=self.start_position_id)
|
203 |
+
self.start_position_id += end - start
|
204 |
+
|
205 |
+
self.update_incremental_state()
|
206 |
+
self.reserve_kv_cache_tokens += self.ode_wrapper.kv_cache_tokens
|
207 |
+
|
208 |
+
|
209 |
+
if verbose:
|
210 |
+
logger.info("Prefilling done in {:.2f} seconds".format(time.time() - start_time))
|
211 |
+
|
212 |
+
def prefill_chunk(self, mel_chunk, semantic_tokens_chunk, start_position_id=0):
|
213 |
+
"""
|
214 |
+
mel_chunk: [T, 80], torch.Tensor, T is the chunk size
|
215 |
+
semantic_tokens_chunk: [T], torch.LongTensor
|
216 |
+
start_position_id: int, default 0
|
217 |
+
"""
|
218 |
+
bs = 1
|
219 |
+
|
220 |
+
semantic_tokens_chunk = semantic_tokens_chunk.unsqueeze(0).to(self.device)
|
221 |
+
mel_chunk = mel_chunk.unsqueeze(0).to(self.device).to(self.dtype)
|
222 |
+
|
223 |
+
if self.normalize_mel:
|
224 |
+
mel_chunk = (mel_chunk - self.mel_mean) / self.mel_std
|
225 |
+
|
226 |
+
x_mask = torch.zeros(bs, mel_chunk.shape[1], device=self.device).bool()
|
227 |
+
|
228 |
+
self.condition_cache = self.ode_wrapper.set_conditions(x_mask=x_mask, x_cond=semantic_tokens_chunk, start_position_id=start_position_id, cache=self.condition_cache)
|
229 |
+
|
230 |
+
x_t = torch.Tensor([0.999]).to(self.device)
|
231 |
+
|
232 |
+
self.ode_wrapper(x_t, mel_chunk)
|
233 |
+
|
234 |
+
|
235 |
+
@classmethod
|
236 |
+
def from_pretrained(cls, model_config, ckpt_path, device, max_prompt_chunk=2, max_kv_cache_tokens=900, use_cfg=True, use_cfg_rescale=True, cfg_init=1.5, cfg_scale=7.5, cfg_schedule="linear"):
|
237 |
+
|
238 |
+
# open yaml file
|
239 |
+
with open(model_config, 'r') as f:
|
240 |
+
config = yaml.safe_load(f)
|
241 |
+
model_config = config["model"]["dit"]
|
242 |
+
dit = DiTPrefix(
|
243 |
+
input_size=model_config["input_size"],
|
244 |
+
semantic_vocab_size=model_config["semantic_vocab_size"] + 1,
|
245 |
+
hidden_size=model_config["hidden_size"],
|
246 |
+
depth=model_config["depth"],
|
247 |
+
num_heads=model_config["num_heads"],
|
248 |
+
mlp_ratio=model_config["mlp_ratio"],
|
249 |
+
ffn_type=model_config.get("ffn_type", "conv1d_conv1d"),
|
250 |
+
ffn_gated_glu=model_config.get("ffn_gated_glu", True),
|
251 |
+
ffn_act_layer=model_config.get("ffn_act_layer", "gelu"),
|
252 |
+
ffn_conv_kernel_size=model_config.get("ffn_conv_kernel_size", 5),
|
253 |
+
|
254 |
+
use_rope=model_config.get("use_rope", False),
|
255 |
+
rope_params=model_config.get("rope_params", { "max_position_embeddings": 4096,"rope_base": 10000,"rope_interpolation_factor": 1 }),
|
256 |
+
|
257 |
+
position_embedding_type=model_config["position_embedding_type"],
|
258 |
+
max_seq_len=model_config["max_seq_len"],
|
259 |
+
output_size=model_config["input_size"],
|
260 |
+
prompt_cfg_dropout=0
|
261 |
+
)
|
262 |
+
cfg_semantic_token_id = model_config["semantic_vocab_size"]
|
263 |
+
|
264 |
+
# load state_dict
|
265 |
+
state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)["state_dict"]
|
266 |
+
speech_model_params = {k.replace("speech_model.", ""): v for k, v in state_dict.items() if "speech_model" in k}
|
267 |
+
dit.load_state_dict(speech_model_params, strict=True)
|
268 |
+
logger.info(f">>> Loaded checkpoint from {ckpt_path}")
|
269 |
+
|
270 |
+
return cls(speech_model=dit, device=device, normalize_mel=config["normalize_mel"], mel_mean=config["mel_mean"], mel_std=config["mel_std"], max_prompt_chunk=max_prompt_chunk, max_kv_cache_tokens=max_kv_cache_tokens,
|
271 |
+
use_cfg=use_cfg, use_cfg_rescale=use_cfg_rescale, cfg_init=cfg_init, cfg_scale=cfg_scale, cfg_schedule=cfg_schedule, cfg_token_id=cfg_semantic_token_id)
|
272 |
+
|
273 |
+
|
modules/audio_detokenizer/vocoder/activations.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn, sin, pow
|
3 |
+
from torch.nn import Parameter
|
4 |
+
|
5 |
+
|
6 |
+
class Snake(nn.Module):
|
7 |
+
"""
|
8 |
+
Implementation of a sine-based periodic activation function
|
9 |
+
Shape:
|
10 |
+
- Input: (B, C, T)
|
11 |
+
- Output: (B, C, T), same shape as the input
|
12 |
+
Parameters:
|
13 |
+
- alpha - trainable parameter
|
14 |
+
References:
|
15 |
+
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
16 |
+
https://arxiv.org/abs/2006.08195
|
17 |
+
Examples:
|
18 |
+
>>> a1 = snake(256)
|
19 |
+
>>> x = torch.randn(256)
|
20 |
+
>>> x = a1(x)
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(
|
24 |
+
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
|
25 |
+
):
|
26 |
+
"""
|
27 |
+
Initialization.
|
28 |
+
INPUT:
|
29 |
+
- in_features: shape of the input
|
30 |
+
- alpha: trainable parameter
|
31 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
32 |
+
alpha will be trained along with the rest of your model.
|
33 |
+
"""
|
34 |
+
super(Snake, self).__init__()
|
35 |
+
self.in_features = in_features
|
36 |
+
|
37 |
+
# Initialize alpha
|
38 |
+
self.alpha_logscale = alpha_logscale
|
39 |
+
if self.alpha_logscale: # Log scale alphas initialized to zeros
|
40 |
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
41 |
+
else: # Linear scale alphas initialized to ones
|
42 |
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
43 |
+
|
44 |
+
self.alpha.requires_grad = alpha_trainable
|
45 |
+
|
46 |
+
self.no_div_by_zero = 0.000000001
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
"""
|
50 |
+
Forward pass of the function.
|
51 |
+
Applies the function to the input elementwise.
|
52 |
+
Snake ∶= x + 1/a * sin^2 (xa)
|
53 |
+
"""
|
54 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # Line up with x to [B, C, T]
|
55 |
+
if self.alpha_logscale:
|
56 |
+
alpha = torch.exp(alpha)
|
57 |
+
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
58 |
+
|
59 |
+
return x
|
60 |
+
|
61 |
+
|
62 |
+
class SnakeBeta(nn.Module):
|
63 |
+
"""
|
64 |
+
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
65 |
+
Shape:
|
66 |
+
- Input: (B, C, T)
|
67 |
+
- Output: (B, C, T), same shape as the input
|
68 |
+
Parameters:
|
69 |
+
- alpha - trainable parameter that controls frequency
|
70 |
+
- beta - trainable parameter that controls magnitude
|
71 |
+
References:
|
72 |
+
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
73 |
+
https://arxiv.org/abs/2006.08195
|
74 |
+
Examples:
|
75 |
+
>>> a1 = snakebeta(256)
|
76 |
+
>>> x = torch.randn(256)
|
77 |
+
>>> x = a1(x)
|
78 |
+
"""
|
79 |
+
|
80 |
+
def __init__(
|
81 |
+
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
|
82 |
+
):
|
83 |
+
"""
|
84 |
+
Initialization.
|
85 |
+
INPUT:
|
86 |
+
- in_features: shape of the input
|
87 |
+
- alpha - trainable parameter that controls frequency
|
88 |
+
- beta - trainable parameter that controls magnitude
|
89 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
90 |
+
beta is initialized to 1 by default, higher values = higher-magnitude.
|
91 |
+
alpha will be trained along with the rest of your model.
|
92 |
+
"""
|
93 |
+
super(SnakeBeta, self).__init__()
|
94 |
+
self.in_features = in_features
|
95 |
+
|
96 |
+
# Initialize alpha
|
97 |
+
self.alpha_logscale = alpha_logscale
|
98 |
+
if self.alpha_logscale: # Log scale alphas initialized to zeros
|
99 |
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
100 |
+
self.beta = Parameter(torch.zeros(in_features) * alpha)
|
101 |
+
else: # Linear scale alphas initialized to ones
|
102 |
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
103 |
+
self.beta = Parameter(torch.ones(in_features) * alpha)
|
104 |
+
|
105 |
+
self.alpha.requires_grad = alpha_trainable
|
106 |
+
self.beta.requires_grad = alpha_trainable
|
107 |
+
|
108 |
+
self.no_div_by_zero = 0.000000001
|
109 |
+
|
110 |
+
def forward(self, x):
|
111 |
+
"""
|
112 |
+
Forward pass of the function.
|
113 |
+
Applies the function to the input elementwise.
|
114 |
+
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
115 |
+
"""
|
116 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # Line up with x to [B, C, T]
|
117 |
+
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
118 |
+
if self.alpha_logscale:
|
119 |
+
alpha = torch.exp(alpha)
|
120 |
+
beta = torch.exp(beta)
|
121 |
+
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
122 |
+
|
123 |
+
return x
|
modules/audio_detokenizer/vocoder/alias_free_activation/__init__.py
ADDED
File without changes
|
modules/audio_detokenizer/vocoder/alias_free_activation/cuda/__init__.py
ADDED
File without changes
|
modules/audio_detokenizer/vocoder/alias_free_activation/cuda/activation1d.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 NVIDIA CORPORATION.
|
2 |
+
# Licensed under the MIT license.
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from ..torch.resample import UpSample1d, DownSample1d
|
7 |
+
|
8 |
+
# load fused CUDA kernel: this enables importing anti_alias_activation_cuda
|
9 |
+
from modules.audio_detokenizer.vocoder.alias_free_activation.cuda import load
|
10 |
+
|
11 |
+
anti_alias_activation_cuda = load.load()
|
12 |
+
|
13 |
+
|
14 |
+
class FusedAntiAliasActivation(torch.autograd.Function):
|
15 |
+
"""
|
16 |
+
Assumes filter size 12, replication padding on upsampling/downsampling, and logscale alpha/beta parameters as inputs.
|
17 |
+
The hyperparameters are hard-coded in the kernel to maximize speed.
|
18 |
+
NOTE: The fused kenrel is incorrect for Activation1d with different hyperparameters.
|
19 |
+
"""
|
20 |
+
|
21 |
+
@staticmethod
|
22 |
+
def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta):
|
23 |
+
activation_results = anti_alias_activation_cuda.forward(
|
24 |
+
inputs, up_ftr, down_ftr, alpha, beta
|
25 |
+
)
|
26 |
+
|
27 |
+
return activation_results
|
28 |
+
|
29 |
+
@staticmethod
|
30 |
+
def backward(ctx, output_grads):
|
31 |
+
raise NotImplementedError
|
32 |
+
return output_grads, None, None
|
33 |
+
|
34 |
+
|
35 |
+
class Activation1d(nn.Module):
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
activation,
|
39 |
+
up_ratio: int = 2,
|
40 |
+
down_ratio: int = 2,
|
41 |
+
up_kernel_size: int = 12,
|
42 |
+
down_kernel_size: int = 12,
|
43 |
+
fused: bool = True,
|
44 |
+
):
|
45 |
+
super().__init__()
|
46 |
+
self.up_ratio = up_ratio
|
47 |
+
self.down_ratio = down_ratio
|
48 |
+
self.act = activation
|
49 |
+
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
50 |
+
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
51 |
+
|
52 |
+
self.fused = fused # Whether to use fused CUDA kernel or not
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
if not self.fused:
|
56 |
+
x = self.upsample(x)
|
57 |
+
x = self.act(x)
|
58 |
+
x = self.downsample(x)
|
59 |
+
return x
|
60 |
+
else:
|
61 |
+
if self.act.__class__.__name__ == "Snake":
|
62 |
+
beta = self.act.alpha.data # Snake uses same params for alpha and beta
|
63 |
+
else:
|
64 |
+
beta = (
|
65 |
+
self.act.beta.data
|
66 |
+
) # Snakebeta uses different params for alpha and beta
|
67 |
+
alpha = self.act.alpha.data
|
68 |
+
if (
|
69 |
+
not self.act.alpha_logscale
|
70 |
+
): # Exp baked into cuda kernel, cancel it out with a log
|
71 |
+
alpha = torch.log(alpha)
|
72 |
+
beta = torch.log(beta)
|
73 |
+
|
74 |
+
x = FusedAntiAliasActivation.apply(
|
75 |
+
x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta
|
76 |
+
)
|
77 |
+
return x
|
modules/audio_detokenizer/vocoder/alias_free_activation/cuda/anti_alias_activation.cpp
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* coding=utf-8
|
2 |
+
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
#include <torch/extension.h>
|
18 |
+
|
19 |
+
extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta);
|
20 |
+
|
21 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
22 |
+
m.def("forward", &fwd_cuda, "Anti-Alias Activation forward (CUDA)");
|
23 |
+
}
|
modules/audio_detokenizer/vocoder/alias_free_activation/cuda/anti_alias_activation_cuda.cu
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* coding=utf-8
|
2 |
+
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
#include <ATen/ATen.h>
|
18 |
+
#include <cuda.h>
|
19 |
+
#include <cuda_runtime.h>
|
20 |
+
#include <cuda_fp16.h>
|
21 |
+
#include <cuda_profiler_api.h>
|
22 |
+
#include <ATen/cuda/CUDAContext.h>
|
23 |
+
#include <torch/extension.h>
|
24 |
+
#include "type_shim.h"
|
25 |
+
#include <assert.h>
|
26 |
+
#include <cfloat>
|
27 |
+
#include <limits>
|
28 |
+
#include <stdint.h>
|
29 |
+
#include <c10/macros/Macros.h>
|
30 |
+
|
31 |
+
namespace
|
32 |
+
{
|
33 |
+
// Hard-coded hyperparameters
|
34 |
+
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
35 |
+
constexpr int ELEMENTS_PER_LDG_STG = 1; //(WARP_ITERATIONS < 4) ? 1 : 4;
|
36 |
+
constexpr int BUFFER_SIZE = 32;
|
37 |
+
constexpr int FILTER_SIZE = 12;
|
38 |
+
constexpr int HALF_FILTER_SIZE = 6;
|
39 |
+
constexpr int UPSAMPLE_REPLICATION_PAD = 5; // 5 on each side, matching torch impl
|
40 |
+
constexpr int DOWNSAMPLE_REPLICATION_PAD_LEFT = 5; // matching torch impl
|
41 |
+
constexpr int DOWNSAMPLE_REPLICATION_PAD_RIGHT = 6; // matching torch impl
|
42 |
+
|
43 |
+
template <typename input_t, typename output_t, typename acc_t>
|
44 |
+
__global__ void anti_alias_activation_forward(
|
45 |
+
output_t *dst,
|
46 |
+
const input_t *src,
|
47 |
+
const input_t *up_ftr,
|
48 |
+
const input_t *down_ftr,
|
49 |
+
const input_t *alpha,
|
50 |
+
const input_t *beta,
|
51 |
+
int batch_size,
|
52 |
+
int channels,
|
53 |
+
int seq_len)
|
54 |
+
{
|
55 |
+
// Up and downsample filters
|
56 |
+
input_t up_filter[FILTER_SIZE];
|
57 |
+
input_t down_filter[FILTER_SIZE];
|
58 |
+
|
59 |
+
// Load data from global memory including extra indices reserved for replication paddings
|
60 |
+
input_t elements[2 * FILTER_SIZE + 2 * BUFFER_SIZE + 2 * UPSAMPLE_REPLICATION_PAD] = {0};
|
61 |
+
input_t intermediates[2 * FILTER_SIZE + 2 * BUFFER_SIZE + DOWNSAMPLE_REPLICATION_PAD_LEFT + DOWNSAMPLE_REPLICATION_PAD_RIGHT] = {0};
|
62 |
+
|
63 |
+
// Output stores downsampled output before writing to dst
|
64 |
+
output_t output[BUFFER_SIZE];
|
65 |
+
|
66 |
+
// blockDim/threadIdx = (128, 1, 1)
|
67 |
+
// gridDim/blockIdx = (seq_blocks, channels, batches)
|
68 |
+
int block_offset = (blockIdx.x * 128 * BUFFER_SIZE + seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
|
69 |
+
int local_offset = threadIdx.x * BUFFER_SIZE;
|
70 |
+
int seq_offset = blockIdx.x * 128 * BUFFER_SIZE + local_offset;
|
71 |
+
|
72 |
+
// intermediate have double the seq_len
|
73 |
+
int intermediate_local_offset = threadIdx.x * BUFFER_SIZE * 2;
|
74 |
+
int intermediate_seq_offset = blockIdx.x * 128 * BUFFER_SIZE * 2 + intermediate_local_offset;
|
75 |
+
|
76 |
+
// Get values needed for replication padding before moving pointer
|
77 |
+
const input_t *right_most_pntr = src + (seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
|
78 |
+
input_t seq_left_most_value = right_most_pntr[0];
|
79 |
+
input_t seq_right_most_value = right_most_pntr[seq_len - 1];
|
80 |
+
|
81 |
+
// Move src and dst pointers
|
82 |
+
src += block_offset + local_offset;
|
83 |
+
dst += block_offset + local_offset;
|
84 |
+
|
85 |
+
// Alpha and beta values for snake activatons. Applies exp by default
|
86 |
+
alpha = alpha + blockIdx.y;
|
87 |
+
input_t alpha_val = expf(alpha[0]);
|
88 |
+
beta = beta + blockIdx.y;
|
89 |
+
input_t beta_val = expf(beta[0]);
|
90 |
+
|
91 |
+
#pragma unroll
|
92 |
+
for (int it = 0; it < FILTER_SIZE; it += 1)
|
93 |
+
{
|
94 |
+
up_filter[it] = up_ftr[it];
|
95 |
+
down_filter[it] = down_ftr[it];
|
96 |
+
}
|
97 |
+
|
98 |
+
// Apply replication padding for upsampling, matching torch impl
|
99 |
+
#pragma unroll
|
100 |
+
for (int it = -HALF_FILTER_SIZE; it < BUFFER_SIZE + HALF_FILTER_SIZE; it += 1)
|
101 |
+
{
|
102 |
+
int element_index = seq_offset + it; // index for element
|
103 |
+
if ((element_index < 0) && (element_index >= -UPSAMPLE_REPLICATION_PAD))
|
104 |
+
{
|
105 |
+
elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_left_most_value;
|
106 |
+
}
|
107 |
+
if ((element_index >= seq_len) && (element_index < seq_len + UPSAMPLE_REPLICATION_PAD))
|
108 |
+
{
|
109 |
+
elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_right_most_value;
|
110 |
+
}
|
111 |
+
if ((element_index >= 0) && (element_index < seq_len))
|
112 |
+
{
|
113 |
+
elements[2 * (HALF_FILTER_SIZE + it)] = 2 * src[it];
|
114 |
+
}
|
115 |
+
}
|
116 |
+
|
117 |
+
// Apply upsampling strided convolution and write to intermediates. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT for replication padding of the downsampilng conv later
|
118 |
+
#pragma unroll
|
119 |
+
for (int it = 0; it < (2 * BUFFER_SIZE + 2 * FILTER_SIZE); it += 1)
|
120 |
+
{
|
121 |
+
input_t acc = 0.0;
|
122 |
+
int element_index = intermediate_seq_offset + it; // index for intermediate
|
123 |
+
#pragma unroll
|
124 |
+
for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
|
125 |
+
{
|
126 |
+
if ((element_index + f_idx) >= 0)
|
127 |
+
{
|
128 |
+
acc += up_filter[f_idx] * elements[it + f_idx];
|
129 |
+
}
|
130 |
+
}
|
131 |
+
intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] = acc;
|
132 |
+
}
|
133 |
+
|
134 |
+
// Apply activation function. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT and DOWNSAMPLE_REPLICATION_PAD_RIGHT for replication padding of the downsampilng conv later
|
135 |
+
double no_div_by_zero = 0.000000001;
|
136 |
+
#pragma unroll
|
137 |
+
for (int it = 0; it < 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it += 1)
|
138 |
+
{
|
139 |
+
intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] += (1.0 / (beta_val + no_div_by_zero)) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val);
|
140 |
+
}
|
141 |
+
|
142 |
+
// Apply replication padding before downsampling conv from intermediates
|
143 |
+
#pragma unroll
|
144 |
+
for (int it = 0; it < DOWNSAMPLE_REPLICATION_PAD_LEFT; it += 1)
|
145 |
+
{
|
146 |
+
intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT];
|
147 |
+
}
|
148 |
+
#pragma unroll
|
149 |
+
for (int it = DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it < DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE + DOWNSAMPLE_REPLICATION_PAD_RIGHT; it += 1)
|
150 |
+
{
|
151 |
+
intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE - 1];
|
152 |
+
}
|
153 |
+
|
154 |
+
// Apply downsample strided convolution (assuming stride=2) from intermediates
|
155 |
+
#pragma unroll
|
156 |
+
for (int it = 0; it < BUFFER_SIZE; it += 1)
|
157 |
+
{
|
158 |
+
input_t acc = 0.0;
|
159 |
+
#pragma unroll
|
160 |
+
for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
|
161 |
+
{
|
162 |
+
// Add constant DOWNSAMPLE_REPLICATION_PAD_RIGHT to match torch implementation
|
163 |
+
acc += down_filter[f_idx] * intermediates[it * 2 + f_idx + DOWNSAMPLE_REPLICATION_PAD_RIGHT];
|
164 |
+
}
|
165 |
+
output[it] = acc;
|
166 |
+
}
|
167 |
+
|
168 |
+
// Write output to dst
|
169 |
+
#pragma unroll
|
170 |
+
for (int it = 0; it < BUFFER_SIZE; it += ELEMENTS_PER_LDG_STG)
|
171 |
+
{
|
172 |
+
int element_index = seq_offset + it;
|
173 |
+
if (element_index < seq_len)
|
174 |
+
{
|
175 |
+
dst[it] = output[it];
|
176 |
+
}
|
177 |
+
}
|
178 |
+
|
179 |
+
}
|
180 |
+
|
181 |
+
template <typename input_t, typename output_t, typename acc_t>
|
182 |
+
void dispatch_anti_alias_activation_forward(
|
183 |
+
output_t *dst,
|
184 |
+
const input_t *src,
|
185 |
+
const input_t *up_ftr,
|
186 |
+
const input_t *down_ftr,
|
187 |
+
const input_t *alpha,
|
188 |
+
const input_t *beta,
|
189 |
+
int batch_size,
|
190 |
+
int channels,
|
191 |
+
int seq_len)
|
192 |
+
{
|
193 |
+
if (seq_len == 0)
|
194 |
+
{
|
195 |
+
return;
|
196 |
+
}
|
197 |
+
else
|
198 |
+
{
|
199 |
+
// Use 128 threads per block to maximimize gpu utilization
|
200 |
+
constexpr int threads_per_block = 128;
|
201 |
+
constexpr int seq_len_per_block = 4096;
|
202 |
+
int blocks_per_seq_len = (seq_len + seq_len_per_block - 1) / seq_len_per_block;
|
203 |
+
dim3 blocks(blocks_per_seq_len, channels, batch_size);
|
204 |
+
dim3 threads(threads_per_block, 1, 1);
|
205 |
+
|
206 |
+
anti_alias_activation_forward<input_t, output_t, acc_t>
|
207 |
+
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, up_ftr, down_ftr, alpha, beta, batch_size, channels, seq_len);
|
208 |
+
}
|
209 |
+
}
|
210 |
+
}
|
211 |
+
|
212 |
+
extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta)
|
213 |
+
{
|
214 |
+
// Input is a 3d tensor with dimensions [batches, channels, seq_len]
|
215 |
+
const int batches = input.size(0);
|
216 |
+
const int channels = input.size(1);
|
217 |
+
const int seq_len = input.size(2);
|
218 |
+
|
219 |
+
// Output
|
220 |
+
auto act_options = input.options().requires_grad(false);
|
221 |
+
|
222 |
+
torch::Tensor anti_alias_activation_results =
|
223 |
+
torch::empty({batches, channels, seq_len}, act_options);
|
224 |
+
|
225 |
+
void *input_ptr = static_cast<void *>(input.data_ptr());
|
226 |
+
void *up_filter_ptr = static_cast<void *>(up_filter.data_ptr());
|
227 |
+
void *down_filter_ptr = static_cast<void *>(down_filter.data_ptr());
|
228 |
+
void *alpha_ptr = static_cast<void *>(alpha.data_ptr());
|
229 |
+
void *beta_ptr = static_cast<void *>(beta.data_ptr());
|
230 |
+
void *anti_alias_activation_results_ptr = static_cast<void *>(anti_alias_activation_results.data_ptr());
|
231 |
+
|
232 |
+
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
233 |
+
input.scalar_type(),
|
234 |
+
"dispatch anti alias activation_forward",
|
235 |
+
dispatch_anti_alias_activation_forward<scalar_t, scalar_t, float>(
|
236 |
+
reinterpret_cast<scalar_t *>(anti_alias_activation_results_ptr),
|
237 |
+
reinterpret_cast<const scalar_t *>(input_ptr),
|
238 |
+
reinterpret_cast<const scalar_t *>(up_filter_ptr),
|
239 |
+
reinterpret_cast<const scalar_t *>(down_filter_ptr),
|
240 |
+
reinterpret_cast<const scalar_t *>(alpha_ptr),
|
241 |
+
reinterpret_cast<const scalar_t *>(beta_ptr),
|
242 |
+
batches,
|
243 |
+
channels,
|
244 |
+
seq_len););
|
245 |
+
return anti_alias_activation_results;
|
246 |
+
}
|
modules/audio_detokenizer/vocoder/alias_free_activation/cuda/compat.h
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* coding=utf-8
|
2 |
+
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
/*This code is copied fron NVIDIA apex:
|
18 |
+
* https://github.com/NVIDIA/apex
|
19 |
+
* with minor changes. */
|
20 |
+
|
21 |
+
#ifndef TORCH_CHECK
|
22 |
+
#define TORCH_CHECK AT_CHECK
|
23 |
+
#endif
|
24 |
+
|
25 |
+
#ifdef VERSION_GE_1_3
|
26 |
+
#define DATA_PTR data_ptr
|
27 |
+
#else
|
28 |
+
#define DATA_PTR data
|
29 |
+
#endif
|
modules/audio_detokenizer/vocoder/alias_free_activation/cuda/load.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 NVIDIA CORPORATION.
|
2 |
+
# Licensed under the MIT license.
|
3 |
+
|
4 |
+
import os
|
5 |
+
import pathlib
|
6 |
+
import subprocess
|
7 |
+
|
8 |
+
from torch.utils import cpp_extension
|
9 |
+
|
10 |
+
"""
|
11 |
+
Setting this param to a list has a problem of generating different compilation commands (with diferent order of architectures) and leading to recompilation of fused kernels.
|
12 |
+
Set it to empty stringo avoid recompilation and assign arch flags explicity in extra_cuda_cflags below
|
13 |
+
"""
|
14 |
+
os.environ["TORCH_CUDA_ARCH_LIST"] = ""
|
15 |
+
|
16 |
+
|
17 |
+
def load():
|
18 |
+
# Check if cuda 11 is installed for compute capability 8.0
|
19 |
+
cc_flag = []
|
20 |
+
_, bare_metal_major, _ = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
|
21 |
+
if int(bare_metal_major) >= 11:
|
22 |
+
cc_flag.append("-gencode")
|
23 |
+
cc_flag.append("arch=compute_80,code=sm_80")
|
24 |
+
|
25 |
+
# Build path
|
26 |
+
srcpath = pathlib.Path(__file__).parent.absolute()
|
27 |
+
buildpath = srcpath / "build"
|
28 |
+
_create_build_dir(buildpath)
|
29 |
+
|
30 |
+
# Helper function to build the kernels.
|
31 |
+
def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
|
32 |
+
return cpp_extension.load(
|
33 |
+
name=name,
|
34 |
+
sources=sources,
|
35 |
+
build_directory=buildpath,
|
36 |
+
extra_cflags=[
|
37 |
+
"-O3",
|
38 |
+
],
|
39 |
+
extra_cuda_cflags=[
|
40 |
+
"-O3",
|
41 |
+
"-gencode",
|
42 |
+
"arch=compute_70,code=sm_70",
|
43 |
+
"--use_fast_math",
|
44 |
+
]
|
45 |
+
+ extra_cuda_flags
|
46 |
+
+ cc_flag,
|
47 |
+
verbose=True,
|
48 |
+
)
|
49 |
+
|
50 |
+
extra_cuda_flags = [
|
51 |
+
"-U__CUDA_NO_HALF_OPERATORS__",
|
52 |
+
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
53 |
+
"--expt-relaxed-constexpr",
|
54 |
+
"--expt-extended-lambda",
|
55 |
+
]
|
56 |
+
|
57 |
+
sources = [
|
58 |
+
srcpath / "anti_alias_activation.cpp",
|
59 |
+
srcpath / "anti_alias_activation_cuda.cu",
|
60 |
+
]
|
61 |
+
anti_alias_activation_cuda = _cpp_extention_load_helper(
|
62 |
+
"anti_alias_activation_cuda", sources, extra_cuda_flags
|
63 |
+
)
|
64 |
+
|
65 |
+
return anti_alias_activation_cuda
|
66 |
+
|
67 |
+
|
68 |
+
def _get_cuda_bare_metal_version(cuda_dir):
|
69 |
+
raw_output = subprocess.check_output(
|
70 |
+
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
|
71 |
+
)
|
72 |
+
output = raw_output.split()
|
73 |
+
release_idx = output.index("release") + 1
|
74 |
+
release = output[release_idx].split(".")
|
75 |
+
bare_metal_major = release[0]
|
76 |
+
bare_metal_minor = release[1][0]
|
77 |
+
|
78 |
+
return raw_output, bare_metal_major, bare_metal_minor
|
79 |
+
|
80 |
+
|
81 |
+
def _create_build_dir(buildpath):
|
82 |
+
try:
|
83 |
+
os.mkdir(buildpath)
|
84 |
+
except OSError:
|
85 |
+
if not os.path.isdir(buildpath):
|
86 |
+
print(f"Creation of the build directory {buildpath} failed")
|
modules/audio_detokenizer/vocoder/alias_free_activation/cuda/type_shim.h
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* coding=utf-8
|
2 |
+
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
#include <ATen/ATen.h>
|
18 |
+
#include "compat.h"
|
19 |
+
|
20 |
+
#define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...) \
|
21 |
+
switch (TYPE) \
|
22 |
+
{ \
|
23 |
+
case at::ScalarType::Float: \
|
24 |
+
{ \
|
25 |
+
using scalar_t = float; \
|
26 |
+
__VA_ARGS__; \
|
27 |
+
break; \
|
28 |
+
} \
|
29 |
+
case at::ScalarType::Half: \
|
30 |
+
{ \
|
31 |
+
using scalar_t = at::Half; \
|
32 |
+
__VA_ARGS__; \
|
33 |
+
break; \
|
34 |
+
} \
|
35 |
+
case at::ScalarType::BFloat16: \
|
36 |
+
{ \
|
37 |
+
using scalar_t = at::BFloat16; \
|
38 |
+
__VA_ARGS__; \
|
39 |
+
break; \
|
40 |
+
} \
|
41 |
+
default: \
|
42 |
+
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
43 |
+
}
|
44 |
+
|
45 |
+
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
|
46 |
+
switch (TYPEIN) \
|
47 |
+
{ \
|
48 |
+
case at::ScalarType::Float: \
|
49 |
+
{ \
|
50 |
+
using scalar_t_in = float; \
|
51 |
+
switch (TYPEOUT) \
|
52 |
+
{ \
|
53 |
+
case at::ScalarType::Float: \
|
54 |
+
{ \
|
55 |
+
using scalar_t_out = float; \
|
56 |
+
__VA_ARGS__; \
|
57 |
+
break; \
|
58 |
+
} \
|
59 |
+
case at::ScalarType::Half: \
|
60 |
+
{ \
|
61 |
+
using scalar_t_out = at::Half; \
|
62 |
+
__VA_ARGS__; \
|
63 |
+
break; \
|
64 |
+
} \
|
65 |
+
case at::ScalarType::BFloat16: \
|
66 |
+
{ \
|
67 |
+
using scalar_t_out = at::BFloat16; \
|
68 |
+
__VA_ARGS__; \
|
69 |
+
break; \
|
70 |
+
} \
|
71 |
+
default: \
|
72 |
+
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
|
73 |
+
} \
|
74 |
+
break; \
|
75 |
+
} \
|
76 |
+
case at::ScalarType::Half: \
|
77 |
+
{ \
|
78 |
+
using scalar_t_in = at::Half; \
|
79 |
+
using scalar_t_out = at::Half; \
|
80 |
+
__VA_ARGS__; \
|
81 |
+
break; \
|
82 |
+
} \
|
83 |
+
case at::ScalarType::BFloat16: \
|
84 |
+
{ \
|
85 |
+
using scalar_t_in = at::BFloat16; \
|
86 |
+
using scalar_t_out = at::BFloat16; \
|
87 |
+
__VA_ARGS__; \
|
88 |
+
break; \
|
89 |
+
} \
|
90 |
+
default: \
|
91 |
+
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
|
92 |
+
}
|
modules/audio_detokenizer/vocoder/alias_free_activation/torch/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
# LICENSE is in incl_licenses directory.
|
3 |
+
|
4 |
+
from .filter import *
|
5 |
+
from .resample import *
|
6 |
+
from .act import *
|
modules/audio_detokenizer/vocoder/alias_free_activation/torch/act.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
# LICENSE is in incl_licenses directory.
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
from .resample import UpSample1d, DownSample1d
|
6 |
+
|
7 |
+
|
8 |
+
class Activation1d(nn.Module):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
activation,
|
12 |
+
up_ratio: int = 2,
|
13 |
+
down_ratio: int = 2,
|
14 |
+
up_kernel_size: int = 12,
|
15 |
+
down_kernel_size: int = 12,
|
16 |
+
):
|
17 |
+
super().__init__()
|
18 |
+
self.up_ratio = up_ratio
|
19 |
+
self.down_ratio = down_ratio
|
20 |
+
self.act = activation
|
21 |
+
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
22 |
+
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
23 |
+
|
24 |
+
# x: [B,C,T]
|
25 |
+
def forward(self, x):
|
26 |
+
x = self.upsample(x)
|
27 |
+
x = self.act(x)
|
28 |
+
x = self.downsample(x)
|
29 |
+
|
30 |
+
return x
|
modules/audio_detokenizer/vocoder/alias_free_activation/torch/filter.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
# LICENSE is in incl_licenses directory.
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import math
|
8 |
+
|
9 |
+
if "sinc" in dir(torch):
|
10 |
+
sinc = torch.sinc
|
11 |
+
else:
|
12 |
+
# This code is adopted from adefossez's julius.core.sinc under the MIT License
|
13 |
+
# https://adefossez.github.io/julius/julius/core.html
|
14 |
+
# LICENSE is in incl_licenses directory.
|
15 |
+
def sinc(x: torch.Tensor):
|
16 |
+
"""
|
17 |
+
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
|
18 |
+
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
|
19 |
+
"""
|
20 |
+
return torch.where(
|
21 |
+
x == 0,
|
22 |
+
torch.tensor(1.0, device=x.device, dtype=x.dtype),
|
23 |
+
torch.sin(math.pi * x) / math.pi / x,
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
|
28 |
+
# https://adefossez.github.io/julius/julius/lowpass.html
|
29 |
+
# LICENSE is in incl_licenses directory.
|
30 |
+
def kaiser_sinc_filter1d(
|
31 |
+
cutoff, half_width, kernel_size
|
32 |
+
): # return filter [1,1,kernel_size]
|
33 |
+
even = kernel_size % 2 == 0
|
34 |
+
half_size = kernel_size // 2
|
35 |
+
|
36 |
+
# For kaiser window
|
37 |
+
delta_f = 4 * half_width
|
38 |
+
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
39 |
+
if A > 50.0:
|
40 |
+
beta = 0.1102 * (A - 8.7)
|
41 |
+
elif A >= 21.0:
|
42 |
+
beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
|
43 |
+
else:
|
44 |
+
beta = 0.0
|
45 |
+
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
46 |
+
|
47 |
+
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
|
48 |
+
if even:
|
49 |
+
time = torch.arange(-half_size, half_size) + 0.5
|
50 |
+
else:
|
51 |
+
time = torch.arange(kernel_size) - half_size
|
52 |
+
if cutoff == 0:
|
53 |
+
filter_ = torch.zeros_like(time)
|
54 |
+
else:
|
55 |
+
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
|
56 |
+
"""
|
57 |
+
Normalize filter to have sum = 1, otherwise we will have a small leakage of the constant component in the input signal.
|
58 |
+
"""
|
59 |
+
filter_ /= filter_.sum()
|
60 |
+
filter = filter_.view(1, 1, kernel_size)
|
61 |
+
|
62 |
+
return filter
|
63 |
+
|
64 |
+
|
65 |
+
class LowPassFilter1d(nn.Module):
|
66 |
+
def __init__(
|
67 |
+
self,
|
68 |
+
cutoff=0.5,
|
69 |
+
half_width=0.6,
|
70 |
+
stride: int = 1,
|
71 |
+
padding: bool = True,
|
72 |
+
padding_mode: str = "replicate",
|
73 |
+
kernel_size: int = 12,
|
74 |
+
):
|
75 |
+
"""
|
76 |
+
kernel_size should be even number for stylegan3 setup, in this implementation, odd number is also possible.
|
77 |
+
"""
|
78 |
+
super().__init__()
|
79 |
+
if cutoff < -0.0:
|
80 |
+
raise ValueError("Minimum cutoff must be larger than zero.")
|
81 |
+
if cutoff > 0.5:
|
82 |
+
raise ValueError("A cutoff above 0.5 does not make sense.")
|
83 |
+
self.kernel_size = kernel_size
|
84 |
+
self.even = kernel_size % 2 == 0
|
85 |
+
self.pad_left = kernel_size // 2 - int(self.even)
|
86 |
+
self.pad_right = kernel_size // 2
|
87 |
+
self.stride = stride
|
88 |
+
self.padding = padding
|
89 |
+
self.padding_mode = padding_mode
|
90 |
+
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
91 |
+
self.register_buffer("filter", filter)
|
92 |
+
|
93 |
+
# Input [B, C, T]
|
94 |
+
def forward(self, x):
|
95 |
+
_, C, _ = x.shape
|
96 |
+
|
97 |
+
if self.padding:
|
98 |
+
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
|
99 |
+
out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
100 |
+
|
101 |
+
return out
|
modules/audio_detokenizer/vocoder/alias_free_activation/torch/resample.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
# LICENSE is in incl_licenses directory.
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
from .filter import LowPassFilter1d
|
7 |
+
from .filter import kaiser_sinc_filter1d
|
8 |
+
|
9 |
+
|
10 |
+
class UpSample1d(nn.Module):
|
11 |
+
def __init__(self, ratio=2, kernel_size=None):
|
12 |
+
super().__init__()
|
13 |
+
self.ratio = ratio
|
14 |
+
self.kernel_size = (
|
15 |
+
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
16 |
+
)
|
17 |
+
self.stride = ratio
|
18 |
+
self.pad = self.kernel_size // ratio - 1
|
19 |
+
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
20 |
+
self.pad_right = (
|
21 |
+
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
22 |
+
)
|
23 |
+
filter = kaiser_sinc_filter1d(
|
24 |
+
cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
|
25 |
+
)
|
26 |
+
self.register_buffer("filter", filter)
|
27 |
+
|
28 |
+
# x: [B, C, T]
|
29 |
+
def forward(self, x):
|
30 |
+
_, C, _ = x.shape
|
31 |
+
|
32 |
+
x = F.pad(x, (self.pad, self.pad), mode="replicate")
|
33 |
+
x = self.ratio * F.conv_transpose1d(
|
34 |
+
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
|
35 |
+
)
|
36 |
+
x = x[..., self.pad_left : -self.pad_right]
|
37 |
+
|
38 |
+
return x
|
39 |
+
|
40 |
+
|
41 |
+
class DownSample1d(nn.Module):
|
42 |
+
def __init__(self, ratio=2, kernel_size=None):
|
43 |
+
super().__init__()
|
44 |
+
self.ratio = ratio
|
45 |
+
self.kernel_size = (
|
46 |
+
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
47 |
+
)
|
48 |
+
self.lowpass = LowPassFilter1d(
|
49 |
+
cutoff=0.5 / ratio,
|
50 |
+
half_width=0.6 / ratio,
|
51 |
+
stride=ratio,
|
52 |
+
kernel_size=self.kernel_size,
|
53 |
+
)
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
xx = self.lowpass(x)
|
57 |
+
|
58 |
+
return xx
|
modules/audio_detokenizer/vocoder/bigvgan.py
ADDED
@@ -0,0 +1,492 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 NVIDIA CORPORATION.
|
2 |
+
# Licensed under the MIT license.
|
3 |
+
|
4 |
+
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
5 |
+
# LICENSE is in incl_licenses directory.
|
6 |
+
|
7 |
+
import os
|
8 |
+
import json
|
9 |
+
from pathlib import Path
|
10 |
+
from typing import Optional, Union, Dict
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
from torch.nn import Conv1d, ConvTranspose1d
|
15 |
+
from torch.nn.utils import weight_norm, remove_weight_norm
|
16 |
+
|
17 |
+
from modules.audio_detokenizer.vocoder.activations import Snake, SnakeBeta
|
18 |
+
from modules.audio_detokenizer.vocoder.utils import init_weights, get_padding
|
19 |
+
from modules.audio_detokenizer.vocoder.alias_free_activation.torch.act import Activation1d as TorchActivation1d
|
20 |
+
from modules.audio_detokenizer.vocoder.utils import AttrDict
|
21 |
+
|
22 |
+
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
|
23 |
+
|
24 |
+
|
25 |
+
def load_hparams_from_json(path) -> AttrDict:
|
26 |
+
with open(path) as f:
|
27 |
+
data = f.read()
|
28 |
+
return AttrDict(json.loads(data))
|
29 |
+
|
30 |
+
|
31 |
+
class AMPBlock1(torch.nn.Module):
|
32 |
+
"""
|
33 |
+
AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
|
34 |
+
AMPBlock1 has additional self.convs2 that contains additional Conv1d layers with a fixed dilation=1 followed by each layer in self.convs1
|
35 |
+
|
36 |
+
Args:
|
37 |
+
h (AttrDict): Hyperparameters.
|
38 |
+
channels (int): Number of convolution channels.
|
39 |
+
kernel_size (int): Size of the convolution kernel. Default is 3.
|
40 |
+
dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).
|
41 |
+
activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.
|
42 |
+
"""
|
43 |
+
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
h: AttrDict,
|
47 |
+
channels: int,
|
48 |
+
kernel_size: int = 3,
|
49 |
+
dilation: tuple = (1, 3, 5),
|
50 |
+
activation: str = None,
|
51 |
+
):
|
52 |
+
super().__init__()
|
53 |
+
|
54 |
+
self.h = h
|
55 |
+
|
56 |
+
self.convs1 = nn.ModuleList(
|
57 |
+
[
|
58 |
+
weight_norm(
|
59 |
+
Conv1d(
|
60 |
+
channels,
|
61 |
+
channels,
|
62 |
+
kernel_size,
|
63 |
+
stride=1,
|
64 |
+
dilation=d,
|
65 |
+
padding=get_padding(kernel_size, d),
|
66 |
+
)
|
67 |
+
)
|
68 |
+
for d in dilation
|
69 |
+
]
|
70 |
+
)
|
71 |
+
self.convs1.apply(init_weights)
|
72 |
+
|
73 |
+
self.convs2 = nn.ModuleList(
|
74 |
+
[
|
75 |
+
weight_norm(
|
76 |
+
Conv1d(
|
77 |
+
channels,
|
78 |
+
channels,
|
79 |
+
kernel_size,
|
80 |
+
stride=1,
|
81 |
+
dilation=1,
|
82 |
+
padding=get_padding(kernel_size, 1),
|
83 |
+
)
|
84 |
+
)
|
85 |
+
for _ in range(len(dilation))
|
86 |
+
]
|
87 |
+
)
|
88 |
+
self.convs2.apply(init_weights)
|
89 |
+
|
90 |
+
self.num_layers = len(self.convs1) + len(
|
91 |
+
self.convs2
|
92 |
+
) # Total number of conv layers
|
93 |
+
|
94 |
+
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
95 |
+
if self.h.get("use_cuda_kernel", False):
|
96 |
+
from modules.audio_detokenizer.vocoder.alias_free_activation.cuda.activation1d import (
|
97 |
+
Activation1d as CudaActivation1d,
|
98 |
+
)
|
99 |
+
|
100 |
+
Activation1d = CudaActivation1d
|
101 |
+
else:
|
102 |
+
Activation1d = TorchActivation1d
|
103 |
+
|
104 |
+
# Activation functions
|
105 |
+
if activation == "snake":
|
106 |
+
self.activations = nn.ModuleList(
|
107 |
+
[
|
108 |
+
Activation1d(
|
109 |
+
activation=Snake(
|
110 |
+
channels, alpha_logscale=h.snake_logscale
|
111 |
+
)
|
112 |
+
)
|
113 |
+
for _ in range(self.num_layers)
|
114 |
+
]
|
115 |
+
)
|
116 |
+
elif activation == "snakebeta":
|
117 |
+
self.activations = nn.ModuleList(
|
118 |
+
[
|
119 |
+
Activation1d(
|
120 |
+
activation=SnakeBeta(
|
121 |
+
channels, alpha_logscale=h.snake_logscale
|
122 |
+
)
|
123 |
+
)
|
124 |
+
for _ in range(self.num_layers)
|
125 |
+
]
|
126 |
+
)
|
127 |
+
else:
|
128 |
+
raise NotImplementedError(
|
129 |
+
"activation incorrectly specified. check the config file and look for 'activation'."
|
130 |
+
)
|
131 |
+
|
132 |
+
def forward(self, x):
|
133 |
+
acts1, acts2 = self.activations[::2], self.activations[1::2]
|
134 |
+
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
|
135 |
+
xt = a1(x)
|
136 |
+
xt = c1(xt)
|
137 |
+
xt = a2(xt)
|
138 |
+
xt = c2(xt)
|
139 |
+
x = xt + x
|
140 |
+
|
141 |
+
return x
|
142 |
+
|
143 |
+
def remove_weight_norm(self):
|
144 |
+
for l in self.convs1:
|
145 |
+
remove_weight_norm(l)
|
146 |
+
for l in self.convs2:
|
147 |
+
remove_weight_norm(l)
|
148 |
+
|
149 |
+
|
150 |
+
class AMPBlock2(torch.nn.Module):
|
151 |
+
"""
|
152 |
+
AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
|
153 |
+
Unlike AMPBlock1, AMPBlock2 does not contain extra Conv1d layers with fixed dilation=1
|
154 |
+
|
155 |
+
Args:
|
156 |
+
h (AttrDict): Hyperparameters.
|
157 |
+
channels (int): Number of convolution channels.
|
158 |
+
kernel_size (int): Size of the convolution kernel. Default is 3.
|
159 |
+
dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).
|
160 |
+
activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.
|
161 |
+
"""
|
162 |
+
|
163 |
+
def __init__(
|
164 |
+
self,
|
165 |
+
h: AttrDict,
|
166 |
+
channels: int,
|
167 |
+
kernel_size: int = 3,
|
168 |
+
dilation: tuple = (1, 3, 5),
|
169 |
+
activation: str = None,
|
170 |
+
):
|
171 |
+
super().__init__()
|
172 |
+
|
173 |
+
self.h = h
|
174 |
+
|
175 |
+
self.convs = nn.ModuleList(
|
176 |
+
[
|
177 |
+
weight_norm(
|
178 |
+
Conv1d(
|
179 |
+
channels,
|
180 |
+
channels,
|
181 |
+
kernel_size,
|
182 |
+
stride=1,
|
183 |
+
dilation=d,
|
184 |
+
padding=get_padding(kernel_size, d),
|
185 |
+
)
|
186 |
+
)
|
187 |
+
for d in dilation
|
188 |
+
]
|
189 |
+
)
|
190 |
+
self.convs.apply(init_weights)
|
191 |
+
|
192 |
+
self.num_layers = len(self.convs) # Total number of conv layers
|
193 |
+
|
194 |
+
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
195 |
+
if self.h.get("use_cuda_kernel", False):
|
196 |
+
from modules.audio_detokenizer.vocoder.alias_free_activation.cuda.activation1d import (
|
197 |
+
Activation1d as CudaActivation1d,
|
198 |
+
)
|
199 |
+
|
200 |
+
Activation1d = CudaActivation1d
|
201 |
+
else:
|
202 |
+
Activation1d = TorchActivation1d
|
203 |
+
|
204 |
+
# Activation functions
|
205 |
+
if activation == "snake":
|
206 |
+
self.activations = nn.ModuleList(
|
207 |
+
[
|
208 |
+
Activation1d(
|
209 |
+
activation=Snake(
|
210 |
+
channels, alpha_logscale=h.snake_logscale
|
211 |
+
)
|
212 |
+
)
|
213 |
+
for _ in range(self.num_layers)
|
214 |
+
]
|
215 |
+
)
|
216 |
+
elif activation == "snakebeta":
|
217 |
+
self.activations = nn.ModuleList(
|
218 |
+
[
|
219 |
+
Activation1d(
|
220 |
+
activation=SnakeBeta(
|
221 |
+
channels, alpha_logscale=h.snake_logscale
|
222 |
+
)
|
223 |
+
)
|
224 |
+
for _ in range(self.num_layers)
|
225 |
+
]
|
226 |
+
)
|
227 |
+
else:
|
228 |
+
raise NotImplementedError(
|
229 |
+
"activation incorrectly specified. check the config file and look for 'activation'."
|
230 |
+
)
|
231 |
+
|
232 |
+
def forward(self, x):
|
233 |
+
for c, a in zip(self.convs, self.activations):
|
234 |
+
xt = a(x)
|
235 |
+
xt = c(xt)
|
236 |
+
x = xt + x
|
237 |
+
|
238 |
+
def remove_weight_norm(self):
|
239 |
+
for l in self.convs:
|
240 |
+
remove_weight_norm(l)
|
241 |
+
|
242 |
+
|
243 |
+
class BigVGAN(
|
244 |
+
torch.nn.Module,
|
245 |
+
PyTorchModelHubMixin,
|
246 |
+
library_name="bigvgan",
|
247 |
+
repo_url="https://github.com/NVIDIA/BigVGAN",
|
248 |
+
docs_url="https://github.com/NVIDIA/BigVGAN/blob/main/README.md",
|
249 |
+
pipeline_tag="audio-to-audio",
|
250 |
+
license="mit",
|
251 |
+
tags=["neural-vocoder", "audio-generation", "arxiv:2206.04658"],
|
252 |
+
):
|
253 |
+
"""
|
254 |
+
BigVGAN is a neural vocoder model that applies anti-aliased periodic activation for residual blocks (resblocks).
|
255 |
+
New in BigVGAN-v2: it can optionally use optimized CUDA kernels for AMP (anti-aliased multi-periodicity) blocks.
|
256 |
+
|
257 |
+
Args:
|
258 |
+
h (AttrDict): Hyperparameters.
|
259 |
+
use_cuda_kernel (bool): If set to True, loads optimized CUDA kernels for AMP. This should be used for inference only, as training is not supported with CUDA kernels.
|
260 |
+
|
261 |
+
Note:
|
262 |
+
- The `use_cuda_kernel` parameter should be used for inference only, as training with CUDA kernels is not supported.
|
263 |
+
- Ensure that the activation function is correctly specified in the hyperparameters (h.activation).
|
264 |
+
"""
|
265 |
+
|
266 |
+
def __init__(self, h: AttrDict, use_cuda_kernel: bool = False):
|
267 |
+
super().__init__()
|
268 |
+
self.h = h
|
269 |
+
self.h["use_cuda_kernel"] = use_cuda_kernel
|
270 |
+
|
271 |
+
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
272 |
+
if self.h.get("use_cuda_kernel", False):
|
273 |
+
from modules.audio_detokenizer.vocoder.alias_free_activation.cuda.activation1d import (
|
274 |
+
Activation1d as CudaActivation1d,
|
275 |
+
)
|
276 |
+
|
277 |
+
Activation1d = CudaActivation1d
|
278 |
+
else:
|
279 |
+
Activation1d = TorchActivation1d
|
280 |
+
|
281 |
+
self.num_kernels = len(h.resblock_kernel_sizes)
|
282 |
+
self.num_upsamples = len(h.upsample_rates)
|
283 |
+
|
284 |
+
# Pre-conv
|
285 |
+
self.conv_pre = weight_norm(
|
286 |
+
Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
|
287 |
+
)
|
288 |
+
|
289 |
+
# Define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
|
290 |
+
if h.resblock == "1":
|
291 |
+
resblock_class = AMPBlock1
|
292 |
+
elif h.resblock == "2":
|
293 |
+
resblock_class = AMPBlock2
|
294 |
+
else:
|
295 |
+
raise ValueError(
|
296 |
+
f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}"
|
297 |
+
)
|
298 |
+
|
299 |
+
# Transposed conv-based upsamplers. does not apply anti-aliasing
|
300 |
+
self.ups = nn.ModuleList()
|
301 |
+
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
302 |
+
self.ups.append(
|
303 |
+
nn.ModuleList(
|
304 |
+
[
|
305 |
+
weight_norm(
|
306 |
+
ConvTranspose1d(
|
307 |
+
h.upsample_initial_channel // (2**i),
|
308 |
+
h.upsample_initial_channel // (2 ** (i + 1)),
|
309 |
+
k,
|
310 |
+
u,
|
311 |
+
padding=(k - u) // 2,
|
312 |
+
)
|
313 |
+
)
|
314 |
+
]
|
315 |
+
)
|
316 |
+
)
|
317 |
+
|
318 |
+
# Residual blocks using anti-aliased multi-periodicity composition modules (AMP)
|
319 |
+
self.resblocks = nn.ModuleList()
|
320 |
+
for i in range(len(self.ups)):
|
321 |
+
ch = h.upsample_initial_channel // (2 ** (i + 1))
|
322 |
+
for j, (k, d) in enumerate(
|
323 |
+
zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
|
324 |
+
):
|
325 |
+
self.resblocks.append(
|
326 |
+
resblock_class(h, ch, k, d, activation=h.activation)
|
327 |
+
)
|
328 |
+
|
329 |
+
# Post-conv
|
330 |
+
activation_post = (
|
331 |
+
Snake(ch, alpha_logscale=h.snake_logscale)
|
332 |
+
if h.activation == "snake"
|
333 |
+
else (
|
334 |
+
SnakeBeta(ch, alpha_logscale=h.snake_logscale)
|
335 |
+
if h.activation == "snakebeta"
|
336 |
+
else None
|
337 |
+
)
|
338 |
+
)
|
339 |
+
if activation_post is None:
|
340 |
+
raise NotImplementedError(
|
341 |
+
"activation incorrectly specified. check the config file and look for 'activation'."
|
342 |
+
)
|
343 |
+
|
344 |
+
self.activation_post = Activation1d(activation=activation_post)
|
345 |
+
|
346 |
+
# Whether to use bias for the final conv_post. Default to True for backward compatibility
|
347 |
+
self.use_bias_at_final = h.get("use_bias_at_final", True)
|
348 |
+
self.conv_post = weight_norm(
|
349 |
+
Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final)
|
350 |
+
)
|
351 |
+
|
352 |
+
# Weight initialization
|
353 |
+
for i in range(len(self.ups)):
|
354 |
+
self.ups[i].apply(init_weights)
|
355 |
+
self.conv_post.apply(init_weights)
|
356 |
+
|
357 |
+
# Final tanh activation. Defaults to True for backward compatibility
|
358 |
+
self.use_tanh_at_final = h.get("use_tanh_at_final", True)
|
359 |
+
|
360 |
+
def forward(self, x):
|
361 |
+
# Pre-conv
|
362 |
+
x = self.conv_pre(x)
|
363 |
+
|
364 |
+
for i in range(self.num_upsamples):
|
365 |
+
# Upsampling
|
366 |
+
for i_up in range(len(self.ups[i])):
|
367 |
+
x = self.ups[i][i_up](x)
|
368 |
+
# AMP blocks
|
369 |
+
xs = None
|
370 |
+
for j in range(self.num_kernels):
|
371 |
+
if xs is None:
|
372 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
373 |
+
else:
|
374 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
375 |
+
x = xs / self.num_kernels
|
376 |
+
|
377 |
+
# Post-conv
|
378 |
+
x = self.activation_post(x)
|
379 |
+
x = self.conv_post(x)
|
380 |
+
# Final tanh activation
|
381 |
+
if self.use_tanh_at_final:
|
382 |
+
x = torch.tanh(x)
|
383 |
+
else:
|
384 |
+
x = torch.clamp(x, min=-1.0, max=1.0) # Bound the output to [-1, 1]
|
385 |
+
|
386 |
+
return x
|
387 |
+
|
388 |
+
def remove_weight_norm(self):
|
389 |
+
try:
|
390 |
+
print("Removing weight norm...")
|
391 |
+
for l in self.ups:
|
392 |
+
for l_i in l:
|
393 |
+
remove_weight_norm(l_i)
|
394 |
+
for l in self.resblocks:
|
395 |
+
l.remove_weight_norm()
|
396 |
+
remove_weight_norm(self.conv_pre)
|
397 |
+
remove_weight_norm(self.conv_post)
|
398 |
+
except ValueError:
|
399 |
+
print("[INFO] Model already removed weight norm. Skipping!")
|
400 |
+
pass
|
401 |
+
|
402 |
+
# Additional methods for huggingface_hub support
|
403 |
+
def _save_pretrained(self, save_directory: Path) -> None:
|
404 |
+
"""Save weights and config.json from a Pytorch model to a local directory."""
|
405 |
+
|
406 |
+
model_path = save_directory / "bigvgan_generator.pt"
|
407 |
+
torch.save({"generator": self.state_dict()}, model_path)
|
408 |
+
|
409 |
+
config_path = save_directory / "config.json"
|
410 |
+
with open(config_path, "w") as config_file:
|
411 |
+
json.dump(self.h, config_file, indent=4)
|
412 |
+
|
413 |
+
@classmethod
|
414 |
+
def _from_pretrained(
|
415 |
+
cls,
|
416 |
+
*,
|
417 |
+
model_id: str,
|
418 |
+
revision: str,
|
419 |
+
cache_dir: str,
|
420 |
+
force_download: bool,
|
421 |
+
proxies: Optional[Dict],
|
422 |
+
resume_download: bool,
|
423 |
+
local_files_only: bool,
|
424 |
+
token: Union[str, bool, None],
|
425 |
+
map_location: str = "cpu", # Additional argument
|
426 |
+
strict: bool = False, # Additional argument
|
427 |
+
use_cuda_kernel: bool = False,
|
428 |
+
**model_kwargs,
|
429 |
+
):
|
430 |
+
"""Load Pytorch pretrained weights and return the loaded model."""
|
431 |
+
|
432 |
+
# Download and load hyperparameters (h) used by BigVGAN
|
433 |
+
if os.path.isdir(model_id):
|
434 |
+
print("Loading config.json from local directory")
|
435 |
+
config_file = os.path.join(model_id, "config.json")
|
436 |
+
else:
|
437 |
+
config_file = hf_hub_download(
|
438 |
+
repo_id=model_id,
|
439 |
+
filename="config.json",
|
440 |
+
revision=revision,
|
441 |
+
cache_dir=cache_dir,
|
442 |
+
force_download=force_download,
|
443 |
+
proxies=proxies,
|
444 |
+
resume_download=resume_download,
|
445 |
+
token=token,
|
446 |
+
local_files_only=local_files_only,
|
447 |
+
)
|
448 |
+
h = load_hparams_from_json(config_file)
|
449 |
+
|
450 |
+
# instantiate BigVGAN using h
|
451 |
+
if use_cuda_kernel:
|
452 |
+
print(
|
453 |
+
f"[WARNING] You have specified use_cuda_kernel=True during BigVGAN.from_pretrained(). Only inference is supported (training is not implemented)!"
|
454 |
+
)
|
455 |
+
print(
|
456 |
+
f"[WARNING] You need nvcc and ninja installed in your system that matches your PyTorch build is using to build the kernel. If not, the model will fail to initialize or generate incorrect waveform!"
|
457 |
+
)
|
458 |
+
print(
|
459 |
+
f"[WARNING] For detail, see the official GitHub repository: https://github.com/NVIDIA/BigVGAN?tab=readme-ov-file#using-custom-cuda-kernel-for-synthesis"
|
460 |
+
)
|
461 |
+
model = cls(h, use_cuda_kernel=use_cuda_kernel)
|
462 |
+
|
463 |
+
# Download and load pretrained generator weight
|
464 |
+
if os.path.isdir(model_id):
|
465 |
+
print("Loading weights from local directory")
|
466 |
+
model_file = os.path.join(model_id, "bigvgan_generator.pt")
|
467 |
+
else:
|
468 |
+
print(f"Loading weights from {model_id}")
|
469 |
+
model_file = hf_hub_download(
|
470 |
+
repo_id=model_id,
|
471 |
+
filename="bigvgan_generator.pt",
|
472 |
+
revision=revision,
|
473 |
+
cache_dir=cache_dir,
|
474 |
+
force_download=force_download,
|
475 |
+
proxies=proxies,
|
476 |
+
resume_download=resume_download,
|
477 |
+
token=token,
|
478 |
+
local_files_only=local_files_only,
|
479 |
+
)
|
480 |
+
|
481 |
+
checkpoint_dict = torch.load(model_file, map_location=map_location, weights_only=True)
|
482 |
+
|
483 |
+
try:
|
484 |
+
model.load_state_dict(checkpoint_dict["generator"])
|
485 |
+
except RuntimeError:
|
486 |
+
print(
|
487 |
+
f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
|
488 |
+
)
|
489 |
+
model.remove_weight_norm()
|
490 |
+
model.load_state_dict(checkpoint_dict["generator"])
|
491 |
+
|
492 |
+
return model
|
modules/audio_detokenizer/vocoder/utils.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from librosa.filters import mel as librosa_mel_fn
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
mel_basis_cache = {}
|
5 |
+
hann_window_cache = {}
|
6 |
+
|
7 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
8 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
9 |
+
|
10 |
+
|
11 |
+
def spectral_normalize_torch(magnitudes):
|
12 |
+
return dynamic_range_compression_torch(magnitudes)
|
13 |
+
|
14 |
+
def get_melspec(
|
15 |
+
y: torch.Tensor,
|
16 |
+
n_fft: int,
|
17 |
+
num_mels: int,
|
18 |
+
sampling_rate: int,
|
19 |
+
hop_size: int,
|
20 |
+
win_size: int,
|
21 |
+
fmin: int,
|
22 |
+
fmax: int = None,
|
23 |
+
center: bool = False,
|
24 |
+
) -> torch.Tensor:
|
25 |
+
"""
|
26 |
+
Calculate the mel spectrogram of an input signal.
|
27 |
+
This function uses slaney norm for the librosa mel filterbank (using librosa.filters.mel) and uses Hann window for STFT (using torch.stft).
|
28 |
+
|
29 |
+
Args:
|
30 |
+
y (torch.Tensor): Input signal.
|
31 |
+
n_fft (int): FFT size.
|
32 |
+
num_mels (int): Number of mel bins.
|
33 |
+
sampling_rate (int): Sampling rate of the input signal.
|
34 |
+
hop_size (int): Hop size for STFT.
|
35 |
+
win_size (int): Window size for STFT.
|
36 |
+
fmin (int): Minimum frequency for mel filterbank.
|
37 |
+
fmax (int): Maximum frequency for mel filterbank. If None, defaults to half the sampling rate (fmax = sr / 2.0) inside librosa_mel_fn
|
38 |
+
center (bool): Whether to pad the input to center the frames. Default is False.
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
torch.Tensor: Mel spectrogram.
|
42 |
+
"""
|
43 |
+
if torch.min(y) < -1.0:
|
44 |
+
print(f"[WARNING] Min value of input waveform signal is {torch.min(y)}")
|
45 |
+
if torch.max(y) > 1.0:
|
46 |
+
print(f"[WARNING] Max value of input waveform signal is {torch.max(y)}")
|
47 |
+
|
48 |
+
device = y.device
|
49 |
+
key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}"
|
50 |
+
|
51 |
+
if key not in mel_basis_cache:
|
52 |
+
mel = librosa_mel_fn(
|
53 |
+
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
54 |
+
)
|
55 |
+
mel_basis_cache[key] = torch.from_numpy(mel).float().to(device)
|
56 |
+
hann_window_cache[key] = torch.hann_window(win_size).to(device)
|
57 |
+
|
58 |
+
mel_basis = mel_basis_cache[key]
|
59 |
+
hann_window = hann_window_cache[key]
|
60 |
+
|
61 |
+
padding = (n_fft - hop_size) // 2
|
62 |
+
y = torch.nn.functional.pad(
|
63 |
+
y.unsqueeze(1), (padding, padding), mode="reflect"
|
64 |
+
).squeeze(1)
|
65 |
+
|
66 |
+
spec = torch.stft(
|
67 |
+
y,
|
68 |
+
n_fft,
|
69 |
+
hop_length=hop_size,
|
70 |
+
win_length=win_size,
|
71 |
+
window=hann_window,
|
72 |
+
center=center,
|
73 |
+
pad_mode="reflect",
|
74 |
+
normalized=False,
|
75 |
+
onesided=True,
|
76 |
+
return_complex=True,
|
77 |
+
)
|
78 |
+
spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
|
79 |
+
|
80 |
+
mel_spec = torch.matmul(mel_basis, spec)
|
81 |
+
mel_spec = spectral_normalize_torch(mel_spec)
|
82 |
+
|
83 |
+
return mel_spec
|
84 |
+
|
85 |
+
|
86 |
+
class AttrDict(dict):
|
87 |
+
def __init__(self, *args, **kwargs):
|
88 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
89 |
+
self.__dict__ = self
|
90 |
+
|
91 |
+
def load_checkpoint(filepath, device):
|
92 |
+
assert os.path.isfile(filepath)
|
93 |
+
print(f"Loading '{filepath}'")
|
94 |
+
checkpoint_dict = torch.load(filepath, map_location=device, weights_only=True)
|
95 |
+
print("Complete.")
|
96 |
+
return checkpoint_dict
|
97 |
+
|
98 |
+
def init_weights(m, mean=0.0, std=0.01):
|
99 |
+
classname = m.__class__.__name__
|
100 |
+
if classname.find("Conv") != -1:
|
101 |
+
m.weight.data.normal_(mean, std)
|
102 |
+
|
103 |
+
|
104 |
+
def get_padding(kernel_size, dilation=1):
|
105 |
+
return int((kernel_size * dilation - dilation) / 2)
|
modules/audio_tokenizer/audio_tokenizer.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import librosa
|
3 |
+
import yaml
|
4 |
+
from transformers import Wav2Vec2BertModel, SeamlessM4TFeatureExtractor
|
5 |
+
import safetensors
|
6 |
+
import accelerate
|
7 |
+
import soundfile as sf
|
8 |
+
import math
|
9 |
+
from einops import rearrange
|
10 |
+
from modules.audio_tokenizer.rep_codec import RepCodec
|
11 |
+
|
12 |
+
|
13 |
+
class AudioTokenizer(object):
|
14 |
+
def __init__(self, **kwargs):
|
15 |
+
self.device = kwargs.pop('device')
|
16 |
+
print(self.device)
|
17 |
+
# tokenize
|
18 |
+
feat_stats = kwargs.pop('feat_stats')
|
19 |
+
feat_stats = torch.load(feat_stats, map_location='cpu')
|
20 |
+
self.feat_mean = feat_stats['mean']
|
21 |
+
self.feat_std = torch.sqrt(feat_stats['var'])
|
22 |
+
wav2vec_ckpt = kwargs.pop("wav2vec_ckpt")
|
23 |
+
self.semantic_model = Wav2Vec2BertModel.from_pretrained(wav2vec_ckpt)
|
24 |
+
self.semantic_model.eval()
|
25 |
+
self.semantic_model.to(self.device)
|
26 |
+
self.semantic_processor = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
|
27 |
+
|
28 |
+
self.semantic_codec = RepCodec()
|
29 |
+
self.semantic_codec.eval()
|
30 |
+
pretrained_path = kwargs.pop("semantic_codec_ckpt")
|
31 |
+
safetensors.torch.load_model(self.semantic_codec, pretrained_path)
|
32 |
+
self.semantic_codec.to(self.device)
|
33 |
+
|
34 |
+
self.max_length = 2048
|
35 |
+
|
36 |
+
|
37 |
+
@torch.no_grad()
|
38 |
+
def tokenize(self, speech):
|
39 |
+
# Input:
|
40 |
+
# speech: torch tensor, shape[B, N_speech]
|
41 |
+
# Output:
|
42 |
+
# semantic token: torch tensor, shape[B, N]
|
43 |
+
|
44 |
+
inputs = self.semantic_processor(speech.cpu(), sampling_rate=16000, return_tensors="pt")
|
45 |
+
input_features = inputs["input_features"].to(self.device)
|
46 |
+
attention_mask = inputs["attention_mask"].to(self.device)
|
47 |
+
seg_num = math.ceil(input_features.shape[1] / self.max_length)
|
48 |
+
pad_num = seg_num * self.max_length - input_features.shape[1]
|
49 |
+
input_features = torch.nn.functional.pad(input_features, (0, 0, 0, pad_num, 0,0), value=0)
|
50 |
+
attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_num, 0, 0), value=0)
|
51 |
+
input_features = rearrange(input_features, "b (s n) d -> (b s) n d", s =seg_num)
|
52 |
+
attention_mask = rearrange(attention_mask, "b (s n) -> (b s) n", s=seg_num)
|
53 |
+
|
54 |
+
|
55 |
+
feats = self.semantic_model(
|
56 |
+
input_features=input_features,
|
57 |
+
attention_mask=attention_mask,
|
58 |
+
output_hidden_states=True,
|
59 |
+
)
|
60 |
+
feat = feats.hidden_states[17]
|
61 |
+
feat = rearrange(feat, "(b s) n d -> b (s n) d", s=seg_num)
|
62 |
+
feat = feat[:, :feat.shape[1]-pad_num, :]
|
63 |
+
feat = (feat - self.feat_mean.to(feat)) / self.feat_std.to(feat)
|
64 |
+
semantic_token, _ = self.semantic_codec.quantize(feat)
|
65 |
+
return semantic_token
|
66 |
+
|
67 |
+
def get_audio_tokenizer():
|
68 |
+
config = dict()
|
69 |
+
config['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'
|
70 |
+
config['feat_stats'] = 'resources/audio_tokenizer/stats.pt'
|
71 |
+
config['wav2vec_ckpt'] = 'facebook/w2v-bert-2.0'
|
72 |
+
config['semantic_codec_ckpt'] = 'resources/audio_tokenizer/model.safetensors'
|
73 |
+
audio_tokenizer = AudioTokenizer(**config)
|
74 |
+
return audio_tokenizer
|
75 |
+
|
76 |
+
get_audio_tokenizer()
|
modules/audio_tokenizer/quantize/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .vector_quantize import VectorQuantize
|
2 |
+
from .residual_vq import ResidualVQ
|
3 |
+
from .factorized_vector_quantize import FactorizedVectorQuantize
|
modules/audio_tokenizer/quantize/factorized_vector_quantize.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from einops import rearrange
|
6 |
+
from torch.nn.utils import weight_norm
|
7 |
+
|
8 |
+
|
9 |
+
def WNConv1d(*args, **kwargs):
|
10 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
11 |
+
|
12 |
+
|
13 |
+
def WNConvTranspose1d(*args, **kwargs):
|
14 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
15 |
+
|
16 |
+
|
17 |
+
class FactorizedVectorQuantize(nn.Module):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
input_dim,
|
21 |
+
codebook_size,
|
22 |
+
codebook_dim,
|
23 |
+
commitment=0.005,
|
24 |
+
codebook_loss_weight=1.0,
|
25 |
+
use_l2_normlize=True,
|
26 |
+
):
|
27 |
+
super().__init__()
|
28 |
+
self.input_dim = input_dim
|
29 |
+
self.codebook_size = codebook_size
|
30 |
+
self.codebook_dim = codebook_dim
|
31 |
+
self.commitment = commitment
|
32 |
+
self.codebook_loss_weight = codebook_loss_weight
|
33 |
+
self.use_l2_normlize = use_l2_normlize
|
34 |
+
|
35 |
+
if self.input_dim != self.codebook_dim:
|
36 |
+
self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)
|
37 |
+
self.out_project = WNConv1d(
|
38 |
+
self.codebook_dim, self.input_dim, kernel_size=1
|
39 |
+
)
|
40 |
+
|
41 |
+
else:
|
42 |
+
self.in_project = nn.Identity()
|
43 |
+
self.out_project = nn.Identity()
|
44 |
+
|
45 |
+
self.codebook = nn.Embedding(self.codebook_size, self.codebook_dim)
|
46 |
+
|
47 |
+
def forward(self, z):
|
48 |
+
"""
|
49 |
+
Parameters
|
50 |
+
----------
|
51 |
+
z: torch.Tensor[B x D x T]
|
52 |
+
|
53 |
+
Returns
|
54 |
+
-------
|
55 |
+
z_q: torch.Tensor[B x D x T]
|
56 |
+
Quantized continuous representation of input
|
57 |
+
commit_loss: Tensor[B]
|
58 |
+
Commitment loss to train encoder to predict vectors closer to codebook entries
|
59 |
+
codebook_loss: Tensor[B]
|
60 |
+
Codebook loss to update the codebook
|
61 |
+
indices: torch.Tensor[B x T]
|
62 |
+
Codebook indices (quantized discrete representation of input)
|
63 |
+
z_e: torch.Tensor[B x D x T]
|
64 |
+
Projected latents (continuous representation of input before quantization)
|
65 |
+
"""
|
66 |
+
|
67 |
+
# Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim
|
68 |
+
z_e = self.in_project(z)
|
69 |
+
z_q, indices = self.decode_latents(z_e)
|
70 |
+
|
71 |
+
# Compute commitment loss and codebook loss
|
72 |
+
if self.training:
|
73 |
+
commit_loss = (
|
74 |
+
F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
|
75 |
+
* self.commitment
|
76 |
+
)
|
77 |
+
codebook_loss = (
|
78 |
+
F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
|
79 |
+
* self.codebook_loss_weight
|
80 |
+
)
|
81 |
+
else:
|
82 |
+
commit_loss = torch.zeros(z.shape[0], device=z.device)
|
83 |
+
codebook_loss = torch.zeros(z.shape[0], device=z.device)
|
84 |
+
|
85 |
+
z_q = z_e + (z_q - z_e).detach()
|
86 |
+
|
87 |
+
z_q = self.out_project(z_q)
|
88 |
+
|
89 |
+
return z_q, commit_loss, codebook_loss, indices, z_e
|
90 |
+
|
91 |
+
def embed_code(self, embed_id):
|
92 |
+
return F.embedding(embed_id, self.codebook.weight)
|
93 |
+
|
94 |
+
def decode_code(self, embed_id):
|
95 |
+
return self.embed_code(embed_id).transpose(1, 2)
|
96 |
+
|
97 |
+
def decode_latents(self, latents):
|
98 |
+
encodings = rearrange(latents, "b d t -> (b t) d")
|
99 |
+
codebook = self.codebook.weight
|
100 |
+
|
101 |
+
# L2 normalize encodings and codebook
|
102 |
+
if self.use_l2_normlize:
|
103 |
+
encodings = F.normalize(encodings)
|
104 |
+
codebook = F.normalize(codebook)
|
105 |
+
|
106 |
+
# Compute euclidean distance between encodings and codebook,
|
107 |
+
# if use_l2_normlize is True, the distance is equal to cosine distance
|
108 |
+
dist = (
|
109 |
+
encodings.pow(2).sum(1, keepdim=True)
|
110 |
+
- 2 * encodings @ codebook.t()
|
111 |
+
+ codebook.pow(2).sum(1, keepdim=True).t()
|
112 |
+
)
|
113 |
+
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
114 |
+
z_q = self.decode_code(indices)
|
115 |
+
|
116 |
+
return z_q, indices
|
117 |
+
|
118 |
+
def vq2emb(self, vq, out_proj=True):
|
119 |
+
emb = self.decode_code(vq)
|
120 |
+
if out_proj:
|
121 |
+
emb = self.out_project(emb)
|
122 |
+
return emb
|
123 |
+
|
124 |
+
def latent2dist(self, latents):
|
125 |
+
encodings = rearrange(latents, "b d t -> (b t) d")
|
126 |
+
codebook = self.codebook.weight
|
127 |
+
|
128 |
+
# L2 normalize encodings and codebook
|
129 |
+
if self.use_l2_normlize:
|
130 |
+
encodings = F.normalize(encodings)
|
131 |
+
codebook = F.normalize(codebook)
|
132 |
+
|
133 |
+
# Compute euclidean distance between encodings and codebook,
|
134 |
+
# if use_l2_normlize is True, the distance is equal to cosine distance
|
135 |
+
dist = (
|
136 |
+
encodings.pow(2).sum(1, keepdim=True)
|
137 |
+
- 2 * encodings @ codebook.t()
|
138 |
+
+ codebook.pow(2).sum(1, keepdim=True).t()
|
139 |
+
) # (b*t, k)
|
140 |
+
|
141 |
+
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
142 |
+
dist = rearrange(dist, "(b t) k -> b t k", b=latents.size(0))
|
143 |
+
z_q = self.decode_code(indices)
|
144 |
+
|
145 |
+
return -dist, indices, z_q
|
modules/audio_tokenizer/quantize/residual_vq.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from einops import rearrange
|
8 |
+
from torch.nn.utils import weight_norm
|
9 |
+
|
10 |
+
|
11 |
+
from .vector_quantize import VectorQuantize
|
12 |
+
from .factorized_vector_quantize import FactorizedVectorQuantize
|
13 |
+
|
14 |
+
|
15 |
+
class ResidualVQ(nn.Module):
|
16 |
+
"""
|
17 |
+
Introduced in SoundStream: An end2end neural audio codec
|
18 |
+
https://arxiv.org/abs/2107.03312
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
input_dim: int = 256,
|
24 |
+
num_quantizers: int = 8,
|
25 |
+
codebook_size: int = 1024,
|
26 |
+
codebook_dim: int = 256,
|
27 |
+
quantizer_type: str = "vq", # "vq" or "fvq" or "lfq"
|
28 |
+
quantizer_dropout: float = 0.5,
|
29 |
+
**kwargs,
|
30 |
+
):
|
31 |
+
super().__init__()
|
32 |
+
|
33 |
+
self.input_dim = input_dim
|
34 |
+
self.num_quantizers = num_quantizers
|
35 |
+
self.codebook_size = codebook_size
|
36 |
+
self.codebook_dim = codebook_dim
|
37 |
+
self.quantizer_type = quantizer_type
|
38 |
+
self.quantizer_dropout = quantizer_dropout
|
39 |
+
|
40 |
+
if quantizer_type == "vq":
|
41 |
+
VQ = VectorQuantize
|
42 |
+
elif quantizer_type == "fvq":
|
43 |
+
VQ = FactorizedVectorQuantize
|
44 |
+
else:
|
45 |
+
raise ValueError(f"Unknown quantizer type {quantizer_type}")
|
46 |
+
|
47 |
+
self.quantizers = nn.ModuleList(
|
48 |
+
[
|
49 |
+
VQ(
|
50 |
+
input_dim=input_dim,
|
51 |
+
codebook_size=codebook_size,
|
52 |
+
codebook_dim=codebook_dim,
|
53 |
+
**kwargs,
|
54 |
+
)
|
55 |
+
for _ in range(num_quantizers)
|
56 |
+
]
|
57 |
+
)
|
58 |
+
|
59 |
+
def forward(self, z, n_quantizers: int = None):
|
60 |
+
"""
|
61 |
+
Parameters
|
62 |
+
----------
|
63 |
+
z : Tensor[B x D x T]
|
64 |
+
n_quantizers : int, optional
|
65 |
+
No. of quantizers to use
|
66 |
+
(n_quantizers < self.n_codebooks ex: for quantizer dropout)
|
67 |
+
Note: if `self.quantizer_dropout` is True, this argument is ignored
|
68 |
+
when in training mode, and a random number of quantizers is used.
|
69 |
+
Returns
|
70 |
+
-------
|
71 |
+
"quantized_out" : Tensor[B x D x T]
|
72 |
+
Quantized continuous representation of input
|
73 |
+
"all_indices" : Tensor[N x B x T]
|
74 |
+
Codebook indices for each codebook
|
75 |
+
(quantized discrete representation of input)
|
76 |
+
"all_commit_losses" : Tensor[N]
|
77 |
+
"all_codebook_losses" : Tensor[N]
|
78 |
+
"all_quantized" : Tensor[N x B x D x T]
|
79 |
+
"""
|
80 |
+
|
81 |
+
quantized_out = 0.0
|
82 |
+
residual = z
|
83 |
+
|
84 |
+
all_commit_losses = []
|
85 |
+
all_codebook_losses = []
|
86 |
+
all_indices = []
|
87 |
+
all_quantized = []
|
88 |
+
|
89 |
+
if n_quantizers is None:
|
90 |
+
n_quantizers = self.num_quantizers
|
91 |
+
|
92 |
+
if self.training:
|
93 |
+
n_quantizers = torch.ones((z.shape[0],)) * self.num_quantizers + 1
|
94 |
+
dropout = torch.randint(1, self.num_quantizers + 1, (z.shape[0],))
|
95 |
+
n_dropout = int(z.shape[0] * self.quantizer_dropout)
|
96 |
+
n_quantizers[:n_dropout] = dropout[:n_dropout]
|
97 |
+
n_quantizers = n_quantizers.to(z.device)
|
98 |
+
|
99 |
+
for i, quantizer in enumerate(self.quantizers):
|
100 |
+
if self.training is False and i >= n_quantizers:
|
101 |
+
break
|
102 |
+
|
103 |
+
z_q_i, commit_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
|
104 |
+
residual
|
105 |
+
)
|
106 |
+
|
107 |
+
# Create mask to apply quantizer dropout
|
108 |
+
mask = (
|
109 |
+
torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
|
110 |
+
)
|
111 |
+
quantized_out = quantized_out + z_q_i * mask[:, None, None]
|
112 |
+
residual = residual - z_q_i
|
113 |
+
|
114 |
+
commit_loss_i = (commit_loss_i * mask).mean()
|
115 |
+
codebook_loss_i = (codebook_loss_i * mask).mean()
|
116 |
+
|
117 |
+
all_commit_losses.append(commit_loss_i)
|
118 |
+
all_codebook_losses.append(codebook_loss_i)
|
119 |
+
all_indices.append(indices_i)
|
120 |
+
all_quantized.append(z_q_i)
|
121 |
+
|
122 |
+
all_commit_losses, all_codebook_losses, all_indices, all_quantized = map(
|
123 |
+
torch.stack,
|
124 |
+
(all_commit_losses, all_codebook_losses, all_indices, all_quantized),
|
125 |
+
)
|
126 |
+
|
127 |
+
return (
|
128 |
+
quantized_out,
|
129 |
+
all_indices,
|
130 |
+
all_commit_losses,
|
131 |
+
all_codebook_losses,
|
132 |
+
all_quantized,
|
133 |
+
)
|
134 |
+
|
135 |
+
def vq2emb(self, vq, n_quantizers=None):
|
136 |
+
quantized_out = 0.0
|
137 |
+
if n_quantizers is None:
|
138 |
+
n_quantizers = self.num_quantizers
|
139 |
+
for idx, quantizer in enumerate(self.quantizers):
|
140 |
+
if idx >= n_quantizers:
|
141 |
+
break
|
142 |
+
quantized_out += quantizer.vq2emb(vq[idx])
|
143 |
+
return quantized_out
|
144 |
+
|
145 |
+
def latent2dist(self, z, n_quantizers=None):
|
146 |
+
quantized_out = 0.0
|
147 |
+
residual = z
|
148 |
+
|
149 |
+
all_dists = []
|
150 |
+
all_indices = []
|
151 |
+
|
152 |
+
if n_quantizers is None:
|
153 |
+
n_quantizers = self.num_quantizers
|
154 |
+
|
155 |
+
for i, quantizer in enumerate(self.quantizers):
|
156 |
+
if self.training is False and i >= n_quantizers:
|
157 |
+
break
|
158 |
+
dist_i, indices_i, z_q_i = quantizer.latent2dist(residual)
|
159 |
+
all_dists.append(dist_i)
|
160 |
+
all_indices.append(indices_i)
|
161 |
+
|
162 |
+
quantized_out = quantized_out + z_q_i
|
163 |
+
residual = residual - z_q_i
|
164 |
+
|
165 |
+
all_dists = torch.stack(all_dists)
|
166 |
+
all_indices = torch.stack(all_indices)
|
167 |
+
|
168 |
+
return all_dists, all_indices
|
modules/audio_tokenizer/quantize/vector_quantize.py
ADDED
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from einops import rearrange, repeat
|
6 |
+
from torch.nn.utils import weight_norm
|
7 |
+
|
8 |
+
|
9 |
+
def WNConv1d(*args, **kwargs):
|
10 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
11 |
+
|
12 |
+
|
13 |
+
def WNConvTranspose1d(*args, **kwargs):
|
14 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
15 |
+
|
16 |
+
|
17 |
+
def l2norm(t):
|
18 |
+
return F.normalize(t, p=2, dim=-1)
|
19 |
+
|
20 |
+
|
21 |
+
def ema_inplace(moving_avg, new, decay):
|
22 |
+
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
23 |
+
|
24 |
+
|
25 |
+
def laplace_smoothing(x, n_categories, eps=1e-5):
|
26 |
+
return (x + eps) / (x.sum() + n_categories * eps)
|
27 |
+
|
28 |
+
|
29 |
+
def sample_vectors(samples, num):
|
30 |
+
num_samples, device = samples.shape[0], samples.device
|
31 |
+
|
32 |
+
if num_samples >= num:
|
33 |
+
indices = torch.randperm(num_samples, device=device)[:num]
|
34 |
+
else:
|
35 |
+
indices = torch.randint(0, num_samples, (num,), device=device)
|
36 |
+
|
37 |
+
return samples[indices]
|
38 |
+
|
39 |
+
|
40 |
+
def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False):
|
41 |
+
dim, dtype, device = samples.shape[-1], samples.dtype, samples.device
|
42 |
+
|
43 |
+
means = sample_vectors(samples, num_clusters)
|
44 |
+
|
45 |
+
for _ in range(num_iters):
|
46 |
+
if use_cosine_sim:
|
47 |
+
dists = samples @ means.t()
|
48 |
+
else:
|
49 |
+
diffs = rearrange(samples, "n d -> n () d") - rearrange(
|
50 |
+
means, "c d -> () c d"
|
51 |
+
)
|
52 |
+
dists = -(diffs**2).sum(dim=-1)
|
53 |
+
|
54 |
+
buckets = dists.max(dim=-1).indices
|
55 |
+
bins = torch.bincount(buckets, minlength=num_clusters)
|
56 |
+
zero_mask = bins == 0
|
57 |
+
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
58 |
+
|
59 |
+
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
|
60 |
+
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
|
61 |
+
new_means = new_means / bins_min_clamped[..., None]
|
62 |
+
|
63 |
+
if use_cosine_sim:
|
64 |
+
new_means = l2norm(new_means)
|
65 |
+
|
66 |
+
means = torch.where(zero_mask[..., None], means, new_means)
|
67 |
+
|
68 |
+
return means, bins
|
69 |
+
|
70 |
+
|
71 |
+
class EuclideanCodebook(nn.Module):
|
72 |
+
def __init__(
|
73 |
+
self,
|
74 |
+
dim,
|
75 |
+
codebook_size,
|
76 |
+
kmeans_init=False,
|
77 |
+
kmeans_iters=10,
|
78 |
+
decay=0.8,
|
79 |
+
eps=1e-5,
|
80 |
+
threshold_ema_dead_code=2,
|
81 |
+
weight_init=False,
|
82 |
+
):
|
83 |
+
super().__init__()
|
84 |
+
|
85 |
+
self.decay = decay
|
86 |
+
init_fn = torch.randn if not weight_init else torch.zeros
|
87 |
+
embed = init_fn(codebook_size, dim)
|
88 |
+
|
89 |
+
if weight_init:
|
90 |
+
nn.init.uniform_(embed, -1 / codebook_size, 1 / codebook_size)
|
91 |
+
|
92 |
+
self.codebook_size = codebook_size
|
93 |
+
self.kmeans_iters = kmeans_iters
|
94 |
+
self.eps = eps
|
95 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
96 |
+
|
97 |
+
self.register_buffer(
|
98 |
+
"initted", torch.Tensor([not kmeans_init])
|
99 |
+
) # if kmeans_init is True, then initted is False; otherwise, initted is True
|
100 |
+
self.register_buffer("cluster_size", torch.zeros(codebook_size))
|
101 |
+
self.register_buffer("embed", embed)
|
102 |
+
self.register_buffer("embed_avg", embed.clone())
|
103 |
+
|
104 |
+
def init_embed_(self, data):
|
105 |
+
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
|
106 |
+
self.embed.data.copy_(embed)
|
107 |
+
self.embed_avg.data.copy_(embed)
|
108 |
+
self.cluster_size.data.copy_(cluster_size)
|
109 |
+
self.initted.data.copy_(torch.Tensor([True]))
|
110 |
+
|
111 |
+
def replace(self, samples, mask):
|
112 |
+
modified_codebook = torch.where(
|
113 |
+
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
|
114 |
+
)
|
115 |
+
self.embed.data.copy_(modified_codebook)
|
116 |
+
|
117 |
+
def expire_codes_(self, batch_samples):
|
118 |
+
if self.threshold_ema_dead_code == 0:
|
119 |
+
return
|
120 |
+
|
121 |
+
expired_codes = self.cluster_size < self.threshold_ema_dead_code
|
122 |
+
if not torch.any(expired_codes):
|
123 |
+
return
|
124 |
+
batch_samples = rearrange(batch_samples, "... d -> (...) d")
|
125 |
+
self.replace(batch_samples, mask=expired_codes)
|
126 |
+
|
127 |
+
def forward(self, x):
|
128 |
+
shape, dtype = x.shape, x.dtype
|
129 |
+
flatten = rearrange(x, "... d -> (...) d")
|
130 |
+
embed = self.embed.t() # (codebook_size, dim) -> (dim, codebook_size)
|
131 |
+
|
132 |
+
if not self.initted:
|
133 |
+
self.init_embed_(flatten)
|
134 |
+
|
135 |
+
dist = -(
|
136 |
+
flatten.pow(2).sum(1, keepdim=True)
|
137 |
+
- 2 * flatten @ embed
|
138 |
+
+ embed.pow(2).sum(0, keepdim=True)
|
139 |
+
)
|
140 |
+
|
141 |
+
embed_ind = dist.max(dim=-1).indices
|
142 |
+
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
|
143 |
+
embed_ind = embed_ind.view(*shape[:-1])
|
144 |
+
quantize = F.embedding(embed_ind, self.embed)
|
145 |
+
|
146 |
+
if self.training:
|
147 |
+
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
|
148 |
+
embed_sum = (
|
149 |
+
flatten.t() @ embed_onehot
|
150 |
+
) # (dim, ...) @ (..., codebook_size) -> (dim, codebook_size)
|
151 |
+
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
|
152 |
+
cluster_size = (
|
153 |
+
laplace_smoothing(self.cluster_size, self.codebook_size, self.eps)
|
154 |
+
* self.cluster_size.sum()
|
155 |
+
)
|
156 |
+
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
157 |
+
self.embed.data.copy_(embed_normalized)
|
158 |
+
self.expire_codes_(x)
|
159 |
+
|
160 |
+
return quantize, embed_ind
|
161 |
+
|
162 |
+
def vq2emb(self, vq):
|
163 |
+
quantize = F.embedding(vq, self.embed)
|
164 |
+
return quantize
|
165 |
+
|
166 |
+
def latent2dist(self, x):
|
167 |
+
shape, dtype = x.shape, x.dtype
|
168 |
+
flatten = rearrange(x, "... d -> (...) d")
|
169 |
+
embed = self.embed.t() # (codebook_size, dim) -> (dim, codebook_size)
|
170 |
+
|
171 |
+
if not self.initted:
|
172 |
+
self.init_embed_(flatten)
|
173 |
+
|
174 |
+
dist = -(
|
175 |
+
flatten.pow(2).sum(1, keepdim=True)
|
176 |
+
- 2 * flatten @ embed
|
177 |
+
+ embed.pow(2).sum(0, keepdim=True)
|
178 |
+
)
|
179 |
+
|
180 |
+
embed_ind = dist.max(dim=-1).indices
|
181 |
+
embed_ind = embed_ind.view(*shape[:-1])
|
182 |
+
quantize = F.embedding(embed_ind, self.embed)
|
183 |
+
|
184 |
+
dist = dist.view(*shape[:-1], -1)
|
185 |
+
|
186 |
+
return dist, embed_ind, quantize
|
187 |
+
|
188 |
+
|
189 |
+
class SimpleCodebook(nn.Module):
|
190 |
+
def __init__(
|
191 |
+
self,
|
192 |
+
dim,
|
193 |
+
codebook_size,
|
194 |
+
use_l2_normlize=False,
|
195 |
+
):
|
196 |
+
super().__init__()
|
197 |
+
|
198 |
+
self.dim = dim
|
199 |
+
self.codebook_size = codebook_size
|
200 |
+
self.use_l2_normlize = use_l2_normlize
|
201 |
+
|
202 |
+
self.embed = nn.Embedding(self.codebook_size, self.dim)
|
203 |
+
|
204 |
+
def forward(self, x):
|
205 |
+
shape, dtype = x.shape, x.dtype
|
206 |
+
flatten = rearrange(x, "... d -> (...) d")
|
207 |
+
embed = self.embed.weight.t() # (codebook_size, dim) -> (dim, codebook_size)
|
208 |
+
|
209 |
+
if self.use_l2_normlize:
|
210 |
+
flatten = F.normalize(flatten)
|
211 |
+
embed = F.normalize(embed)
|
212 |
+
|
213 |
+
dist = -(
|
214 |
+
flatten.pow(2).sum(1, keepdim=True)
|
215 |
+
- 2 * flatten @ embed
|
216 |
+
+ embed.pow(2).sum(0, keepdim=True)
|
217 |
+
)
|
218 |
+
|
219 |
+
embed_ind = dist.max(dim=-1).indices
|
220 |
+
embed_ind = embed_ind.view(*shape[:-1])
|
221 |
+
quantize = F.embedding(embed_ind, self.embed)
|
222 |
+
|
223 |
+
return quantize, embed_ind
|
224 |
+
|
225 |
+
def vq2emb(self, vq):
|
226 |
+
quantize = F.embedding(vq, self.embed.weight)
|
227 |
+
return quantize
|
228 |
+
|
229 |
+
def latent2dist(self, x):
|
230 |
+
shape, dtype = x.shape, x.dtype
|
231 |
+
flatten = rearrange(x, "... d -> (...) d")
|
232 |
+
embed = self.embed.weight.t() # (codebook_size, dim) -> (dim, codebook_size)
|
233 |
+
|
234 |
+
if self.use_l2_normlize:
|
235 |
+
flatten = F.normalize(flatten)
|
236 |
+
embed = F.normalize(embed)
|
237 |
+
|
238 |
+
dist = -(
|
239 |
+
flatten.pow(2).sum(1, keepdim=True)
|
240 |
+
- 2 * flatten @ embed
|
241 |
+
+ embed.pow(2).sum(0, keepdim=True)
|
242 |
+
)
|
243 |
+
|
244 |
+
embed_ind = dist.max(dim=-1).indices
|
245 |
+
embed_ind = embed_ind.view(*shape[:-1])
|
246 |
+
quantize = F.embedding(embed_ind, self.embed)
|
247 |
+
|
248 |
+
dist = dist.view(*shape[:-1], -1)
|
249 |
+
|
250 |
+
return dist, embed_ind, quantize
|
251 |
+
|
252 |
+
|
253 |
+
class VectorQuantize(nn.Module):
|
254 |
+
"""Vector quantization and factorized vecotor quantization implementation
|
255 |
+
Args:
|
256 |
+
input_dim (int): Dimension of input.
|
257 |
+
codebook_size (int): Codebook size.
|
258 |
+
codebook_dim (int): Codebook dimension. We suggest use codebook_dim = input_dim
|
259 |
+
if use codebook_type == "euclidean", otherwise, if you want to use
|
260 |
+
factorized vector quantization, use codebook_dim as small number (e.g. 8 or 32).
|
261 |
+
commitment (float): Weight for commitment loss.
|
262 |
+
use_l2_normlize (bool): Whether to use l2 normlized codes for factorized vecotor quantization,
|
263 |
+
we suggest use it as True if you want to use factorized vector quantization
|
264 |
+
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
265 |
+
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
266 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
267 |
+
epsilon (float): Epsilon value for numerical stability.
|
268 |
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
269 |
+
that have an exponential moving average cluster size less than the specified threshold with
|
270 |
+
randomly selected vector from the current batch.
|
271 |
+
"""
|
272 |
+
|
273 |
+
def __init__(
|
274 |
+
self,
|
275 |
+
input_dim,
|
276 |
+
codebook_size,
|
277 |
+
codebook_dim,
|
278 |
+
commitment=0.005,
|
279 |
+
codebook_loss_weight=1.0,
|
280 |
+
use_l2_normlize=False,
|
281 |
+
codebook_type="euclidean", # "euclidean" or "simple"
|
282 |
+
kmeans_init=False,
|
283 |
+
kmeans_iters=10,
|
284 |
+
decay=0.8,
|
285 |
+
eps=1e-5,
|
286 |
+
threshold_ema_dead_code=2,
|
287 |
+
weight_init=False,
|
288 |
+
):
|
289 |
+
super().__init__()
|
290 |
+
self.input_dim = input_dim
|
291 |
+
self.codebook_size = codebook_size
|
292 |
+
self.codebook_dim = codebook_dim
|
293 |
+
self.commitment = commitment
|
294 |
+
self.codebook_loss_weight = codebook_loss_weight
|
295 |
+
self.use_l2_normlize = use_l2_normlize
|
296 |
+
self.codebook_type = codebook_type
|
297 |
+
self.kmeans_init = kmeans_init
|
298 |
+
self.kmeans_iters = kmeans_iters
|
299 |
+
self.decay = decay
|
300 |
+
self.eps = eps
|
301 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
302 |
+
self.weight_init = weight_init
|
303 |
+
|
304 |
+
if self.input_dim != self.codebook_dim:
|
305 |
+
self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)
|
306 |
+
self.out_project = WNConv1d(
|
307 |
+
self.codebook_dim, self.input_dim, kernel_size=1
|
308 |
+
)
|
309 |
+
|
310 |
+
else:
|
311 |
+
self.in_project = nn.Identity()
|
312 |
+
self.out_project = nn.Identity()
|
313 |
+
|
314 |
+
if self.codebook_type == "euclidean":
|
315 |
+
self.codebook = EuclideanCodebook(
|
316 |
+
self.codebook_dim,
|
317 |
+
codebook_size=self.codebook_size,
|
318 |
+
kmeans_init=self.kmeans_init,
|
319 |
+
kmeans_iters=self.kmeans_iters,
|
320 |
+
decay=self.decay,
|
321 |
+
eps=self.eps,
|
322 |
+
threshold_ema_dead_code=self.threshold_ema_dead_code,
|
323 |
+
weight_init=self.weight_init,
|
324 |
+
)
|
325 |
+
elif self.codebook_type == "simple":
|
326 |
+
self.codebook = SimpleCodebook(
|
327 |
+
self.codebook_dim,
|
328 |
+
codebook_size=self.codebook_size,
|
329 |
+
use_l2_normlize=self.use_l2_normlize,
|
330 |
+
)
|
331 |
+
else:
|
332 |
+
raise NotImplementedError(
|
333 |
+
f"codebook_type {self.codebook_type} is not implemented!"
|
334 |
+
)
|
335 |
+
|
336 |
+
def forward(self, z):
|
337 |
+
"""
|
338 |
+
Parameters
|
339 |
+
----------
|
340 |
+
z: torch.Tensor[B x D x T]
|
341 |
+
|
342 |
+
Returns
|
343 |
+
-------
|
344 |
+
z_q: torch.Tensor[B x D x T]
|
345 |
+
Quantized continuous representation of input
|
346 |
+
commit_loss: Tensor[B]
|
347 |
+
Commitment loss to train encoder to predict vectors closer to codebook entries
|
348 |
+
codebook_loss: Tensor[B]
|
349 |
+
Codebook loss to update the codebook
|
350 |
+
indices: torch.Tensor[B x T]
|
351 |
+
Codebook indices (quantized discrete representation of input)
|
352 |
+
z_e: torch.Tensor[B x D x T]
|
353 |
+
Projected latents (continuous representation of input before quantization)
|
354 |
+
"""
|
355 |
+
|
356 |
+
# Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim
|
357 |
+
z_e = self.in_project(z)
|
358 |
+
z_q, indices = self.decode_latents(z_e)
|
359 |
+
|
360 |
+
# Compute commitment loss and codebook loss
|
361 |
+
if self.training:
|
362 |
+
commit_loss = (
|
363 |
+
F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
|
364 |
+
* self.commitment
|
365 |
+
)
|
366 |
+
codebook_loss = (
|
367 |
+
F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
|
368 |
+
* self.codebook_loss_weight
|
369 |
+
)
|
370 |
+
else:
|
371 |
+
commit_loss = torch.zeros(z.shape[0], device=z.device)
|
372 |
+
codebook_loss = torch.zeros(z.shape[0], device=z.device)
|
373 |
+
|
374 |
+
z_q = z_e + (z_q - z_e).detach()
|
375 |
+
|
376 |
+
z_q = self.out_project(z_q)
|
377 |
+
|
378 |
+
return z_q, commit_loss, codebook_loss, indices, z_e
|
379 |
+
|
380 |
+
def decode_latents(self, latents):
|
381 |
+
encodings = rearrange(latents, "b d t -> b t d")
|
382 |
+
z_q, indices = self.codebook(encodings)
|
383 |
+
z_q = z_q.transpose(1, 2)
|
384 |
+
return z_q, indices
|
385 |
+
|
386 |
+
def vq2emb(self, vq, out_proj=True):
|
387 |
+
emb = self.codebook.vq2emb(vq)
|
388 |
+
emb = emb.transpose(1, 2)
|
389 |
+
if out_proj:
|
390 |
+
emb = self.out_project(emb)
|
391 |
+
return emb
|
392 |
+
|
393 |
+
def latent2dist(self, latents):
|
394 |
+
latents = rearrange(latents, "b d t -> b t d")
|
395 |
+
dist, embed_ind, quantize = self.codebook.latent2dist(latents)
|
396 |
+
return dist, embed_ind, quantize.transpose(1, 2)
|
modules/audio_tokenizer/rep_codec.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
from modules.audio_tokenizer.quantize import ResidualVQ
|
6 |
+
from modules.audio_tokenizer.vocos import VocosBackbone
|
7 |
+
from modules.audio_tokenizer.transformer import TransformerEncoder
|
8 |
+
|
9 |
+
def init_weights(m):
|
10 |
+
if isinstance(m, nn.Conv1d):
|
11 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
12 |
+
nn.init.constant_(m.bias, 0)
|
13 |
+
if isinstance(m, nn.Linear):
|
14 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
15 |
+
nn.init.constant_(m.bias, 0)
|
16 |
+
|
17 |
+
class RepCodec(nn.Module):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
codebook_size=8192,
|
21 |
+
hidden_size=1024,
|
22 |
+
codebook_dim=8,
|
23 |
+
vocos_dim=384,
|
24 |
+
vocos_intermediate_dim=2048,
|
25 |
+
vocos_num_layers=12,
|
26 |
+
num_quantizers=1,
|
27 |
+
use_timbre_encoder=False,
|
28 |
+
cfg=None,
|
29 |
+
):
|
30 |
+
super().__init__()
|
31 |
+
codebook_size = (
|
32 |
+
cfg.codebook_size
|
33 |
+
if cfg is not None and hasattr(cfg, "codebook_size")
|
34 |
+
else codebook_size
|
35 |
+
)
|
36 |
+
codebook_dim = (
|
37 |
+
cfg.codebook_dim
|
38 |
+
if cfg is not None and hasattr(cfg, "codebook_dim")
|
39 |
+
else codebook_dim
|
40 |
+
)
|
41 |
+
hidden_size = (
|
42 |
+
cfg.hidden_size
|
43 |
+
if cfg is not None and hasattr(cfg, "hidden_size")
|
44 |
+
else hidden_size
|
45 |
+
)
|
46 |
+
vocos_dim = (
|
47 |
+
cfg.vocos_dim
|
48 |
+
if cfg is not None and hasattr(cfg, "vocos_dim")
|
49 |
+
else vocos_dim
|
50 |
+
)
|
51 |
+
vocos_intermediate_dim = (
|
52 |
+
cfg.vocos_intermediate_dim
|
53 |
+
if cfg is not None and hasattr(cfg, "vocos_dim")
|
54 |
+
else vocos_intermediate_dim
|
55 |
+
)
|
56 |
+
vocos_num_layers = (
|
57 |
+
cfg.vocos_num_layers
|
58 |
+
if cfg is not None and hasattr(cfg, "vocos_dim")
|
59 |
+
else vocos_num_layers
|
60 |
+
)
|
61 |
+
num_quantizers = (
|
62 |
+
cfg.num_quantizers
|
63 |
+
if cfg is not None and hasattr(cfg, "num_quantizers")
|
64 |
+
else num_quantizers
|
65 |
+
)
|
66 |
+
use_timbre_encoder = (
|
67 |
+
cfg.use_timbre_encoder
|
68 |
+
if cfg is not None and hasattr(cfg, "use_timbre_encoder")
|
69 |
+
else use_timbre_encoder
|
70 |
+
)
|
71 |
+
|
72 |
+
self.codebook_size = codebook_size
|
73 |
+
self.codebook_dim = codebook_dim
|
74 |
+
self.hidden_size = hidden_size
|
75 |
+
self.vocos_dim = vocos_dim
|
76 |
+
self.vocos_intermediate_dim = vocos_intermediate_dim
|
77 |
+
self.vocos_num_layers = vocos_num_layers
|
78 |
+
self.num_quantizers = num_quantizers
|
79 |
+
self.use_timbre_encoder = use_timbre_encoder
|
80 |
+
|
81 |
+
self.encoder = nn.Sequential(
|
82 |
+
VocosBackbone(
|
83 |
+
input_channels=self.hidden_size,
|
84 |
+
dim=384,
|
85 |
+
intermediate_dim=2048,
|
86 |
+
num_layers=12,
|
87 |
+
adanorm_num_embeddings=None
|
88 |
+
),
|
89 |
+
nn.Linear(384, self.hidden_size)
|
90 |
+
)
|
91 |
+
self.decoder = nn.Sequential(
|
92 |
+
VocosBackbone(
|
93 |
+
input_channels=self.hidden_size,
|
94 |
+
dim=384,
|
95 |
+
intermediate_dim=2048,
|
96 |
+
num_layers=12,
|
97 |
+
adanorm_num_embeddings=None
|
98 |
+
),
|
99 |
+
nn.Linear(384, self.hidden_size)
|
100 |
+
)
|
101 |
+
|
102 |
+
self.quantizer = ResidualVQ(
|
103 |
+
input_dim=hidden_size,
|
104 |
+
num_quantizers=num_quantizers,
|
105 |
+
codebook_size=codebook_size,
|
106 |
+
codebook_dim=codebook_dim,
|
107 |
+
quantizer_type="fvq",
|
108 |
+
quantizer_dropout=0.0,
|
109 |
+
commitment=0.15,
|
110 |
+
codebook_loss_weight=1.0,
|
111 |
+
use_l2_normlize=True,
|
112 |
+
)
|
113 |
+
|
114 |
+
if self.use_timbre_encoder: #TODO: write encoder hidden (256) as a hyparam
|
115 |
+
self.timbre_in = nn.Linear(hidden_size, 256)
|
116 |
+
self.timbre_encoder = TransformerEncoder(
|
117 |
+
enc_emb_tokens=None,
|
118 |
+
encoder_layer=4,
|
119 |
+
encoder_hidden=256,
|
120 |
+
encoder_head=4,
|
121 |
+
conv_filter_size=1024,
|
122 |
+
conv_kernel_size=5,
|
123 |
+
encoder_dropout=0.1,
|
124 |
+
use_pe=False,
|
125 |
+
cfg=None,
|
126 |
+
)
|
127 |
+
self.timbre_out = nn.Linear(256, hidden_size)
|
128 |
+
self.timbre_linear = nn.Linear(hidden_size, hidden_size * 2)
|
129 |
+
self.timbre_linear.bias.data[:hidden_size] = 1
|
130 |
+
self.timbre_linear.bias.data[hidden_size:] = 0
|
131 |
+
self.timbre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False)
|
132 |
+
self.enc_ln = nn.LayerNorm(hidden_size, elementwise_affine=False)
|
133 |
+
|
134 |
+
self.reset_parameters()
|
135 |
+
|
136 |
+
def forward(self, x):
|
137 |
+
|
138 |
+
x = self.encoder(x.transpose(1, 2)).transpose(1, 2)
|
139 |
+
|
140 |
+
if self.use_timbre_encoder:
|
141 |
+
x_timbre = x
|
142 |
+
x = x.transpose(1, 2)
|
143 |
+
x = self.enc_ln(x)
|
144 |
+
x = x.transpose(1, 2)
|
145 |
+
|
146 |
+
(
|
147 |
+
quantized_out,
|
148 |
+
all_indices,
|
149 |
+
all_commit_losses,
|
150 |
+
all_codebook_losses,
|
151 |
+
_,
|
152 |
+
) = self.quantizer(x)
|
153 |
+
|
154 |
+
if self.use_timbre_encoder:
|
155 |
+
x_timbre = x_timbre.transpose(1, 2)
|
156 |
+
x_timbre = self.timbre_in(x_timbre)
|
157 |
+
x_timbre = self.timbre_encoder(x_timbre, None, None)
|
158 |
+
x_timbre = self.timbre_out(x_timbre)
|
159 |
+
x_timbre = x_timbre.transpose(1, 2)
|
160 |
+
spk_embs = torch.mean(x_timbre, dim=2)
|
161 |
+
|
162 |
+
style = self.timbre_linear(spk_embs).unsqueeze(2) # (B, 2d, 1)
|
163 |
+
gamma, beta = style.chunk(2, 1) # (B, d, 1)
|
164 |
+
quantized_out = quantized_out.transpose(1, 2)
|
165 |
+
quantized_out = self.timbre_norm(quantized_out)
|
166 |
+
quantized_out = quantized_out.transpose(1, 2)
|
167 |
+
quantized_out = quantized_out * gamma + beta
|
168 |
+
|
169 |
+
|
170 |
+
x_rec = self.decoder(quantized_out)
|
171 |
+
|
172 |
+
codebook_loss = (all_codebook_losses + all_commit_losses).mean()
|
173 |
+
all_indices = all_indices
|
174 |
+
|
175 |
+
return x_rec, codebook_loss, all_indices
|
176 |
+
|
177 |
+
def quantize(self, x):
|
178 |
+
x = self.encoder(x.transpose(1, 2)).transpose(1, 2)
|
179 |
+
|
180 |
+
if self.use_timbre_encoder:
|
181 |
+
x = x.transpose(1, 2)
|
182 |
+
x = self.enc_ln(x)
|
183 |
+
x = x.transpose(1, 2)
|
184 |
+
|
185 |
+
(
|
186 |
+
quantized_out,
|
187 |
+
all_indices,
|
188 |
+
all_commit_losses,
|
189 |
+
all_codebook_losses,
|
190 |
+
_,
|
191 |
+
) = self.quantizer(x)
|
192 |
+
if all_indices.shape[0] == 1:
|
193 |
+
return all_indices.squeeze(0), quantized_out.transpose(1, 2)
|
194 |
+
return all_indices, quantized_out.transpose(1, 2)
|
195 |
+
|
196 |
+
def reset_parameters(self):
|
197 |
+
self.apply(init_weights)
|
modules/audio_tokenizer/transformer.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
import math
|
6 |
+
|
7 |
+
|
8 |
+
class StyleAdaptiveLayerNorm(nn.Module):
|
9 |
+
def __init__(self, normalized_shape, eps=1e-5):
|
10 |
+
super().__init__()
|
11 |
+
self.in_dim = normalized_shape
|
12 |
+
self.norm = nn.LayerNorm(self.in_dim, eps=eps, elementwise_affine=False)
|
13 |
+
self.style = nn.Linear(self.in_dim, self.in_dim * 2)
|
14 |
+
self.style.bias.data[: self.in_dim] = 1
|
15 |
+
self.style.bias.data[self.in_dim :] = 0
|
16 |
+
|
17 |
+
def forward(self, x, condition):
|
18 |
+
# x: (B, T, d); condition: (B, T, d)
|
19 |
+
|
20 |
+
style = self.style(torch.mean(condition, dim=1, keepdim=True))
|
21 |
+
|
22 |
+
gamma, beta = style.chunk(2, -1)
|
23 |
+
|
24 |
+
out = self.norm(x)
|
25 |
+
|
26 |
+
out = gamma * out + beta
|
27 |
+
return out
|
28 |
+
|
29 |
+
|
30 |
+
class PositionalEncoding(nn.Module):
|
31 |
+
def __init__(self, d_model, dropout, max_len=5000):
|
32 |
+
super().__init__()
|
33 |
+
|
34 |
+
self.dropout = dropout
|
35 |
+
position = torch.arange(max_len).unsqueeze(1)
|
36 |
+
div_term = torch.exp(
|
37 |
+
torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
|
38 |
+
)
|
39 |
+
pe = torch.zeros(max_len, 1, d_model)
|
40 |
+
pe[:, 0, 0::2] = torch.sin(position * div_term)
|
41 |
+
pe[:, 0, 1::2] = torch.cos(position * div_term)
|
42 |
+
self.register_buffer("pe", pe)
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
x = x + self.pe[: x.size(0)]
|
46 |
+
return F.dropout(x, self.dropout, training=self.training)
|
47 |
+
|
48 |
+
|
49 |
+
class TransformerFFNLayer(nn.Module):
|
50 |
+
def __init__(
|
51 |
+
self, encoder_hidden, conv_filter_size, conv_kernel_size, encoder_dropout
|
52 |
+
):
|
53 |
+
super().__init__()
|
54 |
+
|
55 |
+
self.encoder_hidden = encoder_hidden
|
56 |
+
self.conv_filter_size = conv_filter_size
|
57 |
+
self.conv_kernel_size = conv_kernel_size
|
58 |
+
self.encoder_dropout = encoder_dropout
|
59 |
+
|
60 |
+
self.ffn_1 = nn.Conv1d(
|
61 |
+
self.encoder_hidden,
|
62 |
+
self.conv_filter_size,
|
63 |
+
self.conv_kernel_size,
|
64 |
+
padding=self.conv_kernel_size // 2,
|
65 |
+
)
|
66 |
+
self.ffn_1.weight.data.normal_(0.0, 0.02)
|
67 |
+
self.ffn_2 = nn.Linear(self.conv_filter_size, self.encoder_hidden)
|
68 |
+
self.ffn_2.weight.data.normal_(0.0, 0.02)
|
69 |
+
|
70 |
+
def forward(self, x):
|
71 |
+
# x: (B, T, d)
|
72 |
+
x = self.ffn_1(x.permute(0, 2, 1)).permute(
|
73 |
+
0, 2, 1
|
74 |
+
) # (B, T, d) -> (B, d, T) -> (B, T, d)
|
75 |
+
x = F.relu(x)
|
76 |
+
x = F.dropout(x, self.encoder_dropout, training=self.training)
|
77 |
+
x = self.ffn_2(x)
|
78 |
+
return x
|
79 |
+
|
80 |
+
|
81 |
+
class TransformerEncoderLayer(nn.Module):
|
82 |
+
def __init__(
|
83 |
+
self,
|
84 |
+
encoder_hidden,
|
85 |
+
encoder_head,
|
86 |
+
conv_filter_size,
|
87 |
+
conv_kernel_size,
|
88 |
+
encoder_dropout,
|
89 |
+
use_cln,
|
90 |
+
):
|
91 |
+
super().__init__()
|
92 |
+
self.encoder_hidden = encoder_hidden
|
93 |
+
self.encoder_head = encoder_head
|
94 |
+
self.conv_filter_size = conv_filter_size
|
95 |
+
self.conv_kernel_size = conv_kernel_size
|
96 |
+
self.encoder_dropout = encoder_dropout
|
97 |
+
self.use_cln = use_cln
|
98 |
+
|
99 |
+
if not self.use_cln:
|
100 |
+
self.ln_1 = nn.LayerNorm(self.encoder_hidden)
|
101 |
+
self.ln_2 = nn.LayerNorm(self.encoder_hidden)
|
102 |
+
else:
|
103 |
+
self.ln_1 = StyleAdaptiveLayerNorm(self.encoder_hidden)
|
104 |
+
self.ln_2 = StyleAdaptiveLayerNorm(self.encoder_hidden)
|
105 |
+
|
106 |
+
self.self_attn = nn.MultiheadAttention(
|
107 |
+
self.encoder_hidden, self.encoder_head, batch_first=True
|
108 |
+
)
|
109 |
+
|
110 |
+
self.ffn = TransformerFFNLayer(
|
111 |
+
self.encoder_hidden,
|
112 |
+
self.conv_filter_size,
|
113 |
+
self.conv_kernel_size,
|
114 |
+
self.encoder_dropout,
|
115 |
+
)
|
116 |
+
|
117 |
+
def forward(self, x, key_padding_mask, conditon=None):
|
118 |
+
# x: (B, T, d); key_padding_mask: (B, T), mask is 0; condition: (B, T, d)
|
119 |
+
|
120 |
+
# self attention
|
121 |
+
residual = x
|
122 |
+
if self.use_cln:
|
123 |
+
x = self.ln_1(x, conditon)
|
124 |
+
else:
|
125 |
+
x = self.ln_1(x)
|
126 |
+
|
127 |
+
if key_padding_mask != None:
|
128 |
+
key_padding_mask_input = ~(key_padding_mask.bool())
|
129 |
+
else:
|
130 |
+
key_padding_mask_input = None
|
131 |
+
x, _ = self.self_attn(
|
132 |
+
query=x, key=x, value=x, key_padding_mask=key_padding_mask_input
|
133 |
+
)
|
134 |
+
x = F.dropout(x, self.encoder_dropout, training=self.training)
|
135 |
+
x = residual + x
|
136 |
+
|
137 |
+
# ffn
|
138 |
+
residual = x
|
139 |
+
if self.use_cln:
|
140 |
+
x = self.ln_2(x, conditon)
|
141 |
+
else:
|
142 |
+
x = self.ln_2(x)
|
143 |
+
x = self.ffn(x)
|
144 |
+
x = residual + x
|
145 |
+
|
146 |
+
return x
|
147 |
+
|
148 |
+
|
149 |
+
class TransformerEncoder(nn.Module):
|
150 |
+
def __init__(
|
151 |
+
self,
|
152 |
+
enc_emb_tokens=None,
|
153 |
+
encoder_layer=4,
|
154 |
+
encoder_hidden=256,
|
155 |
+
encoder_head=4,
|
156 |
+
conv_filter_size=1024,
|
157 |
+
conv_kernel_size=5,
|
158 |
+
encoder_dropout=0.1,
|
159 |
+
use_cln=False,
|
160 |
+
use_pe=True,
|
161 |
+
cfg=None,
|
162 |
+
):
|
163 |
+
super().__init__()
|
164 |
+
|
165 |
+
self.encoder_layer = (
|
166 |
+
encoder_layer if encoder_layer is not None else cfg.encoder_layer
|
167 |
+
)
|
168 |
+
self.encoder_hidden = (
|
169 |
+
encoder_hidden if encoder_hidden is not None else cfg.encoder_hidden
|
170 |
+
)
|
171 |
+
self.encoder_head = (
|
172 |
+
encoder_head if encoder_head is not None else cfg.encoder_head
|
173 |
+
)
|
174 |
+
self.conv_filter_size = (
|
175 |
+
conv_filter_size if conv_filter_size is not None else cfg.conv_filter_size
|
176 |
+
)
|
177 |
+
self.conv_kernel_size = (
|
178 |
+
conv_kernel_size if conv_kernel_size is not None else cfg.conv_kernel_size
|
179 |
+
)
|
180 |
+
self.encoder_dropout = (
|
181 |
+
encoder_dropout if encoder_dropout is not None else cfg.encoder_dropout
|
182 |
+
)
|
183 |
+
self.use_pe = use_pe if use_pe is not None else cfg.use_pe
|
184 |
+
self.use_cln = use_cln if use_cln is not None else cfg.use_cln
|
185 |
+
|
186 |
+
if enc_emb_tokens != None:
|
187 |
+
self.use_enc_emb = True
|
188 |
+
self.enc_emb_tokens = enc_emb_tokens
|
189 |
+
else:
|
190 |
+
self.use_enc_emb = False
|
191 |
+
|
192 |
+
if self.use_pe:
|
193 |
+
self.position_emb = PositionalEncoding(
|
194 |
+
self.encoder_hidden, self.encoder_dropout
|
195 |
+
)
|
196 |
+
|
197 |
+
self.layers = nn.ModuleList([])
|
198 |
+
self.layers.extend(
|
199 |
+
[
|
200 |
+
TransformerEncoderLayer(
|
201 |
+
self.encoder_hidden,
|
202 |
+
self.encoder_head,
|
203 |
+
self.conv_filter_size,
|
204 |
+
self.conv_kernel_size,
|
205 |
+
self.encoder_dropout,
|
206 |
+
self.use_cln,
|
207 |
+
)
|
208 |
+
for i in range(self.encoder_layer)
|
209 |
+
]
|
210 |
+
)
|
211 |
+
|
212 |
+
if self.use_cln:
|
213 |
+
self.last_ln = StyleAdaptiveLayerNorm(self.encoder_hidden)
|
214 |
+
else:
|
215 |
+
self.last_ln = nn.LayerNorm(self.encoder_hidden)
|
216 |
+
|
217 |
+
def forward(self, x, key_padding_mask, condition=None):
|
218 |
+
if len(x.shape) == 2 and self.use_enc_emb:
|
219 |
+
x = self.enc_emb_tokens(x)
|
220 |
+
if self.use_pe:
|
221 |
+
x = self.position_emb(x)
|
222 |
+
else:
|
223 |
+
if self.use_pe:
|
224 |
+
x = self.position_emb(x) # (B, T, d)
|
225 |
+
|
226 |
+
for layer in self.layers:
|
227 |
+
x = layer(x, key_padding_mask, condition)
|
228 |
+
|
229 |
+
if self.use_cln:
|
230 |
+
x = self.last_ln(x, condition)
|
231 |
+
else:
|
232 |
+
x = self.last_ln(x)
|
233 |
+
|
234 |
+
return x
|
modules/audio_tokenizer/vocos.py
ADDED
@@ -0,0 +1,845 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import scipy
|
5 |
+
import torch
|
6 |
+
from torch import nn, view_as_real, view_as_complex
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn.utils import weight_norm, remove_weight_norm
|
9 |
+
from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz
|
10 |
+
|
11 |
+
|
12 |
+
def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
|
13 |
+
"""
|
14 |
+
Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
x (Tensor): Input tensor.
|
18 |
+
clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
|
19 |
+
|
20 |
+
Returns:
|
21 |
+
Tensor: Element-wise logarithm of the input tensor with clipping applied.
|
22 |
+
"""
|
23 |
+
return torch.log(torch.clip(x, min=clip_val))
|
24 |
+
|
25 |
+
|
26 |
+
def symlog(x: torch.Tensor) -> torch.Tensor:
|
27 |
+
return torch.sign(x) * torch.log1p(x.abs())
|
28 |
+
|
29 |
+
|
30 |
+
def symexp(x: torch.Tensor) -> torch.Tensor:
|
31 |
+
return torch.sign(x) * (torch.exp(x.abs()) - 1)
|
32 |
+
|
33 |
+
|
34 |
+
class STFT(nn.Module):
|
35 |
+
def __init__(
|
36 |
+
self,
|
37 |
+
n_fft: int,
|
38 |
+
hop_length: int,
|
39 |
+
win_length: int,
|
40 |
+
center=True,
|
41 |
+
):
|
42 |
+
super().__init__()
|
43 |
+
self.center = center
|
44 |
+
self.n_fft = n_fft
|
45 |
+
self.hop_length = hop_length
|
46 |
+
self.win_length = win_length
|
47 |
+
window = torch.hann_window(win_length)
|
48 |
+
self.register_buffer("window", window)
|
49 |
+
|
50 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
51 |
+
# x: (B, T * hop_length)
|
52 |
+
|
53 |
+
if not self.center:
|
54 |
+
pad = self.win_length - self.hop_length
|
55 |
+
x = torch.nn.functional.pad(x, (pad // 2, pad // 2), mode="reflect")
|
56 |
+
|
57 |
+
stft_spec = torch.stft(
|
58 |
+
x,
|
59 |
+
self.n_fft,
|
60 |
+
hop_length=self.hop_length,
|
61 |
+
win_length=self.win_length,
|
62 |
+
window=self.window,
|
63 |
+
center=self.center,
|
64 |
+
return_complex=False,
|
65 |
+
) # (B, n_fft // 2 + 1, T, 2)
|
66 |
+
|
67 |
+
rea = stft_spec[:, :, :, 0] # (B, n_fft // 2 + 1, T, 2)
|
68 |
+
imag = stft_spec[:, :, :, 1] # (B, n_fft // 2 + 1, T, 2)
|
69 |
+
|
70 |
+
log_mag = torch.log(
|
71 |
+
torch.abs(torch.sqrt(torch.pow(rea, 2) + torch.pow(imag, 2))) + 1e-5
|
72 |
+
) # (B, n_fft // 2 + 1, T)
|
73 |
+
phase = torch.atan2(imag, rea) # (B, n_fft // 2 + 1, T)
|
74 |
+
|
75 |
+
return log_mag, phase
|
76 |
+
|
77 |
+
|
78 |
+
class ISTFT(nn.Module):
|
79 |
+
"""
|
80 |
+
Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
|
81 |
+
windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
|
82 |
+
See issue: https://github.com/pytorch/pytorch/issues/62323
|
83 |
+
Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
|
84 |
+
The NOLA constraint is met as we trim padded samples anyway.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
n_fft (int): Size of Fourier transform.
|
88 |
+
hop_length (int): The distance between neighboring sliding window frames.
|
89 |
+
win_length (int): The size of window frame and STFT filter.
|
90 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
91 |
+
"""
|
92 |
+
|
93 |
+
def __init__(
|
94 |
+
self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"
|
95 |
+
):
|
96 |
+
super().__init__()
|
97 |
+
if padding not in ["center", "same"]:
|
98 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
99 |
+
self.padding = padding
|
100 |
+
self.n_fft = n_fft
|
101 |
+
self.hop_length = hop_length
|
102 |
+
self.win_length = win_length
|
103 |
+
window = torch.hann_window(win_length)
|
104 |
+
self.register_buffer("window", window)
|
105 |
+
|
106 |
+
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
107 |
+
"""
|
108 |
+
Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
|
109 |
+
|
110 |
+
Args:
|
111 |
+
spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
|
112 |
+
N is the number of frequency bins, and T is the number of time frames.
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
|
116 |
+
"""
|
117 |
+
if self.padding == "center":
|
118 |
+
# Fallback to pytorch native implementation
|
119 |
+
return torch.istft(
|
120 |
+
spec,
|
121 |
+
self.n_fft,
|
122 |
+
self.hop_length,
|
123 |
+
self.win_length,
|
124 |
+
self.window,
|
125 |
+
center=True,
|
126 |
+
)
|
127 |
+
elif self.padding == "same":
|
128 |
+
pad = (self.win_length - self.hop_length) // 2
|
129 |
+
else:
|
130 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
131 |
+
|
132 |
+
assert spec.dim() == 3, "Expected a 3D tensor as input"
|
133 |
+
B, N, T = spec.shape
|
134 |
+
|
135 |
+
# Inverse FFT
|
136 |
+
ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
|
137 |
+
ifft = ifft * self.window[None, :, None]
|
138 |
+
|
139 |
+
# Overlap and Add
|
140 |
+
output_size = (T - 1) * self.hop_length + self.win_length
|
141 |
+
y = torch.nn.functional.fold(
|
142 |
+
ifft,
|
143 |
+
output_size=(1, output_size),
|
144 |
+
kernel_size=(1, self.win_length),
|
145 |
+
stride=(1, self.hop_length),
|
146 |
+
)[:, 0, 0, pad:-pad]
|
147 |
+
|
148 |
+
# Window envelope
|
149 |
+
window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
|
150 |
+
window_envelope = torch.nn.functional.fold(
|
151 |
+
window_sq,
|
152 |
+
output_size=(1, output_size),
|
153 |
+
kernel_size=(1, self.win_length),
|
154 |
+
stride=(1, self.hop_length),
|
155 |
+
).squeeze()[pad:-pad]
|
156 |
+
|
157 |
+
# Normalize
|
158 |
+
assert (window_envelope > 1e-11).all()
|
159 |
+
y = y / window_envelope
|
160 |
+
|
161 |
+
return y
|
162 |
+
|
163 |
+
|
164 |
+
class MDCT(nn.Module):
|
165 |
+
"""
|
166 |
+
Modified Discrete Cosine Transform (MDCT) module.
|
167 |
+
|
168 |
+
Args:
|
169 |
+
frame_len (int): Length of the MDCT frame.
|
170 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
171 |
+
"""
|
172 |
+
|
173 |
+
def __init__(self, frame_len: int, padding: str = "same"):
|
174 |
+
super().__init__()
|
175 |
+
if padding not in ["center", "same"]:
|
176 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
177 |
+
self.padding = padding
|
178 |
+
self.frame_len = frame_len
|
179 |
+
N = frame_len // 2
|
180 |
+
n0 = (N + 1) / 2
|
181 |
+
window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
|
182 |
+
self.register_buffer("window", window)
|
183 |
+
|
184 |
+
pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len)
|
185 |
+
post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N)
|
186 |
+
# view_as_real: NCCL Backend does not support ComplexFloat data type
|
187 |
+
# https://github.com/pytorch/pytorch/issues/71613
|
188 |
+
self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
|
189 |
+
self.register_buffer("post_twiddle", view_as_real(post_twiddle))
|
190 |
+
|
191 |
+
def forward(self, audio: torch.Tensor) -> torch.Tensor:
|
192 |
+
"""
|
193 |
+
Apply the Modified Discrete Cosine Transform (MDCT) to the input audio.
|
194 |
+
|
195 |
+
Args:
|
196 |
+
audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size
|
197 |
+
and T is the length of the audio.
|
198 |
+
|
199 |
+
Returns:
|
200 |
+
Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames
|
201 |
+
and N is the number of frequency bins.
|
202 |
+
"""
|
203 |
+
if self.padding == "center":
|
204 |
+
audio = torch.nn.functional.pad(
|
205 |
+
audio, (self.frame_len // 2, self.frame_len // 2)
|
206 |
+
)
|
207 |
+
elif self.padding == "same":
|
208 |
+
# hop_length is 1/2 frame_len
|
209 |
+
audio = torch.nn.functional.pad(
|
210 |
+
audio, (self.frame_len // 4, self.frame_len // 4)
|
211 |
+
)
|
212 |
+
else:
|
213 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
214 |
+
|
215 |
+
x = audio.unfold(-1, self.frame_len, self.frame_len // 2)
|
216 |
+
N = self.frame_len // 2
|
217 |
+
x = x * self.window.expand(x.shape)
|
218 |
+
X = torch.fft.fft(
|
219 |
+
x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1
|
220 |
+
)[..., :N]
|
221 |
+
res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N)
|
222 |
+
return torch.real(res) * np.sqrt(2)
|
223 |
+
|
224 |
+
|
225 |
+
class IMDCT(nn.Module):
|
226 |
+
"""
|
227 |
+
Inverse Modified Discrete Cosine Transform (IMDCT) module.
|
228 |
+
|
229 |
+
Args:
|
230 |
+
frame_len (int): Length of the MDCT frame.
|
231 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
232 |
+
"""
|
233 |
+
|
234 |
+
def __init__(self, frame_len: int, padding: str = "same"):
|
235 |
+
super().__init__()
|
236 |
+
if padding not in ["center", "same"]:
|
237 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
238 |
+
self.padding = padding
|
239 |
+
self.frame_len = frame_len
|
240 |
+
N = frame_len // 2
|
241 |
+
n0 = (N + 1) / 2
|
242 |
+
window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
|
243 |
+
self.register_buffer("window", window)
|
244 |
+
|
245 |
+
pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N)
|
246 |
+
post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2))
|
247 |
+
self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
|
248 |
+
self.register_buffer("post_twiddle", view_as_real(post_twiddle))
|
249 |
+
|
250 |
+
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
251 |
+
"""
|
252 |
+
Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients.
|
253 |
+
|
254 |
+
Args:
|
255 |
+
X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size,
|
256 |
+
L is the number of frames, and N is the number of frequency bins.
|
257 |
+
|
258 |
+
Returns:
|
259 |
+
Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio.
|
260 |
+
"""
|
261 |
+
B, L, N = X.shape
|
262 |
+
Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device)
|
263 |
+
Y[..., :N] = X
|
264 |
+
Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,)))
|
265 |
+
y = torch.fft.ifft(
|
266 |
+
Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1
|
267 |
+
)
|
268 |
+
y = (
|
269 |
+
torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape))
|
270 |
+
* np.sqrt(N)
|
271 |
+
* np.sqrt(2)
|
272 |
+
)
|
273 |
+
result = y * self.window.expand(y.shape)
|
274 |
+
output_size = (1, (L + 1) * N)
|
275 |
+
audio = torch.nn.functional.fold(
|
276 |
+
result.transpose(1, 2),
|
277 |
+
output_size=output_size,
|
278 |
+
kernel_size=(1, self.frame_len),
|
279 |
+
stride=(1, self.frame_len // 2),
|
280 |
+
)[:, 0, 0, :]
|
281 |
+
|
282 |
+
if self.padding == "center":
|
283 |
+
pad = self.frame_len // 2
|
284 |
+
elif self.padding == "same":
|
285 |
+
pad = self.frame_len // 4
|
286 |
+
else:
|
287 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
288 |
+
|
289 |
+
audio = audio[:, pad:-pad]
|
290 |
+
return audio
|
291 |
+
|
292 |
+
|
293 |
+
class FourierHead(nn.Module):
|
294 |
+
"""Base class for inverse fourier modules."""
|
295 |
+
|
296 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
297 |
+
"""
|
298 |
+
Args:
|
299 |
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
300 |
+
L is the sequence length, and H denotes the model dimension.
|
301 |
+
|
302 |
+
Returns:
|
303 |
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
304 |
+
"""
|
305 |
+
raise NotImplementedError("Subclasses must implement the forward method.")
|
306 |
+
|
307 |
+
|
308 |
+
class ISTFTHead(FourierHead):
|
309 |
+
"""
|
310 |
+
ISTFT Head module for predicting STFT complex coefficients.
|
311 |
+
|
312 |
+
Args:
|
313 |
+
dim (int): Hidden dimension of the model.
|
314 |
+
n_fft (int): Size of Fourier transform.
|
315 |
+
hop_length (int): The distance between neighboring sliding window frames, which should align with
|
316 |
+
the resolution of the input features.
|
317 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
318 |
+
"""
|
319 |
+
|
320 |
+
def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
|
321 |
+
super().__init__()
|
322 |
+
out_dim = n_fft + 2
|
323 |
+
self.out = torch.nn.Linear(dim, out_dim)
|
324 |
+
self.istft = ISTFT(
|
325 |
+
n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding
|
326 |
+
)
|
327 |
+
|
328 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
329 |
+
"""
|
330 |
+
Forward pass of the ISTFTHead module.
|
331 |
+
|
332 |
+
Args:
|
333 |
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
334 |
+
L is the sequence length, and H denotes the model dimension.
|
335 |
+
|
336 |
+
Returns:
|
337 |
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
338 |
+
"""
|
339 |
+
x = self.out(x).transpose(1, 2)
|
340 |
+
mag, p = x.chunk(2, dim=1)
|
341 |
+
mag = torch.exp(mag)
|
342 |
+
mag = torch.clip(
|
343 |
+
mag, max=1e2
|
344 |
+
) # safeguard to prevent excessively large magnitudes
|
345 |
+
# wrapping happens here. These two lines produce real and imaginary value
|
346 |
+
x = torch.cos(p)
|
347 |
+
y = torch.sin(p)
|
348 |
+
# recalculating phase here does not produce anything new
|
349 |
+
# only costs time
|
350 |
+
# phase = torch.atan2(y, x)
|
351 |
+
# S = mag * torch.exp(phase * 1j)
|
352 |
+
# better directly produce the complex value
|
353 |
+
S = mag * (x + 1j * y)
|
354 |
+
audio = self.istft(S)
|
355 |
+
return audio
|
356 |
+
|
357 |
+
|
358 |
+
class IMDCTSymExpHead(FourierHead):
|
359 |
+
"""
|
360 |
+
IMDCT Head module for predicting MDCT coefficients with symmetric exponential function
|
361 |
+
|
362 |
+
Args:
|
363 |
+
dim (int): Hidden dimension of the model.
|
364 |
+
mdct_frame_len (int): Length of the MDCT frame.
|
365 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
366 |
+
sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized
|
367 |
+
based on perceptual scaling. Defaults to None.
|
368 |
+
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
|
369 |
+
"""
|
370 |
+
|
371 |
+
def __init__(
|
372 |
+
self,
|
373 |
+
dim: int,
|
374 |
+
mdct_frame_len: int,
|
375 |
+
padding: str = "same",
|
376 |
+
sample_rate: Optional[int] = None,
|
377 |
+
clip_audio: bool = False,
|
378 |
+
):
|
379 |
+
super().__init__()
|
380 |
+
out_dim = mdct_frame_len // 2
|
381 |
+
self.out = nn.Linear(dim, out_dim)
|
382 |
+
self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
|
383 |
+
self.clip_audio = clip_audio
|
384 |
+
|
385 |
+
if sample_rate is not None:
|
386 |
+
# optionally init the last layer following mel-scale
|
387 |
+
m_max = _hz_to_mel(sample_rate // 2)
|
388 |
+
m_pts = torch.linspace(0, m_max, out_dim)
|
389 |
+
f_pts = _mel_to_hz(m_pts)
|
390 |
+
scale = 1 - (f_pts / f_pts.max())
|
391 |
+
|
392 |
+
with torch.no_grad():
|
393 |
+
self.out.weight.mul_(scale.view(-1, 1))
|
394 |
+
|
395 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
396 |
+
"""
|
397 |
+
Forward pass of the IMDCTSymExpHead module.
|
398 |
+
|
399 |
+
Args:
|
400 |
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
401 |
+
L is the sequence length, and H denotes the model dimension.
|
402 |
+
|
403 |
+
Returns:
|
404 |
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
405 |
+
"""
|
406 |
+
x = self.out(x)
|
407 |
+
x = symexp(x)
|
408 |
+
x = torch.clip(
|
409 |
+
x, min=-1e2, max=1e2
|
410 |
+
) # safeguard to prevent excessively large magnitudes
|
411 |
+
audio = self.imdct(x)
|
412 |
+
if self.clip_audio:
|
413 |
+
audio = torch.clip(x, min=-1.0, max=1.0)
|
414 |
+
|
415 |
+
return audio
|
416 |
+
|
417 |
+
|
418 |
+
class IMDCTCosHead(FourierHead):
|
419 |
+
"""
|
420 |
+
IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p)
|
421 |
+
|
422 |
+
Args:
|
423 |
+
dim (int): Hidden dimension of the model.
|
424 |
+
mdct_frame_len (int): Length of the MDCT frame.
|
425 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
426 |
+
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
|
427 |
+
"""
|
428 |
+
|
429 |
+
def __init__(
|
430 |
+
self,
|
431 |
+
dim: int,
|
432 |
+
mdct_frame_len: int,
|
433 |
+
padding: str = "same",
|
434 |
+
clip_audio: bool = False,
|
435 |
+
):
|
436 |
+
super().__init__()
|
437 |
+
self.clip_audio = clip_audio
|
438 |
+
self.out = nn.Linear(dim, mdct_frame_len)
|
439 |
+
self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
|
440 |
+
|
441 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
442 |
+
"""
|
443 |
+
Forward pass of the IMDCTCosHead module.
|
444 |
+
|
445 |
+
Args:
|
446 |
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
447 |
+
L is the sequence length, and H denotes the model dimension.
|
448 |
+
|
449 |
+
Returns:
|
450 |
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
451 |
+
"""
|
452 |
+
x = self.out(x)
|
453 |
+
m, p = x.chunk(2, dim=2)
|
454 |
+
m = torch.exp(m).clip(
|
455 |
+
max=1e2
|
456 |
+
) # safeguard to prevent excessively large magnitudes
|
457 |
+
audio = self.imdct(m * torch.cos(p))
|
458 |
+
if self.clip_audio:
|
459 |
+
audio = torch.clip(x, min=-1.0, max=1.0)
|
460 |
+
return audio
|
461 |
+
|
462 |
+
|
463 |
+
class ConvNeXtBlock(nn.Module):
|
464 |
+
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
|
465 |
+
|
466 |
+
Args:
|
467 |
+
dim (int): Number of input channels.
|
468 |
+
intermediate_dim (int): Dimensionality of the intermediate layer.
|
469 |
+
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
470 |
+
Defaults to None.
|
471 |
+
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
|
472 |
+
None means non-conditional LayerNorm. Defaults to None.
|
473 |
+
"""
|
474 |
+
|
475 |
+
def __init__(
|
476 |
+
self,
|
477 |
+
dim: int,
|
478 |
+
intermediate_dim: int,
|
479 |
+
layer_scale_init_value: float,
|
480 |
+
adanorm_num_embeddings: Optional[int] = None,
|
481 |
+
):
|
482 |
+
super().__init__()
|
483 |
+
self.dwconv = nn.Conv1d(
|
484 |
+
dim, dim, kernel_size=7, padding=3, groups=dim
|
485 |
+
) # depthwise conv
|
486 |
+
self.adanorm = adanorm_num_embeddings is not None
|
487 |
+
if adanorm_num_embeddings:
|
488 |
+
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
|
489 |
+
else:
|
490 |
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
491 |
+
self.pwconv1 = nn.Linear(
|
492 |
+
dim, intermediate_dim
|
493 |
+
) # pointwise/1x1 convs, implemented with linear layers
|
494 |
+
self.act = nn.GELU()
|
495 |
+
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
496 |
+
self.gamma = (
|
497 |
+
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
|
498 |
+
if layer_scale_init_value > 0
|
499 |
+
else None
|
500 |
+
)
|
501 |
+
|
502 |
+
def forward(
|
503 |
+
self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
|
504 |
+
) -> torch.Tensor:
|
505 |
+
residual = x
|
506 |
+
x = self.dwconv(x)
|
507 |
+
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
|
508 |
+
if self.adanorm:
|
509 |
+
assert cond_embedding_id is not None
|
510 |
+
x = self.norm(x, cond_embedding_id)
|
511 |
+
else:
|
512 |
+
x = self.norm(x)
|
513 |
+
x = self.pwconv1(x)
|
514 |
+
x = self.act(x)
|
515 |
+
x = self.pwconv2(x)
|
516 |
+
if self.gamma is not None:
|
517 |
+
x = self.gamma * x
|
518 |
+
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
|
519 |
+
|
520 |
+
x = residual + x
|
521 |
+
return x
|
522 |
+
|
523 |
+
|
524 |
+
class AdaLayerNorm(nn.Module):
|
525 |
+
"""
|
526 |
+
Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
|
527 |
+
|
528 |
+
Args:
|
529 |
+
num_embeddings (int): Number of embeddings.
|
530 |
+
embedding_dim (int): Dimension of the embeddings.
|
531 |
+
"""
|
532 |
+
|
533 |
+
def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
|
534 |
+
super().__init__()
|
535 |
+
self.eps = eps
|
536 |
+
self.dim = embedding_dim
|
537 |
+
self.scale = nn.Embedding(
|
538 |
+
num_embeddings=num_embeddings, embedding_dim=embedding_dim
|
539 |
+
)
|
540 |
+
self.shift = nn.Embedding(
|
541 |
+
num_embeddings=num_embeddings, embedding_dim=embedding_dim
|
542 |
+
)
|
543 |
+
torch.nn.init.ones_(self.scale.weight)
|
544 |
+
torch.nn.init.zeros_(self.shift.weight)
|
545 |
+
|
546 |
+
def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
|
547 |
+
scale = self.scale(cond_embedding_id)
|
548 |
+
shift = self.shift(cond_embedding_id)
|
549 |
+
x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
|
550 |
+
x = x * scale + shift
|
551 |
+
return x
|
552 |
+
|
553 |
+
|
554 |
+
class ResBlock1(nn.Module):
|
555 |
+
"""
|
556 |
+
ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
|
557 |
+
but without upsampling layers.
|
558 |
+
|
559 |
+
Args:
|
560 |
+
dim (int): Number of input channels.
|
561 |
+
kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
|
562 |
+
dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
|
563 |
+
Defaults to (1, 3, 5).
|
564 |
+
lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
|
565 |
+
Defaults to 0.1.
|
566 |
+
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
567 |
+
Defaults to None.
|
568 |
+
"""
|
569 |
+
|
570 |
+
def __init__(
|
571 |
+
self,
|
572 |
+
dim: int,
|
573 |
+
kernel_size: int = 3,
|
574 |
+
dilation: Tuple[int, int, int] = (1, 3, 5),
|
575 |
+
lrelu_slope: float = 0.1,
|
576 |
+
layer_scale_init_value: Optional[float] = None,
|
577 |
+
):
|
578 |
+
super().__init__()
|
579 |
+
self.lrelu_slope = lrelu_slope
|
580 |
+
self.convs1 = nn.ModuleList(
|
581 |
+
[
|
582 |
+
weight_norm(
|
583 |
+
nn.Conv1d(
|
584 |
+
dim,
|
585 |
+
dim,
|
586 |
+
kernel_size,
|
587 |
+
1,
|
588 |
+
dilation=dilation[0],
|
589 |
+
padding=self.get_padding(kernel_size, dilation[0]),
|
590 |
+
)
|
591 |
+
),
|
592 |
+
weight_norm(
|
593 |
+
nn.Conv1d(
|
594 |
+
dim,
|
595 |
+
dim,
|
596 |
+
kernel_size,
|
597 |
+
1,
|
598 |
+
dilation=dilation[1],
|
599 |
+
padding=self.get_padding(kernel_size, dilation[1]),
|
600 |
+
)
|
601 |
+
),
|
602 |
+
weight_norm(
|
603 |
+
nn.Conv1d(
|
604 |
+
dim,
|
605 |
+
dim,
|
606 |
+
kernel_size,
|
607 |
+
1,
|
608 |
+
dilation=dilation[2],
|
609 |
+
padding=self.get_padding(kernel_size, dilation[2]),
|
610 |
+
)
|
611 |
+
),
|
612 |
+
]
|
613 |
+
)
|
614 |
+
|
615 |
+
self.convs2 = nn.ModuleList(
|
616 |
+
[
|
617 |
+
weight_norm(
|
618 |
+
nn.Conv1d(
|
619 |
+
dim,
|
620 |
+
dim,
|
621 |
+
kernel_size,
|
622 |
+
1,
|
623 |
+
dilation=1,
|
624 |
+
padding=self.get_padding(kernel_size, 1),
|
625 |
+
)
|
626 |
+
),
|
627 |
+
weight_norm(
|
628 |
+
nn.Conv1d(
|
629 |
+
dim,
|
630 |
+
dim,
|
631 |
+
kernel_size,
|
632 |
+
1,
|
633 |
+
dilation=1,
|
634 |
+
padding=self.get_padding(kernel_size, 1),
|
635 |
+
)
|
636 |
+
),
|
637 |
+
weight_norm(
|
638 |
+
nn.Conv1d(
|
639 |
+
dim,
|
640 |
+
dim,
|
641 |
+
kernel_size,
|
642 |
+
1,
|
643 |
+
dilation=1,
|
644 |
+
padding=self.get_padding(kernel_size, 1),
|
645 |
+
)
|
646 |
+
),
|
647 |
+
]
|
648 |
+
)
|
649 |
+
|
650 |
+
self.gamma = nn.ParameterList(
|
651 |
+
[
|
652 |
+
(
|
653 |
+
nn.Parameter(
|
654 |
+
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
|
655 |
+
)
|
656 |
+
if layer_scale_init_value is not None
|
657 |
+
else None
|
658 |
+
),
|
659 |
+
(
|
660 |
+
nn.Parameter(
|
661 |
+
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
|
662 |
+
)
|
663 |
+
if layer_scale_init_value is not None
|
664 |
+
else None
|
665 |
+
),
|
666 |
+
(
|
667 |
+
nn.Parameter(
|
668 |
+
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
|
669 |
+
)
|
670 |
+
if layer_scale_init_value is not None
|
671 |
+
else None
|
672 |
+
),
|
673 |
+
]
|
674 |
+
)
|
675 |
+
|
676 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
677 |
+
for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
|
678 |
+
xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
|
679 |
+
xt = c1(xt)
|
680 |
+
xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
|
681 |
+
xt = c2(xt)
|
682 |
+
if gamma is not None:
|
683 |
+
xt = gamma * xt
|
684 |
+
x = xt + x
|
685 |
+
return x
|
686 |
+
|
687 |
+
def remove_weight_norm(self):
|
688 |
+
for l in self.convs1:
|
689 |
+
remove_weight_norm(l)
|
690 |
+
for l in self.convs2:
|
691 |
+
remove_weight_norm(l)
|
692 |
+
|
693 |
+
@staticmethod
|
694 |
+
def get_padding(kernel_size: int, dilation: int = 1) -> int:
|
695 |
+
return int((kernel_size * dilation - dilation) / 2)
|
696 |
+
|
697 |
+
|
698 |
+
class Backbone(nn.Module):
|
699 |
+
"""Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
|
700 |
+
|
701 |
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
702 |
+
"""
|
703 |
+
Args:
|
704 |
+
x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
|
705 |
+
C denotes output features, and L is the sequence length.
|
706 |
+
|
707 |
+
Returns:
|
708 |
+
Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
|
709 |
+
and H denotes the model dimension.
|
710 |
+
"""
|
711 |
+
raise NotImplementedError("Subclasses must implement the forward method.")
|
712 |
+
|
713 |
+
|
714 |
+
class VocosBackbone(Backbone):
|
715 |
+
"""
|
716 |
+
Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
|
717 |
+
|
718 |
+
Args:
|
719 |
+
input_channels (int): Number of input features channels.
|
720 |
+
dim (int): Hidden dimension of the model.
|
721 |
+
intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
|
722 |
+
num_layers (int): Number of ConvNeXtBlock layers.
|
723 |
+
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
|
724 |
+
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
|
725 |
+
None means non-conditional model. Defaults to None.
|
726 |
+
"""
|
727 |
+
|
728 |
+
def __init__(
|
729 |
+
self,
|
730 |
+
input_channels: int,
|
731 |
+
dim: int,
|
732 |
+
intermediate_dim: int,
|
733 |
+
num_layers: int,
|
734 |
+
layer_scale_init_value: Optional[float] = None,
|
735 |
+
adanorm_num_embeddings: Optional[int] = None,
|
736 |
+
):
|
737 |
+
super().__init__()
|
738 |
+
self.input_channels = input_channels
|
739 |
+
self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
|
740 |
+
self.adanorm = adanorm_num_embeddings is not None
|
741 |
+
if adanorm_num_embeddings:
|
742 |
+
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
|
743 |
+
else:
|
744 |
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
745 |
+
layer_scale_init_value = layer_scale_init_value or 1 / num_layers
|
746 |
+
self.convnext = nn.ModuleList(
|
747 |
+
[
|
748 |
+
ConvNeXtBlock(
|
749 |
+
dim=dim,
|
750 |
+
intermediate_dim=intermediate_dim,
|
751 |
+
layer_scale_init_value=layer_scale_init_value,
|
752 |
+
adanorm_num_embeddings=adanorm_num_embeddings,
|
753 |
+
)
|
754 |
+
for _ in range(num_layers)
|
755 |
+
]
|
756 |
+
)
|
757 |
+
self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
|
758 |
+
self.apply(self._init_weights)
|
759 |
+
|
760 |
+
def _init_weights(self, m):
|
761 |
+
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
762 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
763 |
+
nn.init.constant_(m.bias, 0)
|
764 |
+
|
765 |
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
766 |
+
bandwidth_id = kwargs.get("bandwidth_id", None)
|
767 |
+
x = self.embed(x)
|
768 |
+
if self.adanorm:
|
769 |
+
assert bandwidth_id is not None
|
770 |
+
x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id)
|
771 |
+
else:
|
772 |
+
x = self.norm(x.transpose(1, 2))
|
773 |
+
x = x.transpose(1, 2)
|
774 |
+
for conv_block in self.convnext:
|
775 |
+
x = conv_block(x, cond_embedding_id=bandwidth_id)
|
776 |
+
x = self.final_layer_norm(x.transpose(1, 2))
|
777 |
+
return x
|
778 |
+
|
779 |
+
|
780 |
+
class VocosResNetBackbone(Backbone):
|
781 |
+
"""
|
782 |
+
Vocos backbone module built with ResBlocks.
|
783 |
+
|
784 |
+
Args:
|
785 |
+
input_channels (int): Number of input features channels.
|
786 |
+
dim (int): Hidden dimension of the model.
|
787 |
+
num_blocks (int): Number of ResBlock1 blocks.
|
788 |
+
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
|
789 |
+
"""
|
790 |
+
|
791 |
+
def __init__(
|
792 |
+
self,
|
793 |
+
input_channels,
|
794 |
+
dim,
|
795 |
+
num_blocks,
|
796 |
+
layer_scale_init_value=None,
|
797 |
+
):
|
798 |
+
super().__init__()
|
799 |
+
self.input_channels = input_channels
|
800 |
+
self.embed = weight_norm(
|
801 |
+
nn.Conv1d(input_channels, dim, kernel_size=3, padding=1)
|
802 |
+
)
|
803 |
+
layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
|
804 |
+
self.resnet = nn.Sequential(
|
805 |
+
*[
|
806 |
+
ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value)
|
807 |
+
for _ in range(num_blocks)
|
808 |
+
]
|
809 |
+
)
|
810 |
+
|
811 |
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
812 |
+
x = self.embed(x)
|
813 |
+
x = self.resnet(x)
|
814 |
+
x = x.transpose(1, 2)
|
815 |
+
return x
|
816 |
+
|
817 |
+
|
818 |
+
class Vocos(nn.Module):
|
819 |
+
def __init__(
|
820 |
+
self,
|
821 |
+
input_channels: int = 256,
|
822 |
+
dim: int = 384,
|
823 |
+
intermediate_dim: int = 1152,
|
824 |
+
num_layers: int = 8,
|
825 |
+
adanorm_num_embeddings: int = 4,
|
826 |
+
n_fft: int = 800,
|
827 |
+
hop_size: int = 200,
|
828 |
+
padding: str = "same",
|
829 |
+
):
|
830 |
+
super().__init__()
|
831 |
+
|
832 |
+
self.backbone = VocosBackbone(
|
833 |
+
input_channels=input_channels,
|
834 |
+
dim=dim,
|
835 |
+
intermediate_dim=intermediate_dim,
|
836 |
+
num_layers=num_layers,
|
837 |
+
adanorm_num_embeddings=adanorm_num_embeddings,
|
838 |
+
)
|
839 |
+
self.head = ISTFTHead(dim, n_fft, hop_size, padding)
|
840 |
+
|
841 |
+
def forward(self, x):
|
842 |
+
x = self.backbone(x)
|
843 |
+
x = self.head(x)
|
844 |
+
|
845 |
+
return x[:, None, :]
|
modules/tokenizer/tokenizer.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC
|
2 |
+
from abc import abstractmethod
|
3 |
+
import sentencepiece as spm
|
4 |
+
from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model
|
5 |
+
from typing import Any, Union
|
6 |
+
import numpy as np
|
7 |
+
from dataclasses import dataclass
|
8 |
+
|
9 |
+
def encode_pieces(sp_model: spm.SentencePieceProcessor, text: str, sample=False):
|
10 |
+
"""Encode text into sentence pieces. Only supports py3."""
|
11 |
+
|
12 |
+
if not sample:
|
13 |
+
pieces = sp_model.EncodeAsPieces(text)
|
14 |
+
else:
|
15 |
+
pieces = sp_model.SampleEncodeAsPieces(text, 64, 0.1)
|
16 |
+
|
17 |
+
return pieces
|
18 |
+
|
19 |
+
|
20 |
+
class AbstractTokenizer(ABC):
|
21 |
+
"""Abstract class for tokenizer."""
|
22 |
+
|
23 |
+
def __init__(self, name):
|
24 |
+
self.name = name
|
25 |
+
super().__init__()
|
26 |
+
|
27 |
+
@property
|
28 |
+
@abstractmethod
|
29 |
+
def vocab_size(self):
|
30 |
+
pass
|
31 |
+
|
32 |
+
@property
|
33 |
+
@abstractmethod
|
34 |
+
def vocab(self):
|
35 |
+
"""Dictionary from vocab text token to id token."""
|
36 |
+
pass
|
37 |
+
|
38 |
+
@property
|
39 |
+
@abstractmethod
|
40 |
+
def inv_vocab(self):
|
41 |
+
"""Dictionary from vocab id token to text token."""
|
42 |
+
pass
|
43 |
+
|
44 |
+
@abstractmethod
|
45 |
+
def tokenize(self, text):
|
46 |
+
pass
|
47 |
+
|
48 |
+
def detokenize(self, token_ids):
|
49 |
+
raise NotImplementedError('detokenizer is not implemented for {} '
|
50 |
+
'tokenizer'.format(self.name))
|
51 |
+
|
52 |
+
@property
|
53 |
+
def cls(self):
|
54 |
+
raise NotImplementedError('CLS is not provided for {} '
|
55 |
+
'tokenizer'.format(self.name))
|
56 |
+
|
57 |
+
@property
|
58 |
+
def sep(self):
|
59 |
+
raise NotImplementedError('SEP is not provided for {} '
|
60 |
+
'tokenizer'.format(self.name))
|
61 |
+
|
62 |
+
@property
|
63 |
+
def pad(self):
|
64 |
+
raise NotImplementedError('PAD is not provided for {} '
|
65 |
+
'tokenizer'.format(self.name))
|
66 |
+
|
67 |
+
@property
|
68 |
+
def eod(self):
|
69 |
+
raise NotImplementedError('EOD is not provided for {} '
|
70 |
+
'tokenizer'.format(self.name))
|
71 |
+
|
72 |
+
@property
|
73 |
+
def mask(self):
|
74 |
+
raise NotImplementedError('MASK is not provided for {} '
|
75 |
+
'tokenizer'.format(self.name))
|
76 |
+
|
77 |
+
|
78 |
+
class SPieceTokenizer(AbstractTokenizer):
|
79 |
+
def __init__(self, spm_file: str):
|
80 |
+
super().__init__('Sentence Piece')
|
81 |
+
self.sp_model = spm.SentencePieceProcessor()
|
82 |
+
self.sp_model.Load(spm_file)
|
83 |
+
self.eod_id = self.get_token_id('</s>')
|
84 |
+
|
85 |
+
self.special_ids = set([
|
86 |
+
self.sp_model.pad_id(),
|
87 |
+
self.sp_model.eos_id(),
|
88 |
+
self.sp_model.bos_id(),
|
89 |
+
self.sp_model.unk_id(),
|
90 |
+
self.eod_id,
|
91 |
+
])
|
92 |
+
|
93 |
+
# initialize index_2_bytes
|
94 |
+
self._initialize_index_2_bytes()
|
95 |
+
|
96 |
+
def encode_pieces(self, text: str, sample=False):
|
97 |
+
if not sample:
|
98 |
+
pieces = self.sp_model.EncodeAsPieces(text)
|
99 |
+
else:
|
100 |
+
pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1)
|
101 |
+
return pieces
|
102 |
+
|
103 |
+
def _initialize_index_2_bytes(self):
|
104 |
+
proto = sp_pb2_model.ModelProto()
|
105 |
+
proto.ParseFromString(self.sp_model.serialized_model_proto())
|
106 |
+
self.index_2_numbytes = [0] * len(proto.pieces)
|
107 |
+
for i, p in enumerate(proto.pieces):
|
108 |
+
clean_piece = p.piece.replace('▁', '')
|
109 |
+
self.index_2_numbytes[i] = len(clean_piece.encode('utf-8'))
|
110 |
+
|
111 |
+
def set_add_dummy_prefix(self, add_dummy_prefix: bool = False):
|
112 |
+
proto = sp_pb2_model.ModelProto()
|
113 |
+
proto.ParseFromString(self.sp_model.serialized_model_proto())
|
114 |
+
if proto.normalizer_spec.add_dummy_prefix != add_dummy_prefix:
|
115 |
+
proto.normalizer_spec.add_dummy_prefix = add_dummy_prefix
|
116 |
+
self.sp_model.LoadFromSerializedProto(proto.SerializeToString())
|
117 |
+
print(f"> set add_dummy_prefix to {add_dummy_prefix} ...", flush=True)
|
118 |
+
|
119 |
+
def add_special_id(self, token_id):
|
120 |
+
self.special_ids.add(token_id)
|
121 |
+
|
122 |
+
@property
|
123 |
+
def has_dummy_prefix(self):
|
124 |
+
pieces = self.sp_model.EncodeAsPieces("hello")
|
125 |
+
return pieces[0].startswith('▁')
|
126 |
+
|
127 |
+
@property
|
128 |
+
def vocab_size(self):
|
129 |
+
return self.sp_model.GetPieceSize()
|
130 |
+
|
131 |
+
@property
|
132 |
+
def vocab(self):
|
133 |
+
"""Dictionary from vocab text token to id token."""
|
134 |
+
return self.sp_model
|
135 |
+
|
136 |
+
def get_array_bytes(self, array):
|
137 |
+
return sum(self.index_2_numbytes[i] if i < self.vocab_size else 2 for i in array)
|
138 |
+
|
139 |
+
def tokenize(self, text):
|
140 |
+
tokens = encode_pieces(self.sp_model, text)
|
141 |
+
return self.convert_tokens_to_ids(tokens)
|
142 |
+
|
143 |
+
def encode(self, text: str, bos: bool=False, eos: bool=False, **kwargs: Any) -> list[int]:
|
144 |
+
tokens = self.encode_pieces(text)
|
145 |
+
t = self.convert_tokens_to_ids(tokens)
|
146 |
+
if bos:
|
147 |
+
t.insert(0, self.bos_id)
|
148 |
+
if eos:
|
149 |
+
t.append(self.eos_id)
|
150 |
+
return t
|
151 |
+
|
152 |
+
def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]:
|
153 |
+
if isinstance(tokens, str):
|
154 |
+
return self.sp_model.PieceToId(tokens)
|
155 |
+
return [self.sp_model.PieceToId(token) for token in tokens]
|
156 |
+
|
157 |
+
def detokenize(self, token_ids):
|
158 |
+
if isinstance(token_ids, list):
|
159 |
+
pieces = [self.sp_model.IdToPiece(id) for id in token_ids]
|
160 |
+
else:
|
161 |
+
pieces = [self.sp_model.IdToPiece(id) for id in token_ids.tolist()]
|
162 |
+
return pieces
|
163 |
+
|
164 |
+
def decode(self, token_ids: Union[int, list[int]], skip_special_tokens: bool = False) -> str:
|
165 |
+
assert not skip_special_tokens, "skip_special_tokens is not supported"
|
166 |
+
if isinstance(token_ids, (int, np.integer)):
|
167 |
+
return self.detokenize([int(token_ids)])[0]
|
168 |
+
return ''.join(self.detokenize(token_ids))
|
169 |
+
|
170 |
+
def get_token_id(self, token):
|
171 |
+
return self.sp_model.PieceToId(token)
|
172 |
+
|
173 |
+
def inv_vocab(self):
|
174 |
+
# TODO: to be implemented
|
175 |
+
return {}
|
176 |
+
|
177 |
+
def decode_pieces(self, pieces):
|
178 |
+
return self.sp_model.DecodePieces(pieces)
|
179 |
+
|
180 |
+
@property
|
181 |
+
def eod(self):
|
182 |
+
return self.eod_id
|
183 |
+
|
184 |
+
@property
|
185 |
+
def pad_id(self):
|
186 |
+
return self.sp_model.pad_id()
|
187 |
+
|
188 |
+
@property
|
189 |
+
def eos_id(self):
|
190 |
+
return self.sp_model.eos_id()
|
191 |
+
|
192 |
+
@property
|
193 |
+
def bos_id(self):
|
194 |
+
return self.sp_model.bos_id()
|
195 |
+
|
196 |
+
@property
|
197 |
+
def unk_id(self):
|
198 |
+
return self.sp_model.unk_id()
|
199 |
+
|
200 |
+
@property
|
201 |
+
def pad_token_id(self):
|
202 |
+
return self.pad_id
|
203 |
+
|
204 |
+
@property
|
205 |
+
def eos_token_id(self):
|
206 |
+
return self.eos_id
|
207 |
+
|
208 |
+
|
209 |
+
@dataclass
|
210 |
+
class ExtraTokens:
|
211 |
+
msg_end: int
|
212 |
+
user_msg_start: int
|
213 |
+
assistant_msg_start: int
|
214 |
+
name_end: int
|
215 |
+
media_begin: int
|
216 |
+
media_content: int
|
217 |
+
media_end: int
|
218 |
+
pad: int
|
219 |
+
|
220 |
+
|
221 |
+
def instantiate_extra_tokens(tokenizer: AbstractTokenizer):
|
222 |
+
if isinstance(tokenizer, SPieceTokenizer):
|
223 |
+
map_fn = lambda x: tokenizer.convert_tokens_to_ids(x)
|
224 |
+
else:
|
225 |
+
raise ValueError(f"Invalid tokenizer type: {type(tokenizer)}")
|
226 |
+
|
227 |
+
return ExtraTokens(
|
228 |
+
msg_end=map_fn('[extra_id_0]'),
|
229 |
+
user_msg_start=map_fn('[extra_id_1]'),
|
230 |
+
assistant_msg_start=map_fn('[extra_id_2]'),
|
231 |
+
name_end=map_fn('[extra_id_12]'),
|
232 |
+
media_begin=map_fn('[extra_id_13]'),
|
233 |
+
media_content=map_fn('[extra_id_14]'),
|
234 |
+
media_end=map_fn('[extra_id_15]'),
|
235 |
+
pad=tokenizer.pad_id
|
236 |
+
)
|
237 |
+
|
238 |
+
def get_tokenizer_and_extra_tokens():
|
239 |
+
sp_model_path = "resources/tokenizer/160k.model"
|
240 |
+
tokenizer = SPieceTokenizer(sp_model_path)
|
241 |
+
tokenizer.set_add_dummy_prefix(False)
|
242 |
+
extra_tokens = instantiate_extra_tokens(tokenizer)
|
243 |
+
return tokenizer, extra_tokens
|
readme.md
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MoonCast: High-Quality Zero-Shot Podcast Generation
|
2 |
+
|
3 |
+
## Overview
|
4 |
+
Demo page: [demo](https://mooncastdemo.github.io)
|
5 |
+
|
6 |
+
Paper: [paper](https://arxiv.org/abs/2503.14345)
|
7 |
+
|
8 |
+
We open-source this system to advance the field of human-like speech synthesis. Our goal is to create more natural and expressive synthetic voices that bridge the gap between machines and humans. We hope this project will inspire researchers and developers to explore new possibilities in voice technology. We warmly welcome contributions from anyone interested in this project. Whether through code, documentation, feedback, or sharing your insights, every input helps make this project better.
|
9 |
+
|
10 |
+
### Abstract
|
11 |
+
Recent advances in text-to-speech synthesis have achieved notable success in generating high-quality short utterances for individual speakers. However, these systems still face challenges when extending their capabilities to long, multi-speaker, and spontaneous dialogues, typical of real-world scenarios such as podcasts. These limitations arise from two primary challenges: 1) long speech: podcasts typically span several minutes, exceeding the upper limit of most existing work; 2) spontaneity: podcasts are marked by their spontaneous, oral nature, which sharply contrasts with formal, written contexts; existing works often fall short in capturing this spontaneity. In this paper, we propose MoonCast, a solution for high-quality zero-shot podcast generation, aiming to synthesize natural podcast-style speech from text-only sources (e.g., stories, technical reports, news in TXT, PDF, or Web URL formats) using the voices of unseen speakers. To generate long audio, we adopt a long-context language model-based audio modeling approach utilizing large-scale long-context speech data. To enhance spontaneity, we utilize a podcast generation module to generate scripts with spontaneous details, which have been empirically shown to be as crucial as the text-to-speech modeling itself. Experiments demonstrate that MoonCast outperforms baselines, with particularly notable improvements in spontaneity and coherence.
|
12 |
+
|
13 |
+
## Environment Setup
|
14 |
+
- Create conda environment.
|
15 |
+
|
16 |
+
``` sh
|
17 |
+
conda create -n mooncast -y python=3.10
|
18 |
+
conda activate mooncast
|
19 |
+
pip install -r requirements.txt
|
20 |
+
pip install flash-attn --no-build-isolation
|
21 |
+
pip install huggingface_hub
|
22 |
+
pip install gradio==5.22.0
|
23 |
+
```
|
24 |
+
|
25 |
+
- Download the pretrained weights.
|
26 |
+
``` sh
|
27 |
+
python download_pretrain.py
|
28 |
+
```
|
29 |
+
|
30 |
+
## Example Usage
|
31 |
+
|
32 |
+
The audio prompts used in this project are sourced from publicly available podcast segments and are intended solely for demonstration purposes. Redistribution of these audio files, whether in their original form or as generated audio, is strictly prohibited. If you have any concerns or questions regarding the use of these audio files, please contact us at [email protected]
|
33 |
+
|
34 |
+
```sh
|
35 |
+
CUDA_VISIBLE_DEVICIES=0 python inference.py
|
36 |
+
```
|
37 |
+
|
38 |
+
## Disclaimer
|
39 |
+
This project is intended for **research purposes only**. We strongly encourage users to **use this project and its generated audio responsibly**. **We are not responsible for any misuse or abuse of this project**. By using this project, you agree to comply with all applicable laws and ethical guidelines.
|
requirements.txt
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.3.1
|
2 |
+
torchaudio==2.3.1
|
3 |
+
sentencepiece==0.2.0
|
4 |
+
protobuf
|
5 |
+
numpy
|
6 |
+
|
7 |
+
librosa==0.9.1
|
8 |
+
pyyaml
|
9 |
+
transformers
|
10 |
+
safetensors
|
11 |
+
einops
|
12 |
+
scipy
|
13 |
+
timm==1.0.7
|
14 |
+
torchdyn
|
15 |
+
librosa
|
16 |
+
accelerate==0.26.0
|
17 |
+
ninja
|
18 |
+
cryptography
|
test/test_audio_detokenizer.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import torch
|
3 |
+
sys.path.append('.')
|
4 |
+
from modules.audio_tokenizer.audio_tokenizer import get_audio_tokenizer
|
5 |
+
from modules.audio_detokenizer.audio_detokenizer import get_audio_detokenizer, detokenize
|
6 |
+
import torchaudio
|
7 |
+
import librosa
|
8 |
+
|
9 |
+
if __name__ == '__main__':
|
10 |
+
audio_tokenizer = get_audio_tokenizer()
|
11 |
+
audio_detokenizer = get_audio_detokenizer()
|
12 |
+
|
13 |
+
input_wav_16k, _ = librosa.load("en_prompt0.wav", sr=16000)
|
14 |
+
input_wav_24k, _ = librosa.load("en_prompt0.wav", sr=24000)
|
15 |
+
|
16 |
+
prompt_sec = 1
|
17 |
+
prompt_wav_16k = input_wav_16k[:16000*prompt_sec]
|
18 |
+
prompt_wav_24k = input_wav_24k[:24000*prompt_sec]
|
19 |
+
input_wav_16k = input_wav_16k[16000*prompt_sec:]
|
20 |
+
input_wav_24k = input_wav_24k[24000*prompt_sec:]
|
21 |
+
|
22 |
+
prompt_wav_24k = torch.tensor(prompt_wav_24k)[None, :].cuda()
|
23 |
+
prompt_wav_16k = torch.tensor(prompt_wav_16k)[None, :].cuda()
|
24 |
+
input_wav_24k = torch.tensor(input_wav_24k)[None, :].cuda()
|
25 |
+
input_wav_16k = torch.tensor(input_wav_16k)[None, :].cuda()
|
26 |
+
|
27 |
+
semantic_token = audio_tokenizer.tokenize(input_wav_16k)
|
28 |
+
prompt_semantic_token = audio_tokenizer.tokenize(prompt_wav_16k)
|
29 |
+
|
30 |
+
recon_wav = detokenize(audio_detokenizer, semantic_token, prompt_wav_24k, prompt_semantic_token)
|
31 |
+
print(recon_wav.shape)
|
32 |
+
torchaudio.save("test/tmp_recon_en_prompt0.wav", recon_wav.cpu(), 24000)
|
33 |
+
|
34 |
+
print("All tests passed!")
|
test/test_audio_tokenizer.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append('.')
|
3 |
+
from modules.audio_tokenizer.audio_tokenizer import get_audio_tokenizer
|
4 |
+
import torch
|
5 |
+
|
6 |
+
if __name__ == '__main__':
|
7 |
+
audio_tokenizer = get_audio_tokenizer()
|
8 |
+
|
9 |
+
input_wav = torch.zeros(1, 8000)
|
10 |
+
semantic_token = audio_tokenizer.tokenize(input_wav)
|
11 |
+
semantic_token = semantic_token.cpu().numpy().tolist()
|
12 |
+
assert semantic_token == [[ 765, 3512, 7469, 7469, 7028, 2567, 6008, 7469, 6217, 2567, 7649, 7469,
|
13 |
+
3292, 2567, 7649, 7469, 3292, 2567, 948, 7469, 3292, 2567, 948, 7469]]
|
14 |
+
|
15 |
+
print("All tests passed!")
|
test/test_tokenizer.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append('.')
|
3 |
+
from modules.tokenizer.tokenizer import get_tokenizer_and_extra_tokens
|
4 |
+
|
5 |
+
|
6 |
+
|
7 |
+
if __name__ == '__main__':
|
8 |
+
tokenizer, extra_tokens = get_tokenizer_and_extra_tokens()
|
9 |
+
|
10 |
+
assert tokenizer.encode("user") == [1495]
|
11 |
+
assert tokenizer.decode([1495]) == "user"
|
12 |
+
|
13 |
+
assert tokenizer.encode("0") == [501]
|
14 |
+
assert tokenizer.decode([501]) == "0"
|
15 |
+
|
16 |
+
assert tokenizer.encode("1") == [503]
|
17 |
+
assert tokenizer.decode([503]) == "1"
|
18 |
+
|
19 |
+
assert tokenizer.encode("assistant") == [110866]
|
20 |
+
assert tokenizer.decode([110866]) == "assistant"
|
21 |
+
|
22 |
+
assert tokenizer.encode("audio") == [26229]
|
23 |
+
assert tokenizer.decode([26229]) == "audio"
|
24 |
+
|
25 |
+
|
26 |
+
assert extra_tokens.msg_end == 260
|
27 |
+
assert extra_tokens.user_msg_start == 261
|
28 |
+
assert extra_tokens.assistant_msg_start == 262
|
29 |
+
assert extra_tokens.name_end == 272
|
30 |
+
assert extra_tokens.media_begin == 273
|
31 |
+
assert extra_tokens.media_content == 274
|
32 |
+
assert extra_tokens.media_end == 275
|
33 |
+
|
34 |
+
assert [tokenizer.convert_tokens_to_ids(i) for i in ['<0x0A>', '</s>', '[extra_id_0]']] == [14, 1, 260]
|
35 |
+
|
36 |
+
print("All tests passed!")
|
37 |
+
|
zh_prompt0.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f4a334352093e0aafb2217b53f9a027e7619c74b3823f70e752aa9f51aebc597
|
3 |
+
size 240044
|
zh_prompt1.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:21a376ffa8591b16ad3c2403d6d9f9a5053b799039770c5d49ceaa0c92d6eafe
|
3 |
+
size 228964
|