Spaces:
Running
Running
Opus
commited on
Commit
·
9d9ac6c
1
Parent(s):
597284f
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +180 -0
- README.md +5 -5
- app.py +189 -0
- datasets/CoTMovieDubbing/README.md +20 -0
- datasets/CoTMovieDubbing/filelist/cot_spk_for_speech_gen.lst +0 -0
- datasets/CoTMovieDubbing/filelist/mmlm_test.jsonl +0 -0
- datasets/CoTMovieDubbing/filelist/mmlm_train.jsonl +0 -0
- datasets/Grid/README.md +1 -0
- datasets/V2C/README.md +1 -0
- datasets/V2C/V2C_Setting2.txt +0 -0
- datasets/V2C/V2C_Setting3.txt +0 -0
- requirements.txt +237 -0
- ruff.toml +11 -0
- src/internvl/eval.py +337 -0
- src/moviedubber/configs/basemodel.yaml +9 -0
- src/moviedubber/eval.py +245 -0
- src/moviedubber/infer/basic.toml +4 -0
- src/moviedubber/infer/utils_infer.py +399 -0
- src/moviedubber/infer/video_preprocess.py +315 -0
- src/moviedubber/infer_with_mmlm_result.py +339 -0
- src/moviedubber/model/__init__.py +5 -0
- src/moviedubber/model/cfm.py +209 -0
- src/moviedubber/model/dit.py +297 -0
- src/moviedubber/model/modules.py +467 -0
- src/moviedubber/model/utils.py +128 -0
- src/third_party/BigVGAN/.gitignore +146 -0
- src/third_party/BigVGAN/LICENSE +21 -0
- src/third_party/BigVGAN/README.md +266 -0
- src/third_party/BigVGAN/activations.py +126 -0
- src/third_party/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
- src/third_party/BigVGAN/alias_free_activation/cuda/activation1d.py +77 -0
- src/third_party/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
- src/third_party/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
- src/third_party/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
- src/third_party/BigVGAN/alias_free_activation/cuda/load.py +86 -0
- src/third_party/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
- src/third_party/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
- src/third_party/BigVGAN/alias_free_activation/torch/act.py +30 -0
- src/third_party/BigVGAN/alias_free_activation/torch/filter.py +101 -0
- src/third_party/BigVGAN/alias_free_activation/torch/resample.py +58 -0
- src/third_party/BigVGAN/bigvgan.py +493 -0
- src/third_party/BigVGAN/configs/bigvgan_22khz_80band.json +45 -0
- src/third_party/BigVGAN/configs/bigvgan_24khz_100band.json +45 -0
- src/third_party/BigVGAN/configs/bigvgan_base_22khz_80band.json +45 -0
- src/third_party/BigVGAN/configs/bigvgan_base_24khz_100band.json +45 -0
- src/third_party/BigVGAN/configs/bigvgan_v2_22khz_80band_256x.json +61 -0
- src/third_party/BigVGAN/configs/bigvgan_v2_22khz_80band_fmax8k_256x.json +61 -0
- src/third_party/BigVGAN/configs/bigvgan_v2_24khz_100band_256x.json +61 -0
- src/third_party/BigVGAN/configs/bigvgan_v2_44khz_128band_256x.json +61 -0
- src/third_party/BigVGAN/configs/bigvgan_v2_44khz_128band_512x.json +61 -0
.gitignore
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
data
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
|
131 |
+
# Spyder project settings
|
132 |
+
.spyderproject
|
133 |
+
.spyproject
|
134 |
+
|
135 |
+
# Rope project settings
|
136 |
+
.ropeproject
|
137 |
+
|
138 |
+
# mkdocs documentation
|
139 |
+
/site
|
140 |
+
|
141 |
+
# mypy
|
142 |
+
.mypy_cache/
|
143 |
+
.dmypy.json
|
144 |
+
dmypy.json
|
145 |
+
|
146 |
+
# Pyre type checker
|
147 |
+
.pyre/
|
148 |
+
|
149 |
+
# pytype static type analyzer
|
150 |
+
.pytype/
|
151 |
+
|
152 |
+
# Cython debug symbols
|
153 |
+
cython_debug/
|
154 |
+
|
155 |
+
# PyCharm
|
156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
+
#.idea/
|
161 |
+
|
162 |
+
.idea/
|
163 |
+
|
164 |
+
.DS_Store
|
165 |
+
data_process/
|
166 |
+
internvl_chat/work_dirs/
|
167 |
+
internvl_chat/unittest/
|
168 |
+
internvl_chat/data/
|
169 |
+
Husky2/*
|
170 |
+
data_process/
|
171 |
+
*distillation*
|
172 |
+
|
173 |
+
batchscript-*
|
174 |
+
results/
|
175 |
+
|
176 |
+
# *txt
|
177 |
+
*csv
|
178 |
+
*mp4
|
179 |
+
temp/
|
180 |
+
src/moviedubber/infer/basic_test.toml
|
README.md
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.22.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: Deepdubber V1
|
3 |
+
emoji: 🌍
|
4 |
+
colorFrom: red
|
5 |
+
colorTo: green
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.22.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tempfile
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
import librosa
|
6 |
+
import soundfile
|
7 |
+
import tomli
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torchaudio
|
11 |
+
from moviepy import VideoFileClip
|
12 |
+
from pydub import AudioSegment
|
13 |
+
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
|
14 |
+
|
15 |
+
from src.moviedubber.infer.utils_infer import (
|
16 |
+
cfg_strength,
|
17 |
+
chunk_text,
|
18 |
+
nfe_step,
|
19 |
+
sway_sampling_coef,
|
20 |
+
)
|
21 |
+
from src.moviedubber.infer.video_preprocess import VideoFeatureExtractor
|
22 |
+
from src.moviedubber.infer_with_mmlm_result import concat_movie_with_audio, get_spk_emb, load_models
|
23 |
+
from src.moviedubber.model.utils import convert_char_to_pinyin
|
24 |
+
|
25 |
+
|
26 |
+
def load_asr_model(model_id="openai/whisper-large-v3-turbo"):
|
27 |
+
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
28 |
+
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
29 |
+
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
|
30 |
+
).to(device)
|
31 |
+
|
32 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
33 |
+
|
34 |
+
pipe = pipeline(
|
35 |
+
"automatic-speech-recognition",
|
36 |
+
model=model,
|
37 |
+
tokenizer=processor.tokenizer,
|
38 |
+
feature_extractor=processor.feature_extractor,
|
39 |
+
torch_dtype=torch_dtype,
|
40 |
+
device=device,
|
41 |
+
)
|
42 |
+
return pipe
|
43 |
+
|
44 |
+
|
45 |
+
device = "cpu"
|
46 |
+
config = tomli.load(open("src/moviedubber/infer/basic.toml", "rb"))
|
47 |
+
|
48 |
+
|
49 |
+
ema_model, vocoder, ort_session = load_models(config, device=device)
|
50 |
+
asr_pipe = load_asr_model()
|
51 |
+
|
52 |
+
videofeature_extractor = VideoFeatureExtractor(device=device)
|
53 |
+
|
54 |
+
|
55 |
+
def deepdubber(video_path: str, subtitle_text: str, audio_path: str = None) -> str:
|
56 |
+
print(f"Starting deepdubber with video_path: {video_path} and subtitle_text: {subtitle_text}")
|
57 |
+
gen_clip = videofeature_extractor.extract_features(video_path)
|
58 |
+
gen_text = subtitle_text
|
59 |
+
|
60 |
+
clip = VideoFileClip(video_path)
|
61 |
+
gen_audio_len = int(clip.duration * 24000 // 256)
|
62 |
+
|
63 |
+
gen_clip = gen_clip.unsqueeze(0).to(device=device, dtype=torch.float32).transpose(1, 2)
|
64 |
+
gen_clip = F.interpolate(gen_clip, size=(gen_audio_len,), mode="linear", align_corners=False).transpose(1, 2)
|
65 |
+
|
66 |
+
ref_audio_len = None
|
67 |
+
if audio_path is not None:
|
68 |
+
print("reference audio is not None, dubbing with reference audio")
|
69 |
+
|
70 |
+
if audio_path.endswith(".mp3"):
|
71 |
+
audio = AudioSegment.from_mp3(audio_path)
|
72 |
+
|
73 |
+
wav_file = audio_path.replace(".mp3", ".wav")
|
74 |
+
audio.export(wav_file, format="wav")
|
75 |
+
else:
|
76 |
+
wav_file = audio_path
|
77 |
+
|
78 |
+
ref_text = asr_pipe(librosa.load(wav_file, sr=16000)[0], generate_kwargs={"language": "english"})["text"]
|
79 |
+
ref_text = ref_text.replace("\n", " ").replace("\r", " ")
|
80 |
+
print(f"Reference text: {ref_text}")
|
81 |
+
|
82 |
+
spk_emb = get_spk_emb(wav_file, ort_session)
|
83 |
+
spk_emb = torch.tensor(spk_emb).to(device=device, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
|
84 |
+
|
85 |
+
audio_data, sr = torchaudio.load(wav_file)
|
86 |
+
resampler = torchaudio.transforms.Resample(sr, 24000)
|
87 |
+
if sr != 24000:
|
88 |
+
audio_data = resampler(audio_data)
|
89 |
+
|
90 |
+
if audio_data.shape[0] > 1:
|
91 |
+
audio_data = torch.mean(audio_data, dim=0, keepdim=True)
|
92 |
+
|
93 |
+
audio_data = audio_data.to(device)
|
94 |
+
|
95 |
+
ref_audio_len = int(audio_data.shape[-1] // 256)
|
96 |
+
ref_clip = torch.zeros((1, ref_audio_len, 768)).to(device=device)
|
97 |
+
|
98 |
+
gen_clip = torch.cat((gen_clip, ref_clip), dim=1)
|
99 |
+
|
100 |
+
gen_audio_len = ref_audio_len + gen_audio_len
|
101 |
+
|
102 |
+
gen_text = ref_text + " " + gen_text
|
103 |
+
|
104 |
+
else:
|
105 |
+
spk_emb = torch.zeros((1, 1, 192)).to(device=device)
|
106 |
+
audio_data = torch.zeros((1, gen_audio_len, 100)).to(device=device)
|
107 |
+
|
108 |
+
gen_text_batches = chunk_text(gen_text, max_chars=1024)
|
109 |
+
final_text_list = convert_char_to_pinyin(gen_text_batches)
|
110 |
+
|
111 |
+
with torch.inference_mode():
|
112 |
+
generated, _ = ema_model.sample(
|
113 |
+
cond=audio_data,
|
114 |
+
text=final_text_list,
|
115 |
+
clip=gen_clip,
|
116 |
+
spk_emb=spk_emb,
|
117 |
+
duration=gen_audio_len,
|
118 |
+
steps=nfe_step,
|
119 |
+
cfg_strength=cfg_strength,
|
120 |
+
sway_sampling_coef=sway_sampling_coef,
|
121 |
+
no_ref_audio=False,
|
122 |
+
)
|
123 |
+
|
124 |
+
generated = generated.to(torch.float32)
|
125 |
+
|
126 |
+
if ref_audio_len is not None:
|
127 |
+
generated = generated[:, ref_audio_len:, :]
|
128 |
+
|
129 |
+
generated_mel_spec = generated.permute(0, 2, 1)
|
130 |
+
generated_wave = vocoder(generated_mel_spec)
|
131 |
+
|
132 |
+
generated_wave = generated_wave.squeeze().cpu().numpy()
|
133 |
+
|
134 |
+
# using a temporary wav file to save the generated audio
|
135 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_wav_file:
|
136 |
+
temp_wav_path = temp_wav_file.name
|
137 |
+
soundfile.write(temp_wav_path, generated_wave, samplerate=24000)
|
138 |
+
|
139 |
+
concated_video = concat_movie_with_audio(temp_wav_path, video_path, ".")
|
140 |
+
|
141 |
+
# Ensure the temporary file is deleted after use
|
142 |
+
os.remove(temp_wav_path)
|
143 |
+
|
144 |
+
print(f"Deepdubber completed successfully, output path: {concated_video}")
|
145 |
+
return concated_video
|
146 |
+
|
147 |
+
|
148 |
+
def process_video_dubbing(video_path: str, subtitle_text: str, audio_path: str = None) -> str:
|
149 |
+
try:
|
150 |
+
print(f"Processing video: {video_path}")
|
151 |
+
if not os.path.exists(video_path):
|
152 |
+
raise ValueError("Video file does not exist")
|
153 |
+
|
154 |
+
if not subtitle_text.strip():
|
155 |
+
raise ValueError("Subtitle text cannot be empty")
|
156 |
+
|
157 |
+
output_path = deepdubber(video_path, subtitle_text, audio_path)
|
158 |
+
|
159 |
+
return output_path
|
160 |
+
|
161 |
+
except Exception as e:
|
162 |
+
print(f"Error in process_video_dubbing: {e}")
|
163 |
+
|
164 |
+
return None
|
165 |
+
|
166 |
+
|
167 |
+
def create_ui():
|
168 |
+
with gr.Blocks(title="DeepDubber-V1") as app:
|
169 |
+
gr.Markdown("# DeepDubber-V1\nUpload your video file and enter the text you want to dub")
|
170 |
+
|
171 |
+
with gr.Row():
|
172 |
+
video_input = gr.Video(label="Upload video")
|
173 |
+
audio_input = gr.Audio(label="Upload audio", type="filepath")
|
174 |
+
subtitle_input = gr.Textbox(label="Enter the text", placeholder="Enter the text to be dubbed...", lines=5)
|
175 |
+
|
176 |
+
process_btn = gr.Button("Start Dubbing")
|
177 |
+
|
178 |
+
output_video = gr.Video(label="Dubbed Video")
|
179 |
+
|
180 |
+
process_btn.click(
|
181 |
+
fn=process_video_dubbing, inputs=[video_input, subtitle_input, audio_input], outputs=output_video
|
182 |
+
)
|
183 |
+
|
184 |
+
return app
|
185 |
+
|
186 |
+
|
187 |
+
if __name__ == "__main__":
|
188 |
+
app = create_ui()
|
189 |
+
app.launch()
|
datasets/CoTMovieDubbing/README.md
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## MovieData Preprocessing
|
2 |
+
|
3 |
+
TODO
|
4 |
+
|
5 |
+
## Data Tree
|
6 |
+
|
7 |
+
```
|
8 |
+
moviecopy
|
9 |
+
|-- movie1_name
|
10 |
+
|-- xxxx.mp4
|
11 |
+
|-- xxxx.mp3
|
12 |
+
|-- xxxx.txt
|
13 |
+
|-- ...
|
14 |
+
|-- movie2_name
|
15 |
+
|-- xxxx.mp4
|
16 |
+
|-- xxxx.mp3
|
17 |
+
|-- xxxx.txt
|
18 |
+
|-- ...
|
19 |
+
|-- ...
|
20 |
+
```
|
datasets/CoTMovieDubbing/filelist/cot_spk_for_speech_gen.lst
ADDED
The diff for this file is too large to render.
See raw diff
|
|
datasets/CoTMovieDubbing/filelist/mmlm_test.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
datasets/CoTMovieDubbing/filelist/mmlm_train.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
datasets/Grid/README.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Refer to: [Grid](https://paperswithcode.com/dataset/grid)
|
datasets/V2C/README.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Refer to: [V2C](https://github.com/chenqi008/V2C)
|
datasets/V2C/V2C_Setting2.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
datasets/V2C/V2C_Setting3.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
requirements.txt
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==2.1.0
|
2 |
+
accelerate==0.34.2
|
3 |
+
addict==2.4.0
|
4 |
+
aiofiles==24.1.0
|
5 |
+
aiohappyeyeballs==2.4.6
|
6 |
+
aiohttp==3.11.13
|
7 |
+
aiosignal==1.3.2
|
8 |
+
altair==5.5.0
|
9 |
+
annotated-types==0.7.0
|
10 |
+
antlr4-python3-runtime==4.9.3
|
11 |
+
anyio==4.8.0
|
12 |
+
async-timeout==5.0.1
|
13 |
+
attrs==25.1.0
|
14 |
+
audioread==3.0.1
|
15 |
+
beautifulsoup4==4.13.3
|
16 |
+
bitsandbytes==0.42.0
|
17 |
+
blinker==1.9.0
|
18 |
+
boto3==1.37.12
|
19 |
+
botocore==1.37.12
|
20 |
+
cached_path==1.7.1
|
21 |
+
cachetools==5.5.2
|
22 |
+
certifi==2025.1.31
|
23 |
+
cffi==1.17.1
|
24 |
+
charset-normalizer==3.4.1
|
25 |
+
click==8.1.8
|
26 |
+
colorama==0.4.6
|
27 |
+
coloredlogs==15.0.1
|
28 |
+
contourpy==1.3.0
|
29 |
+
cycler==0.12.1
|
30 |
+
datasets==3.3.2
|
31 |
+
decorator==5.2.1
|
32 |
+
decord==0.6.0
|
33 |
+
deepspeed==0.15.4
|
34 |
+
dill==0.3.8
|
35 |
+
docstring_parser==0.16
|
36 |
+
einops==0.8.1
|
37 |
+
einops-exts==0.0.4
|
38 |
+
einx==0.3.0
|
39 |
+
eval_type_backport==0.2.2
|
40 |
+
exceptiongroup==1.2.2
|
41 |
+
fastapi==0.115.8
|
42 |
+
ffmpy==0.5.0
|
43 |
+
filelock==3.13.1
|
44 |
+
flash-attn==2.6.3
|
45 |
+
flatbuffers==25.2.10
|
46 |
+
fonttools==4.56.0
|
47 |
+
frozendict==2.4.6
|
48 |
+
frozenlist==1.5.0
|
49 |
+
fsspec==2024.6.1
|
50 |
+
future==1.0.0
|
51 |
+
gdown==5.2.0
|
52 |
+
gitdb==4.0.12
|
53 |
+
GitPython==3.1.44
|
54 |
+
google-api-core==2.24.2
|
55 |
+
google-auth==2.38.0
|
56 |
+
google-cloud-core==2.4.3
|
57 |
+
google-cloud-storage==2.19.0
|
58 |
+
google-crc32c==1.6.0
|
59 |
+
google-resumable-media==2.7.2
|
60 |
+
googleapis-common-protos==1.69.1
|
61 |
+
gradio==3.35.2
|
62 |
+
gradio_client==0.2.9
|
63 |
+
grpcio==1.70.0
|
64 |
+
h11==0.14.0
|
65 |
+
hjson==3.1.0
|
66 |
+
httpcore==0.17.3
|
67 |
+
httpx==0.24.0
|
68 |
+
huggingface-hub==0.27.1
|
69 |
+
humanfriendly==10.0
|
70 |
+
idna==3.10
|
71 |
+
imageio==2.37.0
|
72 |
+
imageio-ffmpeg==0.6.0
|
73 |
+
importlib_metadata==8.6.1
|
74 |
+
importlib_resources==6.5.2
|
75 |
+
jieba==0.42.1
|
76 |
+
Jinja2==3.1.4
|
77 |
+
jmespath==1.0.1
|
78 |
+
joblib==1.4.2
|
79 |
+
jsonschema==4.23.0
|
80 |
+
jsonschema-specifications==2024.10.1
|
81 |
+
kiwisolver==1.4.7
|
82 |
+
latex2mathml==3.77.0
|
83 |
+
lazy_loader==0.4
|
84 |
+
librosa==0.11.0
|
85 |
+
liger_kernel==0.4.2
|
86 |
+
linkify-it-py==2.0.3
|
87 |
+
llvmlite==0.43.0
|
88 |
+
loguru==0.7.3
|
89 |
+
Markdown==3.7
|
90 |
+
markdown-it-py==2.2.0
|
91 |
+
markdown2==2.5.3
|
92 |
+
MarkupSafe==2.1.5
|
93 |
+
matplotlib==3.9.4
|
94 |
+
mdit-py-plugins==0.3.3
|
95 |
+
mdurl==0.1.2
|
96 |
+
mmcls==0.25.0
|
97 |
+
mmcv-full==1.6.2
|
98 |
+
mmsegmentation==0.30.0
|
99 |
+
model-index==0.1.11
|
100 |
+
moviepy==2.1.2
|
101 |
+
mpmath==1.3.0
|
102 |
+
msgpack==1.1.0
|
103 |
+
multidict==6.1.0
|
104 |
+
multiprocess==0.70.16
|
105 |
+
narwhals==1.28.0
|
106 |
+
networkx==3.2.1
|
107 |
+
ninja==1.11.1.3
|
108 |
+
numba==0.60.0
|
109 |
+
numpy==1.26.3
|
110 |
+
nvidia-cublas-cu11==11.11.3.6
|
111 |
+
nvidia-cublas-cu12==12.1.3.1
|
112 |
+
nvidia-cuda-cupti-cu11==11.8.87
|
113 |
+
nvidia-cuda-cupti-cu12==12.1.105
|
114 |
+
nvidia-cuda-nvrtc-cu11==11.8.89
|
115 |
+
nvidia-cuda-nvrtc-cu12==12.1.105
|
116 |
+
nvidia-cuda-runtime-cu11==11.8.89
|
117 |
+
nvidia-cuda-runtime-cu12==12.1.105
|
118 |
+
nvidia-cudnn-cu11==9.1.0.70
|
119 |
+
nvidia-cudnn-cu12==9.1.0.70
|
120 |
+
nvidia-cufft-cu11==10.9.0.58
|
121 |
+
nvidia-cufft-cu12==11.0.2.54
|
122 |
+
nvidia-curand-cu11==10.3.0.86
|
123 |
+
nvidia-curand-cu12==10.3.2.106
|
124 |
+
nvidia-cusolver-cu11==11.4.1.48
|
125 |
+
nvidia-cusolver-cu12==11.4.5.107
|
126 |
+
nvidia-cusparse-cu11==11.7.5.86
|
127 |
+
nvidia-cusparse-cu12==12.1.0.106
|
128 |
+
nvidia-ml-py==12.570.86
|
129 |
+
nvidia-nccl-cu11==2.20.5
|
130 |
+
nvidia-nccl-cu12==2.20.5
|
131 |
+
nvidia-nvjitlink-cu12==12.8.61
|
132 |
+
nvidia-nvtx-cu11==11.8.86
|
133 |
+
nvidia-nvtx-cu12==12.1.105
|
134 |
+
omegaconf==2.3.0
|
135 |
+
onnxruntime==1.18.0
|
136 |
+
opencv-python==4.11.0.86
|
137 |
+
opendatalab==0.0.10
|
138 |
+
openmim==0.3.9
|
139 |
+
openxlab==0.0.11
|
140 |
+
ordered-set==4.1.0
|
141 |
+
orjson==3.10.15
|
142 |
+
packaging==24.2
|
143 |
+
pandas==2.2.3
|
144 |
+
peft==0.10.0
|
145 |
+
pillow==10.4.0
|
146 |
+
platformdirs==4.3.6
|
147 |
+
pooch==1.8.2
|
148 |
+
prettytable==3.14.0
|
149 |
+
proglog==0.1.10
|
150 |
+
propcache==0.3.0
|
151 |
+
proto-plus==1.26.1
|
152 |
+
protobuf==5.29.3
|
153 |
+
psutil==7.0.0
|
154 |
+
py-cpuinfo==9.0.0
|
155 |
+
pyarrow==19.0.1
|
156 |
+
pyasn1==0.6.1
|
157 |
+
pyasn1_modules==0.4.1
|
158 |
+
pycocoevalcap==1.2
|
159 |
+
pycocotools==2.0.8
|
160 |
+
pycparser==2.22
|
161 |
+
pycryptodome==3.21.0
|
162 |
+
pydantic==2.10.6
|
163 |
+
pydantic_core==2.27.2
|
164 |
+
pydeck==0.9.1
|
165 |
+
pydub==0.25.1
|
166 |
+
Pygments==2.19.1
|
167 |
+
pyparsing==3.2.1
|
168 |
+
pypinyin==0.53.0
|
169 |
+
PySocks==1.7.1
|
170 |
+
python-dateutil==2.9.0.post0
|
171 |
+
python-dotenv==1.0.1
|
172 |
+
python-multipart==0.0.20
|
173 |
+
pytz==2025.1
|
174 |
+
PyYAML==6.0.2
|
175 |
+
referencing==0.36.2
|
176 |
+
regex==2024.11.6
|
177 |
+
requests==2.32.3
|
178 |
+
rich==13.9.4
|
179 |
+
rpds-py==0.23.1
|
180 |
+
rsa==4.9
|
181 |
+
s3transfer==0.11.4
|
182 |
+
safetensors==0.5.3
|
183 |
+
scikit-learn==1.6.1
|
184 |
+
scipy==1.13.1
|
185 |
+
semantic-version==2.10.0
|
186 |
+
sentencepiece==0.1.99
|
187 |
+
shortuuid==1.0.13
|
188 |
+
shtab==1.7.1
|
189 |
+
six==1.17.0
|
190 |
+
smmap==5.0.2
|
191 |
+
sniffio==1.3.1
|
192 |
+
soundfile==0.13.1
|
193 |
+
soupsieve==2.6
|
194 |
+
soxr==0.5.0.post1
|
195 |
+
starlette==0.45.3
|
196 |
+
streamlit==1.42.2
|
197 |
+
streamlit-image-select==0.6.0
|
198 |
+
svgwrite==1.4.3
|
199 |
+
sympy==1.13.1
|
200 |
+
tabulate==0.9.0
|
201 |
+
tenacity==9.0.0
|
202 |
+
tensorboard==2.19.0
|
203 |
+
tensorboard-data-server==0.7.2
|
204 |
+
tensorboardX==2.6.2.2
|
205 |
+
termcolor==2.5.0
|
206 |
+
threadpoolctl==3.5.0
|
207 |
+
timm==0.9.12
|
208 |
+
tokenizers==0.19.1
|
209 |
+
toml==0.10.2
|
210 |
+
tomli==2.2.1
|
211 |
+
torch==2.4.1
|
212 |
+
torchaudio==2.4.1
|
213 |
+
torchdiffeq==0.2.5
|
214 |
+
torchvision==0.19.1
|
215 |
+
tornado==6.4.2
|
216 |
+
tqdm==4.67.1
|
217 |
+
transformers==4.42.1
|
218 |
+
triton==3.0.0
|
219 |
+
trl==0.10.1
|
220 |
+
typeguard==4.4.2
|
221 |
+
typing_extensions==4.12.2
|
222 |
+
tyro==0.9.16
|
223 |
+
tzdata==2025.1
|
224 |
+
uc-micro-py==1.0.3
|
225 |
+
urllib3==1.26.20
|
226 |
+
uvicorn==0.34.0
|
227 |
+
watchdog==6.0.0
|
228 |
+
wavedrom==2.0.3.post3
|
229 |
+
wcwidth==0.2.13
|
230 |
+
websockets==15.0
|
231 |
+
Werkzeug==3.1.3
|
232 |
+
x-transformers==2.1.37
|
233 |
+
xxhash==3.5.0
|
234 |
+
yacs==0.1.8
|
235 |
+
yapf==0.40.1
|
236 |
+
yarl==1.18.3
|
237 |
+
zipp==3.21.0
|
ruff.toml
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
line-length = 120
|
2 |
+
target-version = "py310"
|
3 |
+
|
4 |
+
[lint]
|
5 |
+
# Only ignore variables with names starting with "_".
|
6 |
+
dummy-variable-rgx = "^_.*$"
|
7 |
+
ignore = ["E402"]
|
8 |
+
|
9 |
+
[lint.isort]
|
10 |
+
force-single-line = false
|
11 |
+
lines-after-imports = 2
|
src/internvl/eval.py
ADDED
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import pandas as pd
|
8 |
+
import torch
|
9 |
+
import torchvision.transforms as T
|
10 |
+
from decord import VideoReader, cpu
|
11 |
+
from PIL import Image
|
12 |
+
from torchvision.transforms.functional import InterpolationMode
|
13 |
+
from tqdm import tqdm
|
14 |
+
from transformers import AutoTokenizer
|
15 |
+
|
16 |
+
|
17 |
+
sys.path.insert(0, os.path.join(str(Path(__file__).resolve().parents[2]), "src/third_party/InternVL/internvl_chat"))
|
18 |
+
from internvl.model.internvl_chat.modeling_internvl_chat import InternVLChatModel # type: ignore
|
19 |
+
|
20 |
+
|
21 |
+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
22 |
+
IMAGENET_STD = (0.229, 0.224, 0.225)
|
23 |
+
|
24 |
+
|
25 |
+
def build_transform(input_size):
|
26 |
+
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
|
27 |
+
transform = T.Compose(
|
28 |
+
[
|
29 |
+
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
|
30 |
+
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
|
31 |
+
T.ToTensor(),
|
32 |
+
T.Normalize(mean=MEAN, std=STD),
|
33 |
+
]
|
34 |
+
)
|
35 |
+
return transform
|
36 |
+
|
37 |
+
|
38 |
+
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
|
39 |
+
best_ratio_diff = float("inf")
|
40 |
+
best_ratio = (1, 1)
|
41 |
+
area = width * height
|
42 |
+
for ratio in target_ratios:
|
43 |
+
target_aspect_ratio = ratio[0] / ratio[1]
|
44 |
+
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
45 |
+
if ratio_diff < best_ratio_diff:
|
46 |
+
best_ratio_diff = ratio_diff
|
47 |
+
best_ratio = ratio
|
48 |
+
elif ratio_diff == best_ratio_diff:
|
49 |
+
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
50 |
+
best_ratio = ratio
|
51 |
+
return best_ratio
|
52 |
+
|
53 |
+
|
54 |
+
def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
|
55 |
+
orig_width, orig_height = image.size
|
56 |
+
aspect_ratio = orig_width / orig_height
|
57 |
+
|
58 |
+
# calculate the existing image aspect ratio
|
59 |
+
target_ratios = set(
|
60 |
+
(i, j)
|
61 |
+
for n in range(min_num, max_num + 1)
|
62 |
+
for i in range(1, n + 1)
|
63 |
+
for j in range(1, n + 1)
|
64 |
+
if i * j <= max_num and i * j >= min_num
|
65 |
+
)
|
66 |
+
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
67 |
+
|
68 |
+
# find the closest aspect ratio to the target
|
69 |
+
target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)
|
70 |
+
|
71 |
+
# calculate the target width and height
|
72 |
+
target_width = image_size * target_aspect_ratio[0]
|
73 |
+
target_height = image_size * target_aspect_ratio[1]
|
74 |
+
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
75 |
+
|
76 |
+
# resize the image
|
77 |
+
resized_img = image.resize((target_width, target_height))
|
78 |
+
processed_images = []
|
79 |
+
for i in range(blocks):
|
80 |
+
box = (
|
81 |
+
(i % (target_width // image_size)) * image_size,
|
82 |
+
(i // (target_width // image_size)) * image_size,
|
83 |
+
((i % (target_width // image_size)) + 1) * image_size,
|
84 |
+
((i // (target_width // image_size)) + 1) * image_size,
|
85 |
+
)
|
86 |
+
# split the image
|
87 |
+
split_img = resized_img.crop(box)
|
88 |
+
processed_images.append(split_img)
|
89 |
+
assert len(processed_images) == blocks
|
90 |
+
if use_thumbnail and len(processed_images) != 1:
|
91 |
+
thumbnail_img = image.resize((image_size, image_size))
|
92 |
+
processed_images.append(thumbnail_img)
|
93 |
+
return processed_images
|
94 |
+
|
95 |
+
|
96 |
+
def load_image(image_file, input_size=448, max_num=12):
|
97 |
+
image = Image.open(image_file).convert("RGB")
|
98 |
+
transform = build_transform(input_size=input_size)
|
99 |
+
images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
|
100 |
+
pixel_values = [transform(image) for image in images]
|
101 |
+
pixel_values = torch.stack(pixel_values)
|
102 |
+
return pixel_values
|
103 |
+
|
104 |
+
|
105 |
+
def get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
|
106 |
+
if bound:
|
107 |
+
start, end = bound[0], bound[1]
|
108 |
+
else:
|
109 |
+
start, end = -100000, 100000
|
110 |
+
start_idx = max(first_idx, round(start * fps))
|
111 |
+
end_idx = min(round(end * fps), max_frame)
|
112 |
+
seg_size = float(end_idx - start_idx) / num_segments
|
113 |
+
frame_indices = np.array(
|
114 |
+
[int(start_idx + (seg_size / 2) + np.round(seg_size * idx)) for idx in range(num_segments)]
|
115 |
+
)
|
116 |
+
return frame_indices
|
117 |
+
|
118 |
+
|
119 |
+
def load_video(
|
120 |
+
video_path,
|
121 |
+
bound=None,
|
122 |
+
input_size=448,
|
123 |
+
max_num=1,
|
124 |
+
num_segments=32,
|
125 |
+
cache_dir=".cache/expcache",
|
126 |
+
):
|
127 |
+
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
|
128 |
+
max_frame = len(vr) - 1
|
129 |
+
fps = float(vr.get_avg_fps())
|
130 |
+
|
131 |
+
video_cache_dir = video_path.split("/")[-2] + "_" + os.path.basename(video_path).split(".")[0]
|
132 |
+
video_cache_dir = os.path.join(cache_dir, video_cache_dir)
|
133 |
+
cache_filename = os.path.join(
|
134 |
+
video_cache_dir,
|
135 |
+
f"_bound-{bound}_input_size-{input_size}_max_num-{max_num}_num_segments-{num_segments}.pt",
|
136 |
+
)
|
137 |
+
if os.path.exists(cache_filename) and os.path.isfile(cache_filename):
|
138 |
+
cache = torch.load(cache_filename, weights_only=True)
|
139 |
+
pixel_values = cache["pixel_values"]
|
140 |
+
num_patches_list = cache["num_patches_list"]
|
141 |
+
|
142 |
+
else:
|
143 |
+
pixel_values_list, num_patches_list = [], []
|
144 |
+
transform = build_transform(input_size=input_size)
|
145 |
+
frame_indices = get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments)
|
146 |
+
|
147 |
+
frame_indices = np.append(0, frame_indices) # Add 0 at the beginning of the list
|
148 |
+
frame_indices = np.append(frame_indices, max_frame) # Add max_frame at the end of the list
|
149 |
+
|
150 |
+
os.makedirs(video_cache_dir, exist_ok=True)
|
151 |
+
|
152 |
+
idx = 0
|
153 |
+
for frame_index in frame_indices:
|
154 |
+
img = Image.fromarray(vr[frame_index].asnumpy()).convert("RGB")
|
155 |
+
|
156 |
+
img.save(os.path.join(video_cache_dir, f"frame_{frame_index}_tile_{idx}.png"))
|
157 |
+
|
158 |
+
img = dynamic_preprocess(img, image_size=input_size, use_thumbnail=True, max_num=max_num)
|
159 |
+
|
160 |
+
pixel_values = [transform(tile) for tile in img]
|
161 |
+
pixel_values = torch.stack(pixel_values)
|
162 |
+
num_patches_list.append(pixel_values.shape[0])
|
163 |
+
pixel_values_list.append(pixel_values)
|
164 |
+
|
165 |
+
idx += 1
|
166 |
+
pixel_values = torch.cat(pixel_values_list)
|
167 |
+
|
168 |
+
os.makedirs(cache_dir, exist_ok=True)
|
169 |
+
torch.save({"pixel_values": pixel_values, "num_patches_list": num_patches_list}, cache_filename)
|
170 |
+
|
171 |
+
return pixel_values, num_patches_list
|
172 |
+
|
173 |
+
|
174 |
+
def analyze_predictions(file_path):
|
175 |
+
# Read the CSV file
|
176 |
+
df = pd.read_csv(file_path)
|
177 |
+
|
178 |
+
# Calculate overall accuracy
|
179 |
+
total_samples = len(df)
|
180 |
+
correct_predictions = df["is_correct"].value_counts().get(True, 0)
|
181 |
+
overall_accuracy = correct_predictions / total_samples
|
182 |
+
|
183 |
+
# Initialize metrics for each class
|
184 |
+
classes = ["A", "B", "C"]
|
185 |
+
class_metrics = {}
|
186 |
+
|
187 |
+
for cls in classes:
|
188 |
+
# Filter for samples where target is this class
|
189 |
+
true_class = df[df["target"] == cls]
|
190 |
+
# Filter for samples where prediction is this class
|
191 |
+
# pred_class = df[df["predict"] == cls]
|
192 |
+
|
193 |
+
# Calculate TP, FP, FN
|
194 |
+
TP = len(df[(df["target"] == cls) & (df["predict"] == cls)])
|
195 |
+
FP = len(df[(df["target"] != cls) & (df["predict"] == cls)])
|
196 |
+
FN = len(df[(df["target"] == cls) & (df["predict"] != cls)])
|
197 |
+
|
198 |
+
# Calculate precision, recall, F1
|
199 |
+
precision = TP / (TP + FP) if (TP + FP) > 0 else 0
|
200 |
+
recall = TP / (TP + FN) if (TP + FN) > 0 else 0
|
201 |
+
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
202 |
+
|
203 |
+
# Store metrics
|
204 |
+
class_metrics[cls] = {
|
205 |
+
"total_samples": len(true_class),
|
206 |
+
"precision": precision,
|
207 |
+
"recall": recall,
|
208 |
+
"f1": f1,
|
209 |
+
"true_positives": TP,
|
210 |
+
"false_positives": FP,
|
211 |
+
"false_negatives": FN,
|
212 |
+
}
|
213 |
+
|
214 |
+
print(f"Overall Accuracy: {overall_accuracy:.4f} ({correct_predictions}/{total_samples})")
|
215 |
+
print()
|
216 |
+
print("Indicators for each category:")
|
217 |
+
|
218 |
+
for cls in classes:
|
219 |
+
metrics = class_metrics[cls]
|
220 |
+
print(f" Class {cls}:")
|
221 |
+
print(f" Total Samples: {metrics['total_samples']}")
|
222 |
+
print(f" Precision: {metrics['precision']:.4f}")
|
223 |
+
print(f" Recall: {metrics['recall']:.4f}")
|
224 |
+
print(f" F1 Score: {metrics['f1']:.4f}")
|
225 |
+
print(f" True Positives: {metrics['true_positives']}")
|
226 |
+
print(f" False Positives: {metrics['false_positives']}")
|
227 |
+
print(f" False Negatives: {metrics['false_negatives']}")
|
228 |
+
|
229 |
+
return overall_accuracy, class_metrics
|
230 |
+
|
231 |
+
|
232 |
+
def s_thread(video_dir, model_path, device, chunk, idx, queue):
|
233 |
+
model = InternVLChatModel.from_pretrained(
|
234 |
+
model_path,
|
235 |
+
torch_dtype=torch.bfloat16,
|
236 |
+
low_cpu_mem_usage=True,
|
237 |
+
use_flash_attn=True,
|
238 |
+
)
|
239 |
+
model = model.eval().to(device)
|
240 |
+
|
241 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False)
|
242 |
+
|
243 |
+
generation_config = dict(max_new_tokens=1024, do_sample=False)
|
244 |
+
|
245 |
+
res = []
|
246 |
+
for line in tqdm(chunk, position=idx, desc=f"Device {device}"):
|
247 |
+
data = json.loads(line)
|
248 |
+
|
249 |
+
video_path = os.path.join(video_dir, data["video"])
|
250 |
+
ques = data["conversations"][0]["value"]
|
251 |
+
|
252 |
+
target_ans = data["conversations"][1]["value"].split("<CONCLUSION>")[1].split("</CONCLUSION>")[0].strip()
|
253 |
+
|
254 |
+
pixel_values, num_patches_list = load_video(video_path, num_segments=8, max_num=1)
|
255 |
+
pixel_values = pixel_values.to(torch.bfloat16).to(device)
|
256 |
+
video_prefix = "".join([f"Frame{i + 1}: <image>\n" for i in range(len(num_patches_list))])
|
257 |
+
question = video_prefix + f"{ques}"
|
258 |
+
response = model.chat(
|
259 |
+
tokenizer,
|
260 |
+
pixel_values,
|
261 |
+
question,
|
262 |
+
generation_config,
|
263 |
+
num_patches_list=num_patches_list,
|
264 |
+
history=None,
|
265 |
+
return_history=False,
|
266 |
+
)
|
267 |
+
|
268 |
+
try:
|
269 |
+
ans = response.split("<CONCLUSION>")[1].split("</CONCLUSION>")[0].strip()
|
270 |
+
except Exception as e:
|
271 |
+
print(f"Error: {e}, response: {response}")
|
272 |
+
ans = response.strip()[0]
|
273 |
+
|
274 |
+
is_correct = False
|
275 |
+
if ans == target_ans:
|
276 |
+
is_correct = True
|
277 |
+
|
278 |
+
res.append(f"{video_path},{is_correct},{target_ans},{ans}")
|
279 |
+
|
280 |
+
queue.put(res)
|
281 |
+
|
282 |
+
|
283 |
+
if __name__ == "__main__":
|
284 |
+
import argparse
|
285 |
+
|
286 |
+
import torch.multiprocessing as mp
|
287 |
+
|
288 |
+
parser = argparse.ArgumentParser(description="eval script for mmlm")
|
289 |
+
parser.add_argument("--model_path", type=str, help="Path to the model checkpoint.")
|
290 |
+
parser.add_argument("--test_file", type=str, help="Path to the test file.")
|
291 |
+
parser.add_argument("--video_dir", type=str, help="Path to the test video directory.")
|
292 |
+
parser.add_argument("--gpuids", type=str, help="GPU ids to use.")
|
293 |
+
|
294 |
+
# python eval.py --model_path /path/to/model --test_file /path/to/test_file --video_dir /path/to/video_dir --gpuids 0,1,2,3
|
295 |
+
|
296 |
+
args = parser.parse_args()
|
297 |
+
|
298 |
+
model_path = args.model_path
|
299 |
+
test_file = args.test_file
|
300 |
+
video_dir = args.video_dir
|
301 |
+
|
302 |
+
gpu_ids = args.gpuids.split(",") if args.gpuids else ["0"]
|
303 |
+
|
304 |
+
cot_test = Path(test_file).read_text().splitlines()
|
305 |
+
|
306 |
+
chunks = np.array_split(cot_test, len(gpu_ids))
|
307 |
+
|
308 |
+
mp.set_start_method("spawn", force=True)
|
309 |
+
|
310 |
+
queue = mp.Queue()
|
311 |
+
|
312 |
+
processes = []
|
313 |
+
for idx, chunk in enumerate(chunks):
|
314 |
+
device = gpu_ids[idx % len(gpu_ids)]
|
315 |
+
device = f"cuda:{device}"
|
316 |
+
|
317 |
+
p = mp.Process(target=s_thread, args=(video_dir, model_path, device, chunk, idx, queue))
|
318 |
+
processes.append(p)
|
319 |
+
p.start()
|
320 |
+
|
321 |
+
for process in processes:
|
322 |
+
process.join()
|
323 |
+
|
324 |
+
result = []
|
325 |
+
for _ in range(len(chunks)):
|
326 |
+
res = queue.get()
|
327 |
+
result.extend(res)
|
328 |
+
|
329 |
+
res_saved = f"{'__'.join(model_path.split('/'))}_res.csv"
|
330 |
+
with open(res_saved, "w") as f:
|
331 |
+
f.write("video_id,is_correct,target,predict\n")
|
332 |
+
for res in result:
|
333 |
+
f.write(f"{res}\n")
|
334 |
+
|
335 |
+
accuracy, metrics = analyze_predictions(res_saved)
|
336 |
+
|
337 |
+
print("All processes finished.\n\n")
|
src/moviedubber/configs/basemodel.yaml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
arch:
|
3 |
+
dim: 1024
|
4 |
+
depth: 22
|
5 |
+
heads: 16
|
6 |
+
ff_mult: 2
|
7 |
+
text_dim: 512
|
8 |
+
conv_layers: 4
|
9 |
+
|
src/moviedubber/eval.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import string
|
4 |
+
from concurrent.futures import ProcessPoolExecutor
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import librosa
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from evaluate import load
|
11 |
+
from pymcd.mcd import Calculate_MCD
|
12 |
+
from tqdm import tqdm
|
13 |
+
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, Wav2Vec2FeatureExtractor, WavLMForXVector, pipeline
|
14 |
+
|
15 |
+
|
16 |
+
def convert_numbers_to_words(text):
|
17 |
+
"""Convert single digits in text to words with spaces"""
|
18 |
+
number_word_map = {
|
19 |
+
"0": "zero",
|
20 |
+
"1": "one",
|
21 |
+
"2": "two",
|
22 |
+
"3": "three",
|
23 |
+
"4": "four",
|
24 |
+
"5": "five",
|
25 |
+
"6": "six",
|
26 |
+
"7": "seven",
|
27 |
+
"8": "eight",
|
28 |
+
"9": "nine",
|
29 |
+
}
|
30 |
+
|
31 |
+
words = text.split()
|
32 |
+
converted_words = []
|
33 |
+
|
34 |
+
for word in words:
|
35 |
+
# Check if the word contains both letters and numbers (like 'j4')
|
36 |
+
if any(c.isdigit() for c in word) and any(c.isalpha() for c in word):
|
37 |
+
# Split the word into parts and convert digits
|
38 |
+
new_word = ""
|
39 |
+
for c in word:
|
40 |
+
if c.isdigit():
|
41 |
+
new_word += " " + number_word_map[c]
|
42 |
+
else:
|
43 |
+
new_word += c
|
44 |
+
converted_words.append(new_word)
|
45 |
+
# Check if the word is a single digit
|
46 |
+
elif word.isdigit() and len(word) == 1:
|
47 |
+
converted_words.append(number_word_map[word])
|
48 |
+
else:
|
49 |
+
converted_words.append(word)
|
50 |
+
|
51 |
+
return " ".join(converted_words)
|
52 |
+
|
53 |
+
|
54 |
+
def clean_text(text):
|
55 |
+
text = convert_numbers_to_words(text)
|
56 |
+
text = text.translate(str.maketrans("", "", string.punctuation))
|
57 |
+
text = text.lower()
|
58 |
+
return text
|
59 |
+
|
60 |
+
|
61 |
+
def wer_pipe(gen_dir: str, target_dir: str, model_id="openai/whisper-large-v3-turbo"):
|
62 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
63 |
+
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
64 |
+
print(f"Using Model: {model_id} for WER Evaluation")
|
65 |
+
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
66 |
+
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
|
67 |
+
).to(device)
|
68 |
+
|
69 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
70 |
+
|
71 |
+
pipe = pipeline(
|
72 |
+
"automatic-speech-recognition",
|
73 |
+
model=model,
|
74 |
+
tokenizer=processor.tokenizer,
|
75 |
+
feature_extractor=processor.feature_extractor,
|
76 |
+
torch_dtype=torch_dtype,
|
77 |
+
device=device,
|
78 |
+
)
|
79 |
+
|
80 |
+
gen_list = list(Path(gen_dir).glob("*.wav"))
|
81 |
+
for line in tqdm(gen_list, desc="Processing audio files"):
|
82 |
+
wav = line
|
83 |
+
if not wav.exists():
|
84 |
+
continue
|
85 |
+
|
86 |
+
text = pipe(librosa.load(wav, sr=16000)[0], generate_kwargs={"language": "english"})["text"]
|
87 |
+
with open(wav.with_suffix(".asrtxt"), "w") as fw:
|
88 |
+
fw.write(text)
|
89 |
+
|
90 |
+
wer_metric = load("wer")
|
91 |
+
|
92 |
+
val_list = list(Path(target_dir).glob("*.txt"))
|
93 |
+
|
94 |
+
wer = []
|
95 |
+
for txt in tqdm(val_list, desc="Calculating WER"):
|
96 |
+
try:
|
97 |
+
# Since the original text is automatically transcribed and has not been manually verified, all texts will be cleaned here.
|
98 |
+
|
99 |
+
target_text = " ".join(set(txt.read_text().splitlines()))
|
100 |
+
target_text = clean_text(target_text)
|
101 |
+
|
102 |
+
gen_text = " ".join(Path(os.path.join(gen_dir, txt.with_suffix(".asrtxt").name)).read_text().splitlines())
|
103 |
+
gen_text = clean_text(gen_text)
|
104 |
+
|
105 |
+
if target_text == "" or gen_text == "":
|
106 |
+
continue
|
107 |
+
|
108 |
+
wer_ = wer_metric.compute(references=[target_text], predictions=[gen_text])
|
109 |
+
|
110 |
+
except Exception as e:
|
111 |
+
print("Error in wer calculation: ", e)
|
112 |
+
continue
|
113 |
+
|
114 |
+
wer.append(wer_)
|
115 |
+
|
116 |
+
return np.mean(wer)
|
117 |
+
|
118 |
+
|
119 |
+
def spk_sim_pipe(gen_dir, target_dir):
|
120 |
+
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("microsoft/wavlm-base-sv")
|
121 |
+
model = WavLMForXVector.from_pretrained("microsoft/wavlm-base-sv").cuda()
|
122 |
+
|
123 |
+
cosine_sim = torch.nn.CosineSimilarity(dim=-1)
|
124 |
+
|
125 |
+
val_list = list(Path(target_dir).glob("*.wav"))
|
126 |
+
|
127 |
+
scos = []
|
128 |
+
|
129 |
+
for target_wav in tqdm(val_list, desc="Calculating speaker similarity"):
|
130 |
+
target = librosa.load(target_wav, sr=16000)[0]
|
131 |
+
gen = librosa.load(os.path.join(gen_dir, target_wav.name), sr=16000)[0]
|
132 |
+
|
133 |
+
try:
|
134 |
+
input1 = feature_extractor(gen, return_tensors="pt", sampling_rate=16000).to("cuda")
|
135 |
+
embeddings1 = model(**input1).embeddings
|
136 |
+
|
137 |
+
input2 = feature_extractor(target, return_tensors="pt", sampling_rate=16000).to("cuda")
|
138 |
+
embeddings2 = model(**input2).embeddings
|
139 |
+
|
140 |
+
similarity = cosine_sim(embeddings1[0], embeddings2[0])
|
141 |
+
|
142 |
+
except Exception as e:
|
143 |
+
print(f"Error in {target_wav}, {e}")
|
144 |
+
continue
|
145 |
+
|
146 |
+
scos.append(similarity.detach().cpu().numpy())
|
147 |
+
|
148 |
+
return np.mean(scos)
|
149 |
+
|
150 |
+
|
151 |
+
def calculate_mcd_for_wav(target_wav, gen_dir, mcd_toolbox_dtw, mcd_toolbox_dtw_sl):
|
152 |
+
_mcd_dtw = mcd_toolbox_dtw.calculate_mcd(target_wav, os.path.join(gen_dir, target_wav.name))
|
153 |
+
_mcd_dtw_sl = mcd_toolbox_dtw_sl.calculate_mcd(target_wav, os.path.join(gen_dir, target_wav.name))
|
154 |
+
return _mcd_dtw, _mcd_dtw_sl
|
155 |
+
|
156 |
+
|
157 |
+
def mcd_pipe(gen_dir, target_dir, num_processes=16):
|
158 |
+
mcd_toolbox_dtw = Calculate_MCD(MCD_mode="dtw")
|
159 |
+
mcd_toolbox_dtw_sl = Calculate_MCD(MCD_mode="dtw_sl")
|
160 |
+
|
161 |
+
val_list = list(Path(target_dir).glob("*.wav"))
|
162 |
+
|
163 |
+
mcd_dtw = []
|
164 |
+
mcd_dtw_sl = []
|
165 |
+
|
166 |
+
with ProcessPoolExecutor(max_workers=num_processes) as executor:
|
167 |
+
futures = [
|
168 |
+
executor.submit(calculate_mcd_for_wav, target_wav, gen_dir, mcd_toolbox_dtw, mcd_toolbox_dtw_sl)
|
169 |
+
for target_wav in val_list
|
170 |
+
]
|
171 |
+
for future in tqdm(futures, desc="Calculating MCD"):
|
172 |
+
_mcd_dtw, _mcd_dtw_sl = future.result()
|
173 |
+
mcd_dtw.append(_mcd_dtw)
|
174 |
+
mcd_dtw_sl.append(_mcd_dtw_sl)
|
175 |
+
|
176 |
+
return np.mean(mcd_dtw), np.mean(mcd_dtw_sl)
|
177 |
+
|
178 |
+
|
179 |
+
def run_all_metrics(gen_dir, target_dir, whisper_model="openai/whisper-large-v3-turbo"):
|
180 |
+
"""Run all evaluation metrics and return results"""
|
181 |
+
results = {}
|
182 |
+
|
183 |
+
print("Running WER evaluation...")
|
184 |
+
results["wer"] = wer_pipe(gen_dir, target_dir, model_id=whisper_model)
|
185 |
+
|
186 |
+
print("Running speaker similarity evaluation...")
|
187 |
+
results["speaker_similarity"] = spk_sim_pipe(gen_dir, target_dir)
|
188 |
+
|
189 |
+
print("Running MCD evaluation...")
|
190 |
+
mcd_dtw, mcd_dtw_sl = mcd_pipe(gen_dir, target_dir)
|
191 |
+
results["mcd_dtw"] = mcd_dtw
|
192 |
+
results["mcd_dtw_sl"] = mcd_dtw_sl
|
193 |
+
|
194 |
+
return results
|
195 |
+
|
196 |
+
|
197 |
+
if __name__ == "__main__":
|
198 |
+
parser = argparse.ArgumentParser(description="Audio evaluation metrics")
|
199 |
+
parser.add_argument("--gen_dir", type=str, required=True, help="Directory containing generated audio files")
|
200 |
+
parser.add_argument("--target_dir", type=str, required=True, help="Directory containing target audio files")
|
201 |
+
parser.add_argument(
|
202 |
+
"--metric",
|
203 |
+
type=str,
|
204 |
+
default="all",
|
205 |
+
choices=["wer", "spk_sim", "mcd", "all"],
|
206 |
+
help="Evaluation metric to use",
|
207 |
+
)
|
208 |
+
parser.add_argument(
|
209 |
+
"--whisper_model",
|
210 |
+
type=str,
|
211 |
+
default="openai/whisper-large-v3-turbo",
|
212 |
+
help="Whisper model to use for WER evaluation",
|
213 |
+
)
|
214 |
+
# python eval.py --gen_dir path/to/generated --target_dir path/to/target
|
215 |
+
# keep the name of gen_wav and target_wav the same
|
216 |
+
args = parser.parse_args()
|
217 |
+
|
218 |
+
gen_dir = args.gen_dir
|
219 |
+
target_dir = args.target_dir
|
220 |
+
|
221 |
+
if not os.path.exists(gen_dir):
|
222 |
+
raise ValueError(f"Generated audio directory does not exist: {gen_dir}")
|
223 |
+
if not os.path.exists(target_dir):
|
224 |
+
raise ValueError(f"Target audio directory does not exist: {target_dir}")
|
225 |
+
|
226 |
+
if args.metric == "all":
|
227 |
+
results = run_all_metrics(gen_dir, target_dir, args.whisper_model)
|
228 |
+
print("\nEvaluation Results:")
|
229 |
+
print(f"WER: {results['wer']:.4f}")
|
230 |
+
print(f"Speaker Similarity: {results['speaker_similarity']:.4f}")
|
231 |
+
print(f"MCD (DTW): {results['mcd_dtw']:.4f}")
|
232 |
+
print(f"MCD (DTW-SL): {results['mcd_dtw_sl']:.4f}")
|
233 |
+
|
234 |
+
elif args.metric == "wer":
|
235 |
+
wer = wer_pipe(gen_dir, target_dir, model_id=args.whisper_model)
|
236 |
+
print(f"WER: {wer:.4f}")
|
237 |
+
|
238 |
+
elif args.metric == "spk_sim":
|
239 |
+
spk_sim = spk_sim_pipe(gen_dir, target_dir)
|
240 |
+
print(f"Speaker Similarity: {spk_sim:.4f}")
|
241 |
+
|
242 |
+
elif args.metric == "mcd":
|
243 |
+
mcd_dtw, mcd_dtw_sl = mcd_pipe(gen_dir, target_dir)
|
244 |
+
print(f"MCD (DTW): {mcd_dtw:.4f}")
|
245 |
+
print(f"MCD (DTW-SL): {mcd_dtw_sl:.4f}")
|
src/moviedubber/infer/basic.toml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ckpt_file = "/path/to/ckpt_file.pth"
|
2 |
+
vocab_file = "/path/to/vocab_file.txt"
|
3 |
+
vocoder_local_path = "/path/to/bigvgan"
|
4 |
+
campplus_path = "/path/to/campplus.onnx"
|
src/moviedubber/infer/utils_infer.py
ADDED
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# A unified script for inference process
|
2 |
+
# Make adjustments inside functions, and consider both gradio and cli scripts if need to change func output format
|
3 |
+
|
4 |
+
import re
|
5 |
+
from importlib.resources import files
|
6 |
+
|
7 |
+
import matplotlib
|
8 |
+
|
9 |
+
|
10 |
+
matplotlib.use("Agg")
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import torch.nn.functional as F
|
15 |
+
import torchaudio
|
16 |
+
import tqdm
|
17 |
+
|
18 |
+
from src.moviedubber.model import CFM
|
19 |
+
from src.moviedubber.model.utils import convert_char_to_pinyin, get_tokenizer
|
20 |
+
|
21 |
+
|
22 |
+
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
23 |
+
|
24 |
+
# -----------------------------------------
|
25 |
+
|
26 |
+
target_sample_rate = 24000
|
27 |
+
n_mel_channels = 100
|
28 |
+
hop_length = 256
|
29 |
+
win_length = 1024
|
30 |
+
n_fft = 1024
|
31 |
+
mel_spec_type = "bigvgan"
|
32 |
+
target_rms = 0.1
|
33 |
+
cross_fade_duration = 0.15
|
34 |
+
ode_method = "euler"
|
35 |
+
nfe_step = 32 # 16, 32
|
36 |
+
# cfg_strength = 2.0
|
37 |
+
cfg_strength = 1
|
38 |
+
sway_sampling_coef = -1.0
|
39 |
+
speed = 1.0
|
40 |
+
fix_duration = None
|
41 |
+
|
42 |
+
# -----------------------------------------
|
43 |
+
|
44 |
+
|
45 |
+
# chunk text into smaller pieces
|
46 |
+
|
47 |
+
|
48 |
+
def chunk_text(text, max_chars=135):
|
49 |
+
"""
|
50 |
+
Splits the input text into chunks, each with a maximum number of characters.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
text (str): The text to be split.
|
54 |
+
max_chars (int): The maximum number of characters per chunk.
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
List[str]: A list of text chunks.
|
58 |
+
"""
|
59 |
+
chunks = []
|
60 |
+
current_chunk = ""
|
61 |
+
# Split the text into sentences based on punctuation followed by whitespace
|
62 |
+
sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", text)
|
63 |
+
|
64 |
+
for sentence in sentences:
|
65 |
+
if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars:
|
66 |
+
current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
|
67 |
+
else:
|
68 |
+
if current_chunk:
|
69 |
+
chunks.append(current_chunk.strip())
|
70 |
+
current_chunk = sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
|
71 |
+
|
72 |
+
if current_chunk:
|
73 |
+
chunks.append(current_chunk.strip())
|
74 |
+
|
75 |
+
return chunks
|
76 |
+
|
77 |
+
|
78 |
+
# load vocoder
|
79 |
+
def load_vocoder(local_path, device=device):
|
80 |
+
from src.third_party.BigVGAN import bigvgan
|
81 |
+
|
82 |
+
vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
|
83 |
+
|
84 |
+
vocoder.remove_weight_norm()
|
85 |
+
vocoder = vocoder.eval().to(device)
|
86 |
+
return vocoder
|
87 |
+
|
88 |
+
|
89 |
+
# load model checkpoint for inference
|
90 |
+
|
91 |
+
|
92 |
+
def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
|
93 |
+
if dtype is None:
|
94 |
+
dtype = (
|
95 |
+
torch.float16
|
96 |
+
if "cuda" in device
|
97 |
+
and torch.cuda.get_device_properties(device).major >= 6
|
98 |
+
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
|
99 |
+
else torch.float32
|
100 |
+
)
|
101 |
+
model = model.to(dtype)
|
102 |
+
|
103 |
+
ckpt_type = ckpt_path.split(".")[-1]
|
104 |
+
if ckpt_type == "safetensors":
|
105 |
+
from safetensors.torch import load_file
|
106 |
+
|
107 |
+
checkpoint = load_file(ckpt_path, device=device)
|
108 |
+
else:
|
109 |
+
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)
|
110 |
+
|
111 |
+
if use_ema:
|
112 |
+
if ckpt_type == "safetensors":
|
113 |
+
checkpoint = {"ema_model_state_dict": checkpoint}
|
114 |
+
checkpoint["model_state_dict"] = {
|
115 |
+
k.replace("ema_model.", ""): v
|
116 |
+
for k, v in checkpoint["ema_model_state_dict"].items()
|
117 |
+
if k not in ["initted", "step"]
|
118 |
+
}
|
119 |
+
|
120 |
+
# patch for backward compatibility, 305e3ea
|
121 |
+
for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
|
122 |
+
if key in checkpoint["model_state_dict"]:
|
123 |
+
del checkpoint["model_state_dict"][key]
|
124 |
+
|
125 |
+
state_dict_result = model.load_state_dict(checkpoint["model_state_dict"], strict=False)
|
126 |
+
if state_dict_result.unexpected_keys:
|
127 |
+
print("\nUnexpected keys in state_dict:", state_dict_result.unexpected_keys)
|
128 |
+
if state_dict_result.missing_keys:
|
129 |
+
print("\nMissing keys in state_dict:", state_dict_result.missing_keys)
|
130 |
+
else:
|
131 |
+
if ckpt_type == "safetensors":
|
132 |
+
checkpoint = {"model_state_dict": checkpoint}
|
133 |
+
model.load_state_dict(checkpoint["model_state_dict"], strict=True)
|
134 |
+
|
135 |
+
del checkpoint
|
136 |
+
torch.cuda.empty_cache()
|
137 |
+
|
138 |
+
return model.to(device)
|
139 |
+
|
140 |
+
|
141 |
+
# load model for inference
|
142 |
+
|
143 |
+
|
144 |
+
def load_model(
|
145 |
+
model_cls,
|
146 |
+
model_cfg,
|
147 |
+
ckpt_path,
|
148 |
+
controlnet=None,
|
149 |
+
mel_spec_type=mel_spec_type,
|
150 |
+
vocab_file="",
|
151 |
+
ode_method=ode_method,
|
152 |
+
use_ema=True,
|
153 |
+
device=device,
|
154 |
+
):
|
155 |
+
tokenizer = "custom"
|
156 |
+
|
157 |
+
print("\nvocab : ", vocab_file)
|
158 |
+
print("token : ", tokenizer)
|
159 |
+
print("model : ", ckpt_path, "\n")
|
160 |
+
|
161 |
+
vocab_char_map, vocab_size = get_tokenizer(vocab_file, tokenizer)
|
162 |
+
|
163 |
+
if controlnet is not None:
|
164 |
+
controlnet = controlnet(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels)
|
165 |
+
|
166 |
+
model = CFM(
|
167 |
+
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
|
168 |
+
mel_spec_kwargs=dict(
|
169 |
+
n_fft=n_fft,
|
170 |
+
hop_length=hop_length,
|
171 |
+
win_length=win_length,
|
172 |
+
n_mel_channels=n_mel_channels,
|
173 |
+
target_sample_rate=target_sample_rate,
|
174 |
+
mel_spec_type=mel_spec_type,
|
175 |
+
),
|
176 |
+
odeint_kwargs=dict(
|
177 |
+
method=ode_method,
|
178 |
+
),
|
179 |
+
vocab_char_map=vocab_char_map,
|
180 |
+
controlnet=controlnet,
|
181 |
+
).to(device)
|
182 |
+
|
183 |
+
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
|
184 |
+
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
|
185 |
+
|
186 |
+
return model
|
187 |
+
|
188 |
+
|
189 |
+
def infer_process(
|
190 |
+
ref_audio,
|
191 |
+
ref_text,
|
192 |
+
ref_clip,
|
193 |
+
ref_lip,
|
194 |
+
gen_text,
|
195 |
+
gen_clip,
|
196 |
+
gen_lip,
|
197 |
+
model_obj,
|
198 |
+
vocoder,
|
199 |
+
gen_caption=None,
|
200 |
+
mel_spec_type=mel_spec_type,
|
201 |
+
progress=tqdm,
|
202 |
+
target_rms=target_rms,
|
203 |
+
cross_fade_duration=cross_fade_duration,
|
204 |
+
nfe_step=nfe_step,
|
205 |
+
cfg_strength=cfg_strength,
|
206 |
+
sway_sampling_coef=sway_sampling_coef,
|
207 |
+
speed=speed,
|
208 |
+
fix_duration=fix_duration,
|
209 |
+
device=device,
|
210 |
+
):
|
211 |
+
# Split the input text into batches
|
212 |
+
audio, sr = torchaudio.load(ref_audio)
|
213 |
+
max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
|
214 |
+
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
|
215 |
+
|
216 |
+
return infer_batch_process(
|
217 |
+
(audio, sr),
|
218 |
+
ref_text,
|
219 |
+
ref_clip,
|
220 |
+
ref_lip,
|
221 |
+
gen_text_batches,
|
222 |
+
gen_clip,
|
223 |
+
gen_lip,
|
224 |
+
model_obj,
|
225 |
+
vocoder,
|
226 |
+
gen_caption=gen_caption,
|
227 |
+
mel_spec_type=mel_spec_type,
|
228 |
+
progress=progress,
|
229 |
+
target_rms=target_rms,
|
230 |
+
cross_fade_duration=cross_fade_duration,
|
231 |
+
nfe_step=nfe_step,
|
232 |
+
cfg_strength=cfg_strength,
|
233 |
+
sway_sampling_coef=sway_sampling_coef,
|
234 |
+
speed=speed,
|
235 |
+
fix_duration=fix_duration,
|
236 |
+
device=device,
|
237 |
+
)
|
238 |
+
|
239 |
+
|
240 |
+
# infer batches
|
241 |
+
|
242 |
+
|
243 |
+
def infer_batch_process(
|
244 |
+
ref_audio,
|
245 |
+
ref_text,
|
246 |
+
ref_clip,
|
247 |
+
ref_lip,
|
248 |
+
gen_text_batches,
|
249 |
+
gen_clip,
|
250 |
+
gen_lip,
|
251 |
+
model_obj,
|
252 |
+
vocoder,
|
253 |
+
gen_caption=None,
|
254 |
+
mel_spec_type="vocos",
|
255 |
+
target_rms=0.1,
|
256 |
+
cross_fade_duration=0.15,
|
257 |
+
nfe_step=32,
|
258 |
+
cfg_strength=2.0,
|
259 |
+
sway_sampling_coef=-1,
|
260 |
+
speed=1,
|
261 |
+
fix_duration=None,
|
262 |
+
device=None,
|
263 |
+
):
|
264 |
+
audio, sr = ref_audio
|
265 |
+
if audio.shape[0] > 1:
|
266 |
+
audio = torch.mean(audio, dim=0, keepdim=True)
|
267 |
+
|
268 |
+
rms = torch.sqrt(torch.mean(torch.square(audio)))
|
269 |
+
if rms < target_rms:
|
270 |
+
audio = audio * target_rms / rms
|
271 |
+
if sr != target_sample_rate:
|
272 |
+
resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
|
273 |
+
audio = resampler(audio)
|
274 |
+
audio = audio.to(device)
|
275 |
+
|
276 |
+
generated_waves = []
|
277 |
+
spectrograms = []
|
278 |
+
|
279 |
+
if len(ref_text[-1].encode("utf-8")) == 1:
|
280 |
+
ref_text = ref_text + " "
|
281 |
+
|
282 |
+
for i, gen_text in enumerate(gen_text_batches):
|
283 |
+
# Prepare the text
|
284 |
+
text_list = [ref_text + gen_text]
|
285 |
+
final_text_list = convert_char_to_pinyin(text_list)
|
286 |
+
|
287 |
+
ref_audio_len = audio.shape[-1] // hop_length
|
288 |
+
if ref_clip is not None:
|
289 |
+
ref_clip = F.interpolate(
|
290 |
+
ref_clip.unsqueeze(0).transpose(1, 2), size=ref_audio_len, mode="linear", align_corners=False
|
291 |
+
)
|
292 |
+
else:
|
293 |
+
ref_clip = torch.zeros(1, 768, ref_audio_len).to(device)
|
294 |
+
|
295 |
+
if fix_duration is not None:
|
296 |
+
duration = int(fix_duration * target_sample_rate / hop_length)
|
297 |
+
gen_audio_len = duration - ref_audio_len
|
298 |
+
|
299 |
+
gen_clip = F.interpolate(
|
300 |
+
gen_clip.unsqueeze(0).transpose(1, 2), size=gen_audio_len, mode="linear", align_corners=False
|
301 |
+
)
|
302 |
+
|
303 |
+
else:
|
304 |
+
# Calculate duration
|
305 |
+
ref_text_len = len(ref_text.encode("utf-8"))
|
306 |
+
gen_text_len = len(gen_text.encode("utf-8"))
|
307 |
+
|
308 |
+
gen_audio_len = int(ref_audio_len / ref_text_len * gen_text_len)
|
309 |
+
gen_clip = F.interpolate(
|
310 |
+
gen_clip.unsqueeze(0).transpose(1, 2), size=gen_audio_len, mode="linear", align_corners=False
|
311 |
+
)
|
312 |
+
|
313 |
+
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
|
314 |
+
|
315 |
+
if ref_lip is None:
|
316 |
+
ref_lip = torch.zeros(ref_audio_len // 4, 512)
|
317 |
+
|
318 |
+
clip = torch.cat([ref_clip, gen_clip], dim=-1).permute(0, 2, 1).to(device)
|
319 |
+
|
320 |
+
if gen_lip is not None:
|
321 |
+
lip = torch.cat([ref_lip.unsqueeze(0).transpose(1, 2), gen_lip.unsqueeze(0).transpose(1, 2)], dim=-1).to(
|
322 |
+
device
|
323 |
+
)
|
324 |
+
lip = F.pad(lip, (0, duration - lip.size(-1)), value=0).permute(0, 2, 1)
|
325 |
+
else:
|
326 |
+
lip = None
|
327 |
+
|
328 |
+
# inference
|
329 |
+
with torch.inference_mode():
|
330 |
+
generated, _ = model_obj.sample(
|
331 |
+
cond=audio,
|
332 |
+
text=final_text_list,
|
333 |
+
clip=clip,
|
334 |
+
lip=lip,
|
335 |
+
caption_emb=gen_caption,
|
336 |
+
duration=duration,
|
337 |
+
steps=nfe_step,
|
338 |
+
cfg_strength=cfg_strength,
|
339 |
+
sway_sampling_coef=sway_sampling_coef,
|
340 |
+
no_ref_audio=False,
|
341 |
+
)
|
342 |
+
|
343 |
+
generated = generated.to(torch.float32)
|
344 |
+
generated = generated[:, ref_audio_len:, :]
|
345 |
+
generated_mel_spec = generated.permute(0, 2, 1)
|
346 |
+
if mel_spec_type == "vocos":
|
347 |
+
generated_wave = vocoder.decode(generated_mel_spec)
|
348 |
+
elif mel_spec_type == "bigvgan":
|
349 |
+
generated_wave = vocoder(generated_mel_spec)
|
350 |
+
if rms < target_rms:
|
351 |
+
generated_wave = generated_wave * rms / target_rms
|
352 |
+
|
353 |
+
# wav -> numpy
|
354 |
+
generated_wave = generated_wave.squeeze().cpu().numpy()
|
355 |
+
|
356 |
+
generated_waves.append(generated_wave)
|
357 |
+
spectrograms.append(generated_mel_spec[0].cpu().numpy())
|
358 |
+
|
359 |
+
# Combine all generated waves with cross-fading
|
360 |
+
if cross_fade_duration <= 0:
|
361 |
+
# Simply concatenate
|
362 |
+
final_wave = np.concatenate(generated_waves)
|
363 |
+
else:
|
364 |
+
final_wave = generated_waves[0]
|
365 |
+
for i in range(1, len(generated_waves)):
|
366 |
+
prev_wave = final_wave
|
367 |
+
next_wave = generated_waves[i]
|
368 |
+
|
369 |
+
# Calculate cross-fade samples, ensuring it does not exceed wave lengths
|
370 |
+
cross_fade_samples = int(cross_fade_duration * target_sample_rate)
|
371 |
+
cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
|
372 |
+
|
373 |
+
if cross_fade_samples <= 0:
|
374 |
+
# No overlap possible, concatenate
|
375 |
+
final_wave = np.concatenate([prev_wave, next_wave])
|
376 |
+
continue
|
377 |
+
|
378 |
+
# Overlapping parts
|
379 |
+
prev_overlap = prev_wave[-cross_fade_samples:]
|
380 |
+
next_overlap = next_wave[:cross_fade_samples]
|
381 |
+
|
382 |
+
# Fade out and fade in
|
383 |
+
fade_out = np.linspace(1, 0, cross_fade_samples)
|
384 |
+
fade_in = np.linspace(0, 1, cross_fade_samples)
|
385 |
+
|
386 |
+
# Cross-faded overlap
|
387 |
+
cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
|
388 |
+
|
389 |
+
# Combine
|
390 |
+
new_wave = np.concatenate(
|
391 |
+
[prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]]
|
392 |
+
)
|
393 |
+
|
394 |
+
final_wave = new_wave
|
395 |
+
|
396 |
+
# Create a combined spectrogram
|
397 |
+
combined_spectrogram = np.concatenate(spectrograms, axis=1)
|
398 |
+
|
399 |
+
return final_wave, target_sample_rate, combined_spectrogram
|
src/moviedubber/infer/video_preprocess.py
ADDED
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import glob
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import os.path as osp
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import Optional, Union
|
8 |
+
|
9 |
+
import cv2
|
10 |
+
import imageio
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import torch.multiprocessing as mp
|
14 |
+
from decord import AudioReader, VideoReader, cpu
|
15 |
+
from PIL import Image
|
16 |
+
from tqdm import tqdm
|
17 |
+
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
18 |
+
|
19 |
+
|
20 |
+
logging.basicConfig(level=logging.ERROR, format="%(asctime)s - %(levelname)s - %(message)s")
|
21 |
+
|
22 |
+
|
23 |
+
NUM_FRAMES = None # NUM_FRAMES = 160
|
24 |
+
MAX_FRAMES = None # MAX_FRAMES = 256
|
25 |
+
NUM_FRAMES_PER_SECOND = 10
|
26 |
+
|
27 |
+
|
28 |
+
def get_full_indices(reader: Union[VideoReader, AudioReader]) -> np.ndarray:
|
29 |
+
if isinstance(reader, VideoReader):
|
30 |
+
return np.linspace(0, len(reader) - 1, len(reader), dtype=int)
|
31 |
+
elif isinstance(reader, AudioReader):
|
32 |
+
return np.linspace(0, reader.shape[-1] - 1, reader.shape[-1], dtype=int)
|
33 |
+
|
34 |
+
|
35 |
+
def create_output_directories(output_dir):
|
36 |
+
try:
|
37 |
+
os.makedirs(osp.join(output_dir, "audio"), exist_ok=True)
|
38 |
+
os.makedirs(osp.join(output_dir, "video"), exist_ok=True)
|
39 |
+
except OSError as e:
|
40 |
+
print(f"Error creating directories: {e}")
|
41 |
+
raise
|
42 |
+
|
43 |
+
|
44 |
+
def frame_sample(duration, mode="uniform", num_frames=None, fps=None):
|
45 |
+
if mode == "uniform":
|
46 |
+
assert num_frames is not None, "Number of frames must be provided for uniform sampling."
|
47 |
+
# NOTE: v1 version
|
48 |
+
# Calculate the size of each segment from which a frame will be extracted
|
49 |
+
seg_size = float(duration - 1) / num_frames
|
50 |
+
|
51 |
+
frame_ids = []
|
52 |
+
for i in range(num_frames):
|
53 |
+
# Calculate the start and end indices of each segment
|
54 |
+
start = seg_size * i
|
55 |
+
end = seg_size * (i + 1)
|
56 |
+
# Append the middle index of the segment to the list
|
57 |
+
frame_ids.append((start + end) / 2)
|
58 |
+
|
59 |
+
return np.round(np.array(frame_ids) + 1e-6).astype(int)
|
60 |
+
# NOTE: v0 version
|
61 |
+
# return np.linspace(0, duration-1, num_frames, dtype=int)
|
62 |
+
elif mode == "fps":
|
63 |
+
assert fps is not None, "FPS must be provided for FPS sampling."
|
64 |
+
segment_len = min(fps // NUM_FRAMES_PER_SECOND, duration)
|
65 |
+
return np.arange(segment_len // 2, duration, segment_len, dtype=int)
|
66 |
+
else:
|
67 |
+
raise ImportError(f"Unsupported frame sampling mode: {mode}")
|
68 |
+
|
69 |
+
|
70 |
+
def expand2square(pil_img, background_color):
|
71 |
+
width, height = pil_img.size
|
72 |
+
if width == height:
|
73 |
+
return pil_img
|
74 |
+
elif width > height:
|
75 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
76 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
77 |
+
return result
|
78 |
+
else:
|
79 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
80 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
81 |
+
return result
|
82 |
+
|
83 |
+
|
84 |
+
def process_video(video_path, processor, s=None, e=None, aspect_ratio="pad", num_frames=NUM_FRAMES):
|
85 |
+
if isinstance(video_path, str):
|
86 |
+
if s is not None and e is not None:
|
87 |
+
s = s if s >= 0.0 else 0.0
|
88 |
+
e = e if e >= 0.0 else 0.0
|
89 |
+
if s > e:
|
90 |
+
s, e = e, s
|
91 |
+
elif s == e:
|
92 |
+
e = s + 1
|
93 |
+
|
94 |
+
# 1. Loading Video
|
95 |
+
if os.path.isdir(video_path):
|
96 |
+
frame_files = sorted(os.listdir(video_path))
|
97 |
+
|
98 |
+
fps = 3
|
99 |
+
num_frames_of_video = len(frame_files)
|
100 |
+
elif video_path.endswith(".gif"):
|
101 |
+
gif_reader = imageio.get_reader(video_path)
|
102 |
+
|
103 |
+
fps = 25
|
104 |
+
num_frames_of_video = len(gif_reader)
|
105 |
+
else:
|
106 |
+
try:
|
107 |
+
vreader = VideoReader(video_path, ctx=cpu(0), num_threads=1)
|
108 |
+
except: # noqa: E722
|
109 |
+
return None
|
110 |
+
|
111 |
+
fps = vreader.get_avg_fps()
|
112 |
+
num_frames_of_video = len(vreader)
|
113 |
+
|
114 |
+
# 2. Determine frame range & Calculate frame indices
|
115 |
+
f_start = 0 if s is None else max(int(s * fps) - 1, 0)
|
116 |
+
f_end = num_frames_of_video - 1 if e is None else min(int(e * fps) - 1, num_frames_of_video - 1)
|
117 |
+
frame_indices = list(range(f_start, f_end + 1))
|
118 |
+
|
119 |
+
duration = len(frame_indices)
|
120 |
+
# 3. Sampling frame indices
|
121 |
+
if num_frames is None:
|
122 |
+
sampled_frame_indices = [frame_indices[i] for i in frame_sample(duration, mode="fps", fps=fps)]
|
123 |
+
else:
|
124 |
+
sampled_frame_indices = [
|
125 |
+
frame_indices[i] for i in frame_sample(duration, mode="uniform", num_frames=num_frames)
|
126 |
+
]
|
127 |
+
|
128 |
+
# 4. Acquire frame data
|
129 |
+
if os.path.isdir(video_path):
|
130 |
+
video_data = [Image.open(os.path.join(video_path, frame_files[f_idx])) for f_idx in sampled_frame_indices]
|
131 |
+
elif video_path.endswith(".gif"):
|
132 |
+
video_data = [
|
133 |
+
Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB))
|
134 |
+
for idx, frame in enumerate(gif_reader)
|
135 |
+
if idx in sampled_frame_indices
|
136 |
+
]
|
137 |
+
else:
|
138 |
+
video_data = [Image.fromarray(frame) for frame in vreader.get_batch(sampled_frame_indices).asnumpy()]
|
139 |
+
|
140 |
+
elif isinstance(video_path, np.ndarray):
|
141 |
+
video_data = [Image.fromarray(f) for f in video_path]
|
142 |
+
elif isinstance(video_path, list) and isinstance(video_path[0], np.ndarray):
|
143 |
+
video_data = [Image.fromarray(f) for f in video_path]
|
144 |
+
elif isinstance(video_path, list) and isinstance(video_path[0], str):
|
145 |
+
video_data = [Image.open(f) for f in video_path]
|
146 |
+
elif isinstance(video_path, list) and isinstance(video_path[0], Image.Image):
|
147 |
+
video_data = video_path
|
148 |
+
else:
|
149 |
+
raise ValueError(f"Unsupported video path type: {type(video_path)}")
|
150 |
+
|
151 |
+
while num_frames is not None and len(video_data) < num_frames:
|
152 |
+
video_data.append(Image.fromarray(np.zeros((*video_data[-1].size, 3), dtype=np.uint8)))
|
153 |
+
|
154 |
+
# MAX_FRAMES filter
|
155 |
+
if MAX_FRAMES:
|
156 |
+
video_data = video_data[:MAX_FRAMES]
|
157 |
+
|
158 |
+
if aspect_ratio == "pad":
|
159 |
+
images = [expand2square(f, tuple(int(x * 255) for x in processor.image_mean)) for f in video_data]
|
160 |
+
else:
|
161 |
+
images = list(video_data)
|
162 |
+
video = processor.preprocess(images, return_tensors="pt")["pixel_values"]
|
163 |
+
return video
|
164 |
+
|
165 |
+
|
166 |
+
class VideoFeatureExtractor:
|
167 |
+
def __init__(
|
168 |
+
self,
|
169 |
+
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = "openai/clip-vit-large-patch14",
|
170 |
+
device: str = "cuda",
|
171 |
+
):
|
172 |
+
self.device = device
|
173 |
+
|
174 |
+
self.processor = CLIPImageProcessor.from_pretrained(pretrained_model_name_or_path)
|
175 |
+
self.model = CLIPVisionModelWithProjection.from_pretrained(pretrained_model_name_or_path).to(self.device).half()
|
176 |
+
|
177 |
+
def extract_features(self, video_path):
|
178 |
+
images = process_video(video_path, self.processor)
|
179 |
+
if images is None:
|
180 |
+
return None
|
181 |
+
clip_feature = self.model(images.to(self.device).half()).image_embeds
|
182 |
+
|
183 |
+
return clip_feature
|
184 |
+
|
185 |
+
|
186 |
+
def video_processor(item, feature_extractor, output_dir=None):
|
187 |
+
video_path = Path(item)
|
188 |
+
if not os.path.exists(video_path):
|
189 |
+
return
|
190 |
+
|
191 |
+
clip_feature = feature_extractor.extract_features(str(video_path))
|
192 |
+
if clip_feature is None:
|
193 |
+
return
|
194 |
+
|
195 |
+
if output_dir is not None and not os.path.exists(output_dir):
|
196 |
+
os.makedirs(output_dir, exist_ok=True)
|
197 |
+
output_path = osp.join(output_dir, f"{video_path.stem}.pt")
|
198 |
+
else:
|
199 |
+
output_path = video_path.with_suffix(".clip")
|
200 |
+
|
201 |
+
torch.save(clip_feature, output_path)
|
202 |
+
|
203 |
+
|
204 |
+
def s_thread(items, id, device, output_dir):
|
205 |
+
feature_extractor = VideoFeatureExtractor(device=device)
|
206 |
+
for i, data in tqdm(enumerate(items), total=len(items), position=id):
|
207 |
+
video_processor(data, feature_extractor, output_dir)
|
208 |
+
|
209 |
+
|
210 |
+
def load_tensor(file_path, map_location="cpu", weights_only=True):
|
211 |
+
try:
|
212 |
+
return torch.load(file_path, map_location=map_location, weights_only=weights_only)
|
213 |
+
except FileNotFoundError:
|
214 |
+
logging.error(f"File not found: {file_path}")
|
215 |
+
except torch.serialization.pickle.UnpicklingError:
|
216 |
+
logging.error(f"Failed to unpickle file: {file_path}")
|
217 |
+
except Exception as e:
|
218 |
+
logging.error(f"An error occurred while loading {file_path}: {e}")
|
219 |
+
return None
|
220 |
+
|
221 |
+
|
222 |
+
def post_check(directory):
|
223 |
+
if not osp.isdir(directory):
|
224 |
+
logging.error(f"Invalid directory: {directory}")
|
225 |
+
return
|
226 |
+
|
227 |
+
video_dir = osp.join(directory, "video")
|
228 |
+
pt_files = glob.glob(f"{video_dir}/*.pt")
|
229 |
+
|
230 |
+
for file_path in tqdm(pt_files):
|
231 |
+
embeds = load_tensor(file_path)
|
232 |
+
if embeds is None:
|
233 |
+
continue
|
234 |
+
|
235 |
+
audio_file_path = file_path.replace("video", "audio")
|
236 |
+
audio_text_embeds = load_tensor(audio_file_path)
|
237 |
+
if audio_text_embeds is None:
|
238 |
+
logging.error(f"Failed to load audio file: {audio_file_path}")
|
239 |
+
continue
|
240 |
+
|
241 |
+
text = audio_text_embeds.get("text")
|
242 |
+
mel = audio_text_embeds.get("mel")
|
243 |
+
if text is None or mel is None:
|
244 |
+
logging.error(f"Missing 'text' or 'mel' in {audio_file_path}")
|
245 |
+
|
246 |
+
|
247 |
+
def args_parse():
|
248 |
+
args = argparse.ArgumentParser()
|
249 |
+
args.add_argument("--data_type", "-d", type=str, default="video", help="'audio' or 'video'")
|
250 |
+
args.add_argument("--check", action="store_true", help="post check, if any pt file was damaged")
|
251 |
+
|
252 |
+
args.add_argument(
|
253 |
+
"--num_threads",
|
254 |
+
"-n",
|
255 |
+
type=int,
|
256 |
+
default=1,
|
257 |
+
required=False,
|
258 |
+
help="num_threads",
|
259 |
+
)
|
260 |
+
args.add_argument(
|
261 |
+
"--input",
|
262 |
+
"-i",
|
263 |
+
type=str,
|
264 |
+
required=True,
|
265 |
+
help="input file path",
|
266 |
+
)
|
267 |
+
args.add_argument(
|
268 |
+
"--output_dir",
|
269 |
+
"-o",
|
270 |
+
type=str,
|
271 |
+
default=None,
|
272 |
+
help="output folder path",
|
273 |
+
)
|
274 |
+
args.add_argument("--multi_gpu", "-m", nargs="+", type=str, default=None, required=False, help="GPU ids")
|
275 |
+
|
276 |
+
args = args.parse_args()
|
277 |
+
return args
|
278 |
+
|
279 |
+
|
280 |
+
if __name__ == "__main__":
|
281 |
+
args_main = args_parse()
|
282 |
+
|
283 |
+
if args_main.check:
|
284 |
+
post_check(args_main.output_dir)
|
285 |
+
exit(0)
|
286 |
+
|
287 |
+
gpu_ids = ["cuda:0"]
|
288 |
+
if args_main.multi_gpu is not None:
|
289 |
+
gpu_ids = [f"cuda:{gpu}" for gpu in args_main.multi_gpu]
|
290 |
+
|
291 |
+
output_dir = args_main.output_dir
|
292 |
+
if output_dir is not None:
|
293 |
+
create_output_directories(output_dir)
|
294 |
+
|
295 |
+
rows = None
|
296 |
+
rows = [it.strip() for it in Path(args_main.input).read_text().split("\n") if it.strip() != ""]
|
297 |
+
|
298 |
+
chunks = np.array_split(rows, args_main.num_threads)
|
299 |
+
chunks = [chunk.tolist() for chunk in chunks]
|
300 |
+
|
301 |
+
processes = []
|
302 |
+
mp.set_start_method("spawn", force=True)
|
303 |
+
for idx, chunk in enumerate(chunks):
|
304 |
+
device = gpu_ids[idx % len(gpu_ids)]
|
305 |
+
p = mp.Process(target=s_thread, args=(chunk, idx, device, output_dir))
|
306 |
+
processes.append(p)
|
307 |
+
p.start()
|
308 |
+
|
309 |
+
for process in processes:
|
310 |
+
process.join()
|
311 |
+
|
312 |
+
# DEBUG
|
313 |
+
# s_thread(args_main, input_dir, output_dir, chunks[0], 0, "cuda:0")
|
314 |
+
|
315 |
+
print("process done!")
|
src/moviedubber/infer_with_mmlm_result.py
ADDED
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import os.path as osp
|
4 |
+
import random
|
5 |
+
import sys
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import onnxruntime
|
10 |
+
import soundfile
|
11 |
+
import tomli
|
12 |
+
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
import torchaudio
|
15 |
+
import torchaudio.compliance.kaldi as kaldi
|
16 |
+
from moviepy import AudioFileClip, VideoFileClip
|
17 |
+
from omegaconf import OmegaConf
|
18 |
+
from pydub import AudioSegment
|
19 |
+
from tqdm import tqdm
|
20 |
+
|
21 |
+
|
22 |
+
src_path = Path(osp.dirname(__file__)).parent.parent
|
23 |
+
sys.path.insert(0, str(src_path))
|
24 |
+
sys.path.append(str(src_path / "src/third_party/BigVGAN"))
|
25 |
+
|
26 |
+
from src.moviedubber.infer.utils_infer import (
|
27 |
+
cfg_strength,
|
28 |
+
chunk_text,
|
29 |
+
load_model,
|
30 |
+
load_vocoder,
|
31 |
+
mel_spec_type,
|
32 |
+
nfe_step,
|
33 |
+
sway_sampling_coef,
|
34 |
+
)
|
35 |
+
from src.moviedubber.infer.video_preprocess import VideoFeatureExtractor
|
36 |
+
from src.moviedubber.model import ControlNetDiT, DiT
|
37 |
+
from src.moviedubber.model.utils import convert_char_to_pinyin
|
38 |
+
|
39 |
+
|
40 |
+
def concat_movie_with_audio(wav, video_path, out_dir):
|
41 |
+
if not os.path.exists(wav):
|
42 |
+
raise FileNotFoundError(f"Audio file {wav} does not exist")
|
43 |
+
|
44 |
+
if not os.path.exists(video_path):
|
45 |
+
raise FileNotFoundError(f"Video file {video_path} does not exist")
|
46 |
+
|
47 |
+
try:
|
48 |
+
with (
|
49 |
+
AudioFileClip(str(wav)) as audio_clip,
|
50 |
+
VideoFileClip(str(video_path)) as video_clip,
|
51 |
+
):
|
52 |
+
duration = min(video_clip.duration, audio_clip.duration)
|
53 |
+
|
54 |
+
video_subclip = video_clip.subclipped(0, duration)
|
55 |
+
audio_subclip = audio_clip.subclipped(0, duration)
|
56 |
+
|
57 |
+
final_video = video_subclip.with_audio(audio_subclip)
|
58 |
+
|
59 |
+
output_path = wav.replace(".wav", ".mp4")
|
60 |
+
|
61 |
+
final_video.write_videofile(
|
62 |
+
str(output_path),
|
63 |
+
codec="libx264",
|
64 |
+
audio_codec="mp3",
|
65 |
+
fps=25,
|
66 |
+
logger=None,
|
67 |
+
threads=1,
|
68 |
+
temp_audiofile_path=out_dir,
|
69 |
+
)
|
70 |
+
|
71 |
+
except Exception as e:
|
72 |
+
print(f"Error processing {wav} {video_path}: {str(e)}")
|
73 |
+
|
74 |
+
return output_path
|
75 |
+
|
76 |
+
|
77 |
+
def get_spk_emb(audio_path, ort_session):
|
78 |
+
audio, sample_rate = torchaudio.load(str(audio_path))
|
79 |
+
if sample_rate != 16000:
|
80 |
+
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
|
81 |
+
feat = kaldi.fbank(audio, num_mel_bins=80, dither=0, sample_frequency=16000)
|
82 |
+
feat = feat - feat.mean(dim=0, keepdim=True)
|
83 |
+
embedding = (
|
84 |
+
ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0]
|
85 |
+
.flatten()
|
86 |
+
.tolist()
|
87 |
+
)
|
88 |
+
return embedding
|
89 |
+
|
90 |
+
|
91 |
+
def load_models(config, device):
|
92 |
+
model_cfg = config.get("model_cfg", "src/moviedubber/configs/basemodel.yaml")
|
93 |
+
ckpt_file = config.get("ckpt_file", None)
|
94 |
+
campplus_path = config.get("campplus_path", None)
|
95 |
+
vocab_file = config.get("vocab_file", None)
|
96 |
+
|
97 |
+
vocoder_local_path = config.get("vocoder_local_path", None)
|
98 |
+
|
99 |
+
if ckpt_file is None or vocab_file is None or vocoder_local_path is None or campplus_path is None:
|
100 |
+
raise ValueError("ckpt_file, vocab_file and vocoder_local_path must be specified")
|
101 |
+
|
102 |
+
vocoder_name = config.get("vocoder_name", mel_spec_type)
|
103 |
+
|
104 |
+
vocoder = load_vocoder(local_path=vocoder_local_path, device=device)
|
105 |
+
|
106 |
+
model_cls = DiT
|
107 |
+
model_cfg = OmegaConf.load(model_cfg).model.arch
|
108 |
+
controlnet = ControlNetDiT
|
109 |
+
|
110 |
+
ema_model = load_model(
|
111 |
+
model_cls,
|
112 |
+
model_cfg,
|
113 |
+
ckpt_file,
|
114 |
+
mel_spec_type=vocoder_name,
|
115 |
+
vocab_file=vocab_file,
|
116 |
+
controlnet=controlnet,
|
117 |
+
device=device,
|
118 |
+
)
|
119 |
+
|
120 |
+
option = onnxruntime.SessionOptions()
|
121 |
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
122 |
+
option.intra_op_num_threads = 1
|
123 |
+
providers = ["CPUExecutionProvider"]
|
124 |
+
ort_session = onnxruntime.InferenceSession(
|
125 |
+
campplus_path,
|
126 |
+
sess_options=option,
|
127 |
+
providers=providers,
|
128 |
+
)
|
129 |
+
return ema_model, vocoder, ort_session
|
130 |
+
|
131 |
+
|
132 |
+
def main(config, device, chunk, gen_dir, target_dir, out_dir, idx):
|
133 |
+
ema_model, vocoder, ort_session = load_models(config, device=device)
|
134 |
+
|
135 |
+
videofeature_extractor = VideoFeatureExtractor(device=device)
|
136 |
+
|
137 |
+
for it in tqdm(chunk, total=len(chunk), position=idx, desc=f"Processing {idx}"):
|
138 |
+
wav, video, text, ref_wav = it
|
139 |
+
|
140 |
+
with open(f"{target_dir}/{wav.split('/')[-1].split('.')[0]}.txt", "a") as f:
|
141 |
+
f.write(text + "\n")
|
142 |
+
|
143 |
+
if wav.endswith(".mp3"):
|
144 |
+
audio = AudioSegment.from_mp3(wav)
|
145 |
+
|
146 |
+
wav_file = wav.replace(".mp3", ".wav")
|
147 |
+
audio.export(wav_file, format="wav")
|
148 |
+
|
149 |
+
wav = Path(wav).with_suffix(".wav")
|
150 |
+
if wav.exists() is False:
|
151 |
+
continue
|
152 |
+
|
153 |
+
os.system(f"cp {wav} {target_dir}/")
|
154 |
+
|
155 |
+
gen_audio, sr = torchaudio.load(str(wav))
|
156 |
+
resampler = torchaudio.transforms.Resample(sr, 24000)
|
157 |
+
if sr != 24000:
|
158 |
+
gen_audio = resampler(gen_audio)
|
159 |
+
|
160 |
+
if gen_audio.shape[0] > 1:
|
161 |
+
gen_audio = torch.mean(gen_audio, dim=0, keepdim=True)
|
162 |
+
|
163 |
+
gen_video = video
|
164 |
+
gen_clip_path = gen_video.replace(".mp4", ".clip")
|
165 |
+
|
166 |
+
if not os.path.exists(gen_clip_path):
|
167 |
+
gen_clip = videofeature_extractor.extract_features(gen_video)
|
168 |
+
|
169 |
+
torch.save(gen_clip.detach().cpu(), gen_clip_path)
|
170 |
+
|
171 |
+
else:
|
172 |
+
gen_clip = torch.load(gen_clip_path, weights_only=True).to(device=device, dtype=torch.float32)
|
173 |
+
|
174 |
+
if ref_wav == "None":
|
175 |
+
use_ref_audio = False
|
176 |
+
gen_text_ = text
|
177 |
+
|
178 |
+
gen_clip_ = gen_clip
|
179 |
+
|
180 |
+
ref_audio_ = gen_audio
|
181 |
+
|
182 |
+
spk_emb = torch.zeros(1, 1, 192).to(device=device, dtype=torch.float32)
|
183 |
+
|
184 |
+
else:
|
185 |
+
use_ref_audio = True
|
186 |
+
ref_audio = Path(ref_wav)
|
187 |
+
|
188 |
+
spk_emb = get_spk_emb(ref_audio, ort_session)
|
189 |
+
spk_emb = torch.tensor(spk_emb).to(device=device, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
|
190 |
+
|
191 |
+
ref_text = ref_audio.with_suffix(".txt").read_text().strip()
|
192 |
+
gen_text_ = ref_text + " " + text
|
193 |
+
|
194 |
+
if ref_audio.exists() is False:
|
195 |
+
raise Exception(f"ref_audio {ref_audio} not found")
|
196 |
+
|
197 |
+
if ref_audio.suffix == ".mp3":
|
198 |
+
audio = AudioSegment.from_mp3(ref_audio)
|
199 |
+
|
200 |
+
wav_file = ref_audio.with_suffix(".wav")
|
201 |
+
audio.export(wav_file, format="wav")
|
202 |
+
|
203 |
+
ref_audio_, _ = torchaudio.load(str(ref_audio.with_suffix(".wav")))
|
204 |
+
resampler = torchaudio.transforms.Resample(sr, 24000)
|
205 |
+
if sr != 24000:
|
206 |
+
ref_audio_ = resampler(ref_audio_)
|
207 |
+
|
208 |
+
if ref_audio_.shape[0] > 1:
|
209 |
+
ref_audio_ = torch.mean(ref_audio_, dim=0, keepdim=True)
|
210 |
+
|
211 |
+
ref_video = ref_audio.with_suffix(".mp4")
|
212 |
+
ref_clip_path = ref_video.with_suffix(".clip")
|
213 |
+
|
214 |
+
if not ref_clip_path.exists():
|
215 |
+
ref_clip = videofeature_extractor.extract_features(str(ref_video))
|
216 |
+
|
217 |
+
torch.save(ref_clip.detach().cpu(), ref_clip_path)
|
218 |
+
|
219 |
+
else:
|
220 |
+
ref_clip = torch.load(ref_clip_path, weights_only=True).to(device=device, dtype=torch.float32)
|
221 |
+
|
222 |
+
gen_clip_ = torch.cat([ref_clip, gen_clip], dim=0)
|
223 |
+
|
224 |
+
gen_audio_len = gen_audio.shape[1] // 256
|
225 |
+
|
226 |
+
if use_ref_audio:
|
227 |
+
ref_audio_len = ref_audio_.shape[1] // 256
|
228 |
+
duration = ref_audio_len + gen_audio_len
|
229 |
+
else:
|
230 |
+
duration = gen_audio_len
|
231 |
+
|
232 |
+
gen_clip_ = gen_clip_.unsqueeze(0).to(device=device, dtype=torch.float32).transpose(1, 2)
|
233 |
+
gen_clip_ = F.interpolate(gen_clip_, size=duration, mode="linear", align_corners=False).transpose(1, 2)
|
234 |
+
|
235 |
+
gen_text_batches = chunk_text(gen_text_, max_chars=1024)
|
236 |
+
final_text_list = convert_char_to_pinyin(gen_text_batches)
|
237 |
+
|
238 |
+
with torch.inference_mode():
|
239 |
+
generated, _ = ema_model.sample(
|
240 |
+
cond=ref_audio_.to(device),
|
241 |
+
text=final_text_list,
|
242 |
+
clip=gen_clip_,
|
243 |
+
spk_emb=spk_emb,
|
244 |
+
duration=duration,
|
245 |
+
steps=nfe_step,
|
246 |
+
cfg_strength=cfg_strength,
|
247 |
+
sway_sampling_coef=sway_sampling_coef,
|
248 |
+
no_ref_audio=not use_ref_audio,
|
249 |
+
)
|
250 |
+
|
251 |
+
generated = generated.to(torch.float32)
|
252 |
+
|
253 |
+
if use_ref_audio:
|
254 |
+
generated = generated[:, ref_audio_len:, :]
|
255 |
+
|
256 |
+
generated_mel_spec = generated.permute(0, 2, 1)
|
257 |
+
generated_wave = vocoder(generated_mel_spec)
|
258 |
+
|
259 |
+
generated_wave = generated_wave.squeeze().cpu().numpy()
|
260 |
+
|
261 |
+
out_path = osp.join(gen_dir, f"{wav.stem}.wav")
|
262 |
+
soundfile.write(out_path, generated_wave, samplerate=24000)
|
263 |
+
_ = concat_movie_with_audio(out_path, gen_video, out_dir)
|
264 |
+
|
265 |
+
|
266 |
+
if __name__ == "__main__":
|
267 |
+
import torch.multiprocessing as mp
|
268 |
+
|
269 |
+
parser = argparse.ArgumentParser(
|
270 |
+
prog="python3 infer-cli.py",
|
271 |
+
description="Commandline interface for moviedubber infer with Advanced Batch Processing.",
|
272 |
+
epilog="Specify options above to override one or more settings from config.",
|
273 |
+
)
|
274 |
+
parser.add_argument(
|
275 |
+
"-c",
|
276 |
+
"--config",
|
277 |
+
type=str,
|
278 |
+
default="src/moviedubber/infer/basic.toml",
|
279 |
+
help="The configuration file, default see infer/basic.toml",
|
280 |
+
)
|
281 |
+
parser.add_argument("-i", "--input_list", type=str, required=True, help="The val list file")
|
282 |
+
parser.add_argument("-s", "--ref_spk_list", type=str, required=True, help="The spk list file")
|
283 |
+
parser.add_argument("-o", "--out_dir", type=str, default="data/dubberout", help="The output directory")
|
284 |
+
parser.add_argument("--gpuids", type=str, help="GPU ids to use, split by comma")
|
285 |
+
parser.add_argument("--nums_workers", type=int, default=1, help="Number of workers for per gpu")
|
286 |
+
|
287 |
+
args = parser.parse_args()
|
288 |
+
|
289 |
+
out_dir = args.out_dir
|
290 |
+
input_list = args.input_list
|
291 |
+
gpu_ids = args.gpuids.split(",") if args.gpuids else ["0"]
|
292 |
+
num_pre = args.nums_workers
|
293 |
+
spk_ref_path = args.ref_spk_list
|
294 |
+
|
295 |
+
config = tomli.load(open(args.config, "rb"))
|
296 |
+
|
297 |
+
gen_lst = Path(input_list).read_text().splitlines()[1:]
|
298 |
+
|
299 |
+
gen_pre_conf = []
|
300 |
+
|
301 |
+
spk_lines = Path(spk_ref_path).read_text().splitlines()
|
302 |
+
|
303 |
+
for idx, line in enumerate(gen_lst):
|
304 |
+
if line.strip():
|
305 |
+
mp4_path, is_correc, _, _ = line.split(",")
|
306 |
+
|
307 |
+
wav_path = mp4_path.replace(".mp4", ".mp3")
|
308 |
+
text = Path(wav_path.replace(".mp3", ".txt")).read_text().strip()
|
309 |
+
|
310 |
+
if is_correc == "True":
|
311 |
+
ref_wav = spk_lines[idx].split(",")[1].strip()
|
312 |
+
else:
|
313 |
+
ref_wav = random.choice(spk_lines).split(",")[-1].strip() # Use random speaker for incorrect samples
|
314 |
+
|
315 |
+
gen_pre_conf.append([wav_path, mp4_path, text, ref_wav])
|
316 |
+
|
317 |
+
chunks = np.array_split(gen_pre_conf, len(gpu_ids) * num_pre)
|
318 |
+
|
319 |
+
gen_dir = os.path.join(out_dir, "generated")
|
320 |
+
target_dir = os.path.join(out_dir, "target")
|
321 |
+
|
322 |
+
if os.path.exists(gen_dir) is False or os.path.exists(target_dir) is False:
|
323 |
+
os.makedirs(gen_dir)
|
324 |
+
os.makedirs(target_dir)
|
325 |
+
|
326 |
+
mp.set_start_method("spawn", force=True)
|
327 |
+
processes = []
|
328 |
+
for idx, chunk in enumerate(chunks):
|
329 |
+
device = gpu_ids[idx % len(gpu_ids)]
|
330 |
+
|
331 |
+
device = f"cuda:{device}"
|
332 |
+
p = mp.Process(target=main, args=(config, device, chunk, gen_dir, target_dir, out_dir, idx))
|
333 |
+
processes.append(p)
|
334 |
+
p.start()
|
335 |
+
|
336 |
+
for process in processes:
|
337 |
+
process.join()
|
338 |
+
|
339 |
+
print("All processes finished.")
|
src/moviedubber/model/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .cfm import CFM
|
2 |
+
from .dit import ControlNetDiT, DiT
|
3 |
+
|
4 |
+
|
5 |
+
__all__ = ["CFM", "UNetT", "DiT", "ControlNetDiT", "MMDiT", "Trainer"]
|
src/moviedubber/model/cfm.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/SWivid/F5-TTS/blob/main/src/f5_tts/model/cfm.py
|
2 |
+
|
3 |
+
"""
|
4 |
+
ein notation:
|
5 |
+
b - batch
|
6 |
+
n - sequence
|
7 |
+
nt - text sequence
|
8 |
+
nw - raw wave length
|
9 |
+
d - dimension
|
10 |
+
"""
|
11 |
+
|
12 |
+
from __future__ import annotations
|
13 |
+
|
14 |
+
from typing import Callable
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from torch import nn
|
19 |
+
from torch.nn.utils.rnn import pad_sequence
|
20 |
+
from torchdiffeq import odeint
|
21 |
+
|
22 |
+
from .modules import MelSpec
|
23 |
+
from .utils import (
|
24 |
+
default,
|
25 |
+
exists,
|
26 |
+
lens_to_mask,
|
27 |
+
list_str_to_idx,
|
28 |
+
list_str_to_tensor,
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
class CFM(nn.Module):
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
transformer: nn.Module,
|
36 |
+
sigma=0.0,
|
37 |
+
odeint_kwargs: dict = dict(
|
38 |
+
method="euler" # 'midpoint'
|
39 |
+
),
|
40 |
+
num_channels=None,
|
41 |
+
mel_spec_module: nn.Module | None = None,
|
42 |
+
mel_spec_kwargs: dict = dict(),
|
43 |
+
frac_lengths_mask: tuple[float, float] = (0.7, 1.0),
|
44 |
+
vocab_char_map: dict[str:int] | None = None,
|
45 |
+
controlnet: nn.Module | None = None,
|
46 |
+
):
|
47 |
+
super().__init__()
|
48 |
+
|
49 |
+
self.frac_lengths_mask = frac_lengths_mask
|
50 |
+
|
51 |
+
self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs))
|
52 |
+
num_channels = default(num_channels, self.mel_spec.n_mel_channels)
|
53 |
+
self.num_channels = num_channels
|
54 |
+
|
55 |
+
self.transformer = transformer
|
56 |
+
dim = transformer.dim
|
57 |
+
self.dim = dim
|
58 |
+
|
59 |
+
self.sigma = sigma
|
60 |
+
|
61 |
+
self.odeint_kwargs = odeint_kwargs
|
62 |
+
|
63 |
+
self.vocab_char_map = vocab_char_map
|
64 |
+
|
65 |
+
self.controlnet = controlnet
|
66 |
+
|
67 |
+
@property
|
68 |
+
def device(self):
|
69 |
+
return next(self.parameters()).device
|
70 |
+
|
71 |
+
@torch.no_grad()
|
72 |
+
def sample(
|
73 |
+
self,
|
74 |
+
cond: float["b n d"] | float["b nw"], # noqa: F722
|
75 |
+
text: int["b nt"] | list[str], # noqa: F722
|
76 |
+
clip: float["b n d"], # noqa: F722
|
77 |
+
duration: int | int["b"], # noqa: F821
|
78 |
+
*,
|
79 |
+
caption_emb: float["b n d"] | None = None, # noqa: F722
|
80 |
+
spk_emb: float["b n d"] | None = None, # noqa: F722
|
81 |
+
lens: int["b"] | None = None, # noqa: F821
|
82 |
+
steps=32,
|
83 |
+
cfg_strength=1.0,
|
84 |
+
sway_sampling_coef=None,
|
85 |
+
seed: int | None = None,
|
86 |
+
max_duration=4096,
|
87 |
+
vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722
|
88 |
+
no_ref_audio=False,
|
89 |
+
duplicate_test=False,
|
90 |
+
t_inter=0.1,
|
91 |
+
edit_mask=None,
|
92 |
+
):
|
93 |
+
self.eval()
|
94 |
+
|
95 |
+
if cond.ndim == 2:
|
96 |
+
cond = self.mel_spec(cond)
|
97 |
+
cond = cond.permute(0, 2, 1)
|
98 |
+
assert cond.shape[-1] == self.num_channels
|
99 |
+
|
100 |
+
cond = cond.to(next(self.parameters()).dtype)
|
101 |
+
|
102 |
+
batch, cond_seq_len, device = *cond.shape[:2], cond.device
|
103 |
+
if not exists(lens):
|
104 |
+
lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long)
|
105 |
+
|
106 |
+
if isinstance(text, list):
|
107 |
+
if exists(self.vocab_char_map):
|
108 |
+
text = list_str_to_idx(text, self.vocab_char_map).to(device)
|
109 |
+
else:
|
110 |
+
text = list_str_to_tensor(text).to(device)
|
111 |
+
assert text.shape[0] == batch
|
112 |
+
|
113 |
+
if exists(text):
|
114 |
+
text_lens = (text != -1).sum(dim=-1)
|
115 |
+
lens = torch.maximum(text_lens, lens)
|
116 |
+
|
117 |
+
cond_mask = lens_to_mask(lens)
|
118 |
+
if edit_mask is not None:
|
119 |
+
cond_mask = cond_mask & edit_mask
|
120 |
+
|
121 |
+
if isinstance(duration, int):
|
122 |
+
duration = torch.full((batch,), duration, device=device, dtype=torch.long)
|
123 |
+
|
124 |
+
# duration = torch.maximum(lens + 1, duration)
|
125 |
+
|
126 |
+
duration = duration.clamp(max=max_duration)
|
127 |
+
max_duration = duration.amax()
|
128 |
+
|
129 |
+
if duplicate_test:
|
130 |
+
test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0)
|
131 |
+
|
132 |
+
cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0)
|
133 |
+
cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False)
|
134 |
+
cond_mask = cond_mask.unsqueeze(-1)
|
135 |
+
step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
|
136 |
+
|
137 |
+
if batch > 1:
|
138 |
+
mask = lens_to_mask(duration)
|
139 |
+
else:
|
140 |
+
mask = None
|
141 |
+
|
142 |
+
if no_ref_audio:
|
143 |
+
cond = torch.zeros_like(cond)
|
144 |
+
|
145 |
+
def fn(t, x):
|
146 |
+
step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
|
147 |
+
|
148 |
+
controlnet_embeds = self.controlnet(
|
149 |
+
x=x,
|
150 |
+
text=text,
|
151 |
+
clip=clip,
|
152 |
+
spk_emb=spk_emb,
|
153 |
+
caption=caption_emb,
|
154 |
+
time=t,
|
155 |
+
)
|
156 |
+
|
157 |
+
cond_pred = self.transformer(
|
158 |
+
x=x,
|
159 |
+
cond=step_cond,
|
160 |
+
text=text,
|
161 |
+
time=t,
|
162 |
+
mask=mask,
|
163 |
+
drop_audio_cond=[False],
|
164 |
+
drop_text=[False],
|
165 |
+
controlnet_embeds=controlnet_embeds,
|
166 |
+
)
|
167 |
+
|
168 |
+
null_pred = self.transformer(
|
169 |
+
x=x,
|
170 |
+
cond=step_cond,
|
171 |
+
text=text,
|
172 |
+
time=t,
|
173 |
+
mask=mask,
|
174 |
+
drop_audio_cond=[True],
|
175 |
+
drop_text=[True],
|
176 |
+
controlnet_embeds=None,
|
177 |
+
)
|
178 |
+
|
179 |
+
return null_pred + (cond_pred - null_pred) * 2
|
180 |
+
|
181 |
+
y0 = []
|
182 |
+
for dur in duration:
|
183 |
+
if exists(seed):
|
184 |
+
torch.manual_seed(seed)
|
185 |
+
y0.append(torch.randn(dur, self.num_channels, device=self.device, dtype=step_cond.dtype))
|
186 |
+
y0 = pad_sequence(y0, padding_value=0, batch_first=True)
|
187 |
+
|
188 |
+
t_start = 0
|
189 |
+
|
190 |
+
if duplicate_test:
|
191 |
+
t_start = t_inter
|
192 |
+
y0 = (1 - t_start) * y0 + t_start * test_cond
|
193 |
+
steps = int(steps * (1 - t_start))
|
194 |
+
|
195 |
+
t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype)
|
196 |
+
if sway_sampling_coef is not None:
|
197 |
+
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
|
198 |
+
|
199 |
+
trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
|
200 |
+
|
201 |
+
sampled = trajectory[-1]
|
202 |
+
out = sampled
|
203 |
+
out = torch.where(cond_mask, cond, out)
|
204 |
+
|
205 |
+
if exists(vocoder):
|
206 |
+
out = out.permute(0, 2, 1)
|
207 |
+
out = vocoder(out)
|
208 |
+
|
209 |
+
return out, trajectory
|
src/moviedubber/model/dit.py
ADDED
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/SWivid/F5-TTS/blob/main/src/f5_tts/model/backbones/dit.py
|
2 |
+
|
3 |
+
"""
|
4 |
+
ein notation:
|
5 |
+
b - batch
|
6 |
+
n - sequence
|
7 |
+
nt - text sequence
|
8 |
+
nw - raw wave length
|
9 |
+
d - dimension
|
10 |
+
"""
|
11 |
+
|
12 |
+
from __future__ import annotations
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn.functional as F
|
16 |
+
from torch import nn
|
17 |
+
from x_transformers.x_transformers import RotaryEmbedding
|
18 |
+
|
19 |
+
from .modules import (
|
20 |
+
AdaLayerNormZero_Final,
|
21 |
+
ConvNeXtV2Block,
|
22 |
+
ConvPositionEmbedding,
|
23 |
+
DiTBlock,
|
24 |
+
TimestepEmbedding,
|
25 |
+
get_pos_embed_indices,
|
26 |
+
precompute_freqs_cis,
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
# Text embedding
|
31 |
+
|
32 |
+
|
33 |
+
class TextEmbedding(nn.Module):
|
34 |
+
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
|
35 |
+
super().__init__()
|
36 |
+
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
|
37 |
+
|
38 |
+
if conv_layers > 0:
|
39 |
+
self.extra_modeling = True
|
40 |
+
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
41 |
+
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
|
42 |
+
self.text_blocks = nn.Sequential(
|
43 |
+
*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
|
44 |
+
)
|
45 |
+
else:
|
46 |
+
self.extra_modeling = False
|
47 |
+
|
48 |
+
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
|
49 |
+
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
50 |
+
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
51 |
+
batch, text_len = text.shape[0], text.shape[1]
|
52 |
+
text = F.pad(text, (0, seq_len - text_len), value=0)
|
53 |
+
|
54 |
+
for idx, _drop in enumerate(drop_text): # cfg for text
|
55 |
+
if _drop:
|
56 |
+
text[idx] = torch.zeros_like(text[idx])
|
57 |
+
|
58 |
+
text = self.text_embed(text) # b n -> b n d
|
59 |
+
|
60 |
+
# possible extra modeling
|
61 |
+
if self.extra_modeling:
|
62 |
+
# sinus pos emb
|
63 |
+
batch_start = torch.zeros((batch,), dtype=torch.long)
|
64 |
+
pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
|
65 |
+
text_pos_embed = self.freqs_cis[pos_idx]
|
66 |
+
text = text + text_pos_embed
|
67 |
+
|
68 |
+
# convnextv2 blocks
|
69 |
+
text = self.text_blocks(text)
|
70 |
+
|
71 |
+
return text
|
72 |
+
|
73 |
+
|
74 |
+
# noised input audio and context mixing embedding
|
75 |
+
|
76 |
+
|
77 |
+
class InputEmbedding(nn.Module):
|
78 |
+
def __init__(self, mel_dim, text_dim, out_dim):
|
79 |
+
super().__init__()
|
80 |
+
self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
|
81 |
+
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
|
82 |
+
|
83 |
+
def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
|
84 |
+
for idx, _drop in enumerate(drop_audio_cond): # cfg for cond audio
|
85 |
+
if _drop:
|
86 |
+
cond[idx] = torch.zeros_like(cond[idx])
|
87 |
+
|
88 |
+
x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
|
89 |
+
x = self.conv_pos_embed(x) + x
|
90 |
+
return x
|
91 |
+
|
92 |
+
|
93 |
+
class InputEmbeddingO(nn.Module):
|
94 |
+
def __init__(self, mel_dim, text_dim, out_dim):
|
95 |
+
super().__init__()
|
96 |
+
self.proj = nn.Linear(mel_dim + 512 + text_dim + 192 + 32, out_dim)
|
97 |
+
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
|
98 |
+
|
99 |
+
def forward(
|
100 |
+
self,
|
101 |
+
x: float["b n d"], # noqa: F722
|
102 |
+
text_emb: float["b n d"], # noqa: F722
|
103 |
+
video_emb: float["b n d"], # noqa: F722
|
104 |
+
spk_emb: float["b n d"], # noqa: F722
|
105 |
+
caption_emb: float["b n d"], # noqa: F722
|
106 |
+
):
|
107 |
+
x = self.proj(torch.cat((x, text_emb, video_emb, spk_emb, caption_emb), dim=-1))
|
108 |
+
x = self.conv_pos_embed(x) + x
|
109 |
+
return x
|
110 |
+
|
111 |
+
|
112 |
+
# Transformer backbone using DiT blocks
|
113 |
+
|
114 |
+
|
115 |
+
class DiT(nn.Module):
|
116 |
+
def __init__(
|
117 |
+
self,
|
118 |
+
*,
|
119 |
+
dim,
|
120 |
+
depth=8,
|
121 |
+
heads=8,
|
122 |
+
dim_head=64,
|
123 |
+
dropout=0.1,
|
124 |
+
ff_mult=4,
|
125 |
+
mel_dim=100,
|
126 |
+
text_num_embeds=256,
|
127 |
+
text_dim=None,
|
128 |
+
conv_layers=0,
|
129 |
+
long_skip_connection=False,
|
130 |
+
):
|
131 |
+
super().__init__()
|
132 |
+
|
133 |
+
self.time_embed = TimestepEmbedding(dim)
|
134 |
+
if text_dim is None:
|
135 |
+
text_dim = mel_dim
|
136 |
+
self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
|
137 |
+
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
|
138 |
+
|
139 |
+
self.rotary_embed = RotaryEmbedding(dim_head)
|
140 |
+
|
141 |
+
self.dim = dim
|
142 |
+
self.depth = depth
|
143 |
+
|
144 |
+
self.transformer_blocks = nn.ModuleList(
|
145 |
+
[DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)]
|
146 |
+
)
|
147 |
+
self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
|
148 |
+
|
149 |
+
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
|
150 |
+
self.proj_out = nn.Linear(dim, mel_dim)
|
151 |
+
|
152 |
+
def forward(
|
153 |
+
self,
|
154 |
+
x: float["b n d"], # nosied input audio # noqa: F722
|
155 |
+
cond: float["b n d"], # masked cond audio # noqa: F722
|
156 |
+
text: int["b nt"], # text # noqa: F722
|
157 |
+
time: float["b"] | float[""], # time step # noqa: F821 F722
|
158 |
+
drop_audio_cond, # cfg for cond audio
|
159 |
+
drop_text, # cfg for text
|
160 |
+
mask: bool["b n"] | None = None, # noqa: F722
|
161 |
+
controlnet_embeds: float["b n d"] | None = None, # noqa: F722
|
162 |
+
):
|
163 |
+
batch, seq_len = x.shape[0], x.shape[1]
|
164 |
+
if time.ndim == 0:
|
165 |
+
time = time.repeat(batch)
|
166 |
+
|
167 |
+
t = self.time_embed(time)
|
168 |
+
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
|
169 |
+
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
170 |
+
|
171 |
+
rope = self.rotary_embed.forward_from_seq_len(seq_len)
|
172 |
+
|
173 |
+
if self.long_skip_connection is not None:
|
174 |
+
residual = x
|
175 |
+
|
176 |
+
for i, block in enumerate(self.transformer_blocks):
|
177 |
+
if controlnet_embeds is not None and i < 12:
|
178 |
+
x += controlnet_embeds[i]
|
179 |
+
|
180 |
+
x = block(x, t, mask=mask, rope=rope)
|
181 |
+
|
182 |
+
if self.long_skip_connection is not None:
|
183 |
+
x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
|
184 |
+
|
185 |
+
x = self.norm_out(x, t)
|
186 |
+
output = self.proj_out(x)
|
187 |
+
|
188 |
+
return output
|
189 |
+
|
190 |
+
|
191 |
+
class ControlNetDiT(nn.Module):
|
192 |
+
def __init__(
|
193 |
+
self,
|
194 |
+
*,
|
195 |
+
dim,
|
196 |
+
depth=8,
|
197 |
+
heads=8,
|
198 |
+
dim_head=64,
|
199 |
+
dropout=0.1,
|
200 |
+
ff_mult=4,
|
201 |
+
mel_dim=100,
|
202 |
+
text_num_embeds=256,
|
203 |
+
text_dim=None,
|
204 |
+
conv_layers=0,
|
205 |
+
long_skip_connection=False,
|
206 |
+
checkpoint_activations=False,
|
207 |
+
duration_predictor=None,
|
208 |
+
):
|
209 |
+
super().__init__()
|
210 |
+
|
211 |
+
if text_dim is None:
|
212 |
+
text_dim = mel_dim
|
213 |
+
|
214 |
+
self.time_embed = TimestepEmbedding(dim)
|
215 |
+
|
216 |
+
self.rotary_embed = RotaryEmbedding(dim_head)
|
217 |
+
|
218 |
+
self.dim = dim
|
219 |
+
self.depth = depth // 2 + 1
|
220 |
+
|
221 |
+
self.transformer_blocks1 = nn.ModuleList(
|
222 |
+
[
|
223 |
+
DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout)
|
224 |
+
for _ in range(self.depth)
|
225 |
+
]
|
226 |
+
)
|
227 |
+
|
228 |
+
self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
|
229 |
+
|
230 |
+
self.input_embed = InputEmbeddingO(mel_dim, text_dim, dim)
|
231 |
+
|
232 |
+
self.spk_embed_affine_layer = torch.nn.Linear(192, 192)
|
233 |
+
self.clip_embed_affine_layer = torch.nn.Linear(768, 512)
|
234 |
+
self.caption_embed_affine_layer = torch.nn.Linear(512, 32)
|
235 |
+
|
236 |
+
self.zero_linear = nn.ModuleList([nn.Linear(dim, dim, bias=False) for _ in range(12)])
|
237 |
+
for zero_linear in self.zero_linear:
|
238 |
+
nn.init.zeros_(zero_linear.weight)
|
239 |
+
|
240 |
+
self.duration_predictor = duration_predictor
|
241 |
+
|
242 |
+
def forward(
|
243 |
+
self,
|
244 |
+
x: float["b n d"], # nosied input audio # noqa: F722
|
245 |
+
text: int["b nt"], # text # noqa: F722
|
246 |
+
clip: float["b n d"], # video clip # noqa: F722
|
247 |
+
spk_emb: float["b d"], # speaker embedding # noqa: F722
|
248 |
+
time: float["b"] | float[""], # time step # noqa: F821 F722
|
249 |
+
caption: float["b nt"] | None = None, # caption # noqa: F722
|
250 |
+
mask: bool["b n"] | None = None, # noqa: F722
|
251 |
+
lens: int["b"] | None = None, # noqa: F722, F821
|
252 |
+
return_dur: bool = False, # return duration prediction
|
253 |
+
):
|
254 |
+
batch, seq_len = x.shape[0], x.shape[1]
|
255 |
+
|
256 |
+
if time.ndim == 0:
|
257 |
+
time = time.repeat(batch)
|
258 |
+
|
259 |
+
t = self.time_embed(time)
|
260 |
+
|
261 |
+
clip_emb = F.normalize(clip, dim=-1)
|
262 |
+
clip_emb = self.clip_embed_affine_layer(clip)
|
263 |
+
|
264 |
+
spk_emb = F.normalize(spk_emb, dim=-1)
|
265 |
+
spk_emb = self.spk_embed_affine_layer(spk_emb)
|
266 |
+
spk_emb = torch.repeat_interleave(spk_emb, seq_len, dim=1)
|
267 |
+
|
268 |
+
if caption is None:
|
269 |
+
caption = torch.zeros(1, seq_len, 512).to(device=x.device)
|
270 |
+
|
271 |
+
caption_emb = F.normalize(caption, dim=-1)
|
272 |
+
caption_emb = self.caption_embed_affine_layer(caption_emb)
|
273 |
+
|
274 |
+
text_embed = self.text_embed(text, seq_len, drop_text=[False])
|
275 |
+
|
276 |
+
x = self.input_embed(x, text_embed, clip_emb, spk_emb, caption_emb)
|
277 |
+
|
278 |
+
rope = self.rotary_embed.forward_from_seq_len(seq_len)
|
279 |
+
|
280 |
+
info = []
|
281 |
+
for i, block in enumerate(self.transformer_blocks1):
|
282 |
+
x = block(x, t, mask=mask, rope=rope) # 'b n 1024'
|
283 |
+
|
284 |
+
info.append(x)
|
285 |
+
|
286 |
+
out_info = []
|
287 |
+
for i, linear in enumerate(self.zero_linear):
|
288 |
+
h = linear(info[i])
|
289 |
+
out_info.append(h)
|
290 |
+
|
291 |
+
if return_dur and self.duration_predictor is not None:
|
292 |
+
dur_loss = self.duration_predictor(x=x, text=clip_emb, lens=lens)
|
293 |
+
|
294 |
+
return out_info, dur_loss
|
295 |
+
|
296 |
+
else:
|
297 |
+
return out_info
|
src/moviedubber/model/modules.py
ADDED
@@ -0,0 +1,467 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
ein notation:
|
3 |
+
b - batch
|
4 |
+
n - sequence
|
5 |
+
nt - text sequence
|
6 |
+
nw - raw wave length
|
7 |
+
d - dimension
|
8 |
+
"""
|
9 |
+
|
10 |
+
from __future__ import annotations
|
11 |
+
|
12 |
+
import math
|
13 |
+
from typing import Optional
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from librosa.filters import mel as librosa_mel_fn
|
18 |
+
from torch import nn
|
19 |
+
from x_transformers.x_transformers import apply_rotary_pos_emb
|
20 |
+
|
21 |
+
|
22 |
+
# raw wav to mel spec
|
23 |
+
|
24 |
+
|
25 |
+
mel_basis_cache = {}
|
26 |
+
hann_window_cache = {}
|
27 |
+
|
28 |
+
|
29 |
+
def get_bigvgan_mel_spectrogram(
|
30 |
+
waveform,
|
31 |
+
n_fft=1024,
|
32 |
+
n_mel_channels=100,
|
33 |
+
target_sample_rate=24000,
|
34 |
+
hop_length=256,
|
35 |
+
win_length=1024,
|
36 |
+
fmin=0,
|
37 |
+
fmax=None,
|
38 |
+
center=False,
|
39 |
+
): # Copy from https://github.com/NVIDIA/BigVGAN/tree/main
|
40 |
+
device = waveform.device
|
41 |
+
key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{fmin}_{fmax}_{device}"
|
42 |
+
|
43 |
+
if key not in mel_basis_cache:
|
44 |
+
mel = librosa_mel_fn(sr=target_sample_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=fmin, fmax=fmax)
|
45 |
+
mel_basis_cache[key] = torch.from_numpy(mel).float().to(device) # TODO: why they need .float()?
|
46 |
+
hann_window_cache[key] = torch.hann_window(win_length).to(device)
|
47 |
+
|
48 |
+
mel_basis = mel_basis_cache[key]
|
49 |
+
hann_window = hann_window_cache[key]
|
50 |
+
|
51 |
+
padding = (n_fft - hop_length) // 2
|
52 |
+
waveform = torch.nn.functional.pad(waveform.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1)
|
53 |
+
|
54 |
+
spec = torch.stft(
|
55 |
+
waveform,
|
56 |
+
n_fft,
|
57 |
+
hop_length=hop_length,
|
58 |
+
win_length=win_length,
|
59 |
+
window=hann_window,
|
60 |
+
center=center,
|
61 |
+
pad_mode="reflect",
|
62 |
+
normalized=False,
|
63 |
+
onesided=True,
|
64 |
+
return_complex=True,
|
65 |
+
)
|
66 |
+
spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
|
67 |
+
|
68 |
+
mel_spec = torch.matmul(mel_basis, spec)
|
69 |
+
mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5))
|
70 |
+
|
71 |
+
return mel_spec
|
72 |
+
|
73 |
+
|
74 |
+
class MelSpec(nn.Module):
|
75 |
+
def __init__(
|
76 |
+
self,
|
77 |
+
n_fft=1024,
|
78 |
+
hop_length=256,
|
79 |
+
win_length=1024,
|
80 |
+
n_mel_channels=100,
|
81 |
+
target_sample_rate=24_000,
|
82 |
+
mel_spec_type="bigvgan",
|
83 |
+
):
|
84 |
+
super().__init__()
|
85 |
+
|
86 |
+
self.n_fft = n_fft
|
87 |
+
self.hop_length = hop_length
|
88 |
+
self.win_length = win_length
|
89 |
+
self.n_mel_channels = n_mel_channels
|
90 |
+
self.target_sample_rate = target_sample_rate
|
91 |
+
|
92 |
+
self.extractor = get_bigvgan_mel_spectrogram
|
93 |
+
|
94 |
+
self.register_buffer("dummy", torch.tensor(0), persistent=False)
|
95 |
+
|
96 |
+
def forward(self, wav):
|
97 |
+
if self.dummy.device != wav.device:
|
98 |
+
self.to(wav.device)
|
99 |
+
|
100 |
+
mel = self.extractor(
|
101 |
+
waveform=wav,
|
102 |
+
n_fft=self.n_fft,
|
103 |
+
n_mel_channels=self.n_mel_channels,
|
104 |
+
target_sample_rate=self.target_sample_rate,
|
105 |
+
hop_length=self.hop_length,
|
106 |
+
win_length=self.win_length,
|
107 |
+
)
|
108 |
+
|
109 |
+
return mel
|
110 |
+
|
111 |
+
|
112 |
+
# sinusoidal position embedding
|
113 |
+
|
114 |
+
|
115 |
+
class SinusPositionEmbedding(nn.Module):
|
116 |
+
def __init__(self, dim):
|
117 |
+
super().__init__()
|
118 |
+
self.dim = dim
|
119 |
+
|
120 |
+
def forward(self, x, scale=1000):
|
121 |
+
device = x.device
|
122 |
+
half_dim = self.dim // 2
|
123 |
+
emb = math.log(10000) / (half_dim - 1)
|
124 |
+
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
|
125 |
+
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
126 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
127 |
+
return emb
|
128 |
+
|
129 |
+
|
130 |
+
# convolutional position embedding
|
131 |
+
|
132 |
+
|
133 |
+
class ConvPositionEmbedding(nn.Module):
|
134 |
+
def __init__(self, dim, kernel_size=31, groups=16):
|
135 |
+
super().__init__()
|
136 |
+
assert kernel_size % 2 != 0
|
137 |
+
self.conv1d = nn.Sequential(
|
138 |
+
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
|
139 |
+
nn.Mish(),
|
140 |
+
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
|
141 |
+
nn.Mish(),
|
142 |
+
)
|
143 |
+
|
144 |
+
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
|
145 |
+
if mask is not None:
|
146 |
+
mask = mask[..., None]
|
147 |
+
x = x.masked_fill(~mask, 0.0)
|
148 |
+
|
149 |
+
x = x.permute(0, 2, 1)
|
150 |
+
x = self.conv1d(x)
|
151 |
+
out = x.permute(0, 2, 1)
|
152 |
+
|
153 |
+
if mask is not None:
|
154 |
+
out = out.masked_fill(~mask, 0.0)
|
155 |
+
|
156 |
+
return out
|
157 |
+
|
158 |
+
|
159 |
+
# rotary positional embedding related
|
160 |
+
|
161 |
+
|
162 |
+
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
|
163 |
+
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
164 |
+
# has some connection to NTK literature
|
165 |
+
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
166 |
+
# https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
|
167 |
+
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
168 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
169 |
+
t = torch.arange(end, device=freqs.device) # type: ignore
|
170 |
+
freqs = torch.outer(t, freqs).float() # type: ignore
|
171 |
+
freqs_cos = torch.cos(freqs) # real part
|
172 |
+
freqs_sin = torch.sin(freqs) # imaginary part
|
173 |
+
return torch.cat([freqs_cos, freqs_sin], dim=-1)
|
174 |
+
|
175 |
+
|
176 |
+
def get_pos_embed_indices(start, length, max_pos, scale=1.0):
|
177 |
+
# length = length if isinstance(length, int) else length.max()
|
178 |
+
scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
|
179 |
+
pos = (
|
180 |
+
start.unsqueeze(1)
|
181 |
+
+ (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long()
|
182 |
+
)
|
183 |
+
# avoid extra long error.
|
184 |
+
pos = torch.where(pos < max_pos, pos, max_pos - 1)
|
185 |
+
return pos
|
186 |
+
|
187 |
+
|
188 |
+
# Global Response Normalization layer (Instance Normalization ?)
|
189 |
+
|
190 |
+
|
191 |
+
class GRN(nn.Module):
|
192 |
+
def __init__(self, dim):
|
193 |
+
super().__init__()
|
194 |
+
self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
|
195 |
+
self.beta = nn.Parameter(torch.zeros(1, 1, dim))
|
196 |
+
|
197 |
+
def forward(self, x):
|
198 |
+
Gx = torch.norm(x, p=2, dim=1, keepdim=True)
|
199 |
+
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
|
200 |
+
return self.gamma * (x * Nx) + self.beta + x
|
201 |
+
|
202 |
+
|
203 |
+
# ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
|
204 |
+
# ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
|
205 |
+
|
206 |
+
|
207 |
+
class ConvNeXtV2Block(nn.Module):
|
208 |
+
def __init__(
|
209 |
+
self,
|
210 |
+
dim: int,
|
211 |
+
intermediate_dim: int,
|
212 |
+
dilation: int = 1,
|
213 |
+
):
|
214 |
+
super().__init__()
|
215 |
+
padding = (dilation * (7 - 1)) // 2
|
216 |
+
self.dwconv = nn.Conv1d(
|
217 |
+
dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
|
218 |
+
) # depthwise conv
|
219 |
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
220 |
+
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
|
221 |
+
self.act = nn.GELU()
|
222 |
+
self.grn = GRN(intermediate_dim)
|
223 |
+
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
224 |
+
|
225 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
226 |
+
residual = x
|
227 |
+
x = x.transpose(1, 2) # b n d -> b d n
|
228 |
+
x = self.dwconv(x)
|
229 |
+
x = x.transpose(1, 2) # b d n -> b n d
|
230 |
+
x = self.norm(x)
|
231 |
+
x = self.pwconv1(x)
|
232 |
+
x = self.act(x)
|
233 |
+
x = self.grn(x)
|
234 |
+
x = self.pwconv2(x)
|
235 |
+
return residual + x
|
236 |
+
|
237 |
+
|
238 |
+
# AdaLayerNormZero
|
239 |
+
# return with modulated x for attn input, and params for later mlp modulation
|
240 |
+
|
241 |
+
|
242 |
+
class AdaLayerNormZero(nn.Module):
|
243 |
+
def __init__(self, dim):
|
244 |
+
super().__init__()
|
245 |
+
|
246 |
+
self.silu = nn.SiLU()
|
247 |
+
self.linear = nn.Linear(dim, dim * 6)
|
248 |
+
|
249 |
+
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
250 |
+
|
251 |
+
def forward(self, x, emb=None):
|
252 |
+
emb = self.linear(self.silu(emb))
|
253 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
|
254 |
+
|
255 |
+
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
256 |
+
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
257 |
+
|
258 |
+
|
259 |
+
# AdaLayerNormZero for final layer
|
260 |
+
# return only with modulated x for attn input, cuz no more mlp modulation
|
261 |
+
|
262 |
+
|
263 |
+
class AdaLayerNormZero_Final(nn.Module):
|
264 |
+
def __init__(self, dim):
|
265 |
+
super().__init__()
|
266 |
+
|
267 |
+
self.silu = nn.SiLU()
|
268 |
+
self.linear = nn.Linear(dim, dim * 2)
|
269 |
+
|
270 |
+
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
271 |
+
|
272 |
+
def forward(self, x, emb):
|
273 |
+
emb = self.linear(self.silu(emb))
|
274 |
+
scale, shift = torch.chunk(emb, 2, dim=1)
|
275 |
+
|
276 |
+
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
277 |
+
return x
|
278 |
+
|
279 |
+
|
280 |
+
# FeedForward
|
281 |
+
|
282 |
+
|
283 |
+
class FeedForward(nn.Module):
|
284 |
+
def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"):
|
285 |
+
super().__init__()
|
286 |
+
inner_dim = int(dim * mult)
|
287 |
+
dim_out = dim_out if dim_out is not None else dim
|
288 |
+
|
289 |
+
activation = nn.GELU(approximate=approximate)
|
290 |
+
project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
|
291 |
+
self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
|
292 |
+
|
293 |
+
def forward(self, x):
|
294 |
+
return self.ff(x)
|
295 |
+
|
296 |
+
|
297 |
+
# Attention with possible joint part
|
298 |
+
# modified from diffusers/src/diffusers/models/attention_processor.py
|
299 |
+
|
300 |
+
|
301 |
+
class Attention(nn.Module):
|
302 |
+
def __init__(
|
303 |
+
self,
|
304 |
+
processor: AttnProcessor,
|
305 |
+
dim: int,
|
306 |
+
heads: int = 8,
|
307 |
+
dim_head: int = 64,
|
308 |
+
dropout: float = 0.0,
|
309 |
+
context_dim: Optional[int] = None, # if not None -> joint attention
|
310 |
+
context_pre_only=None,
|
311 |
+
):
|
312 |
+
super().__init__()
|
313 |
+
|
314 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
315 |
+
raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
316 |
+
|
317 |
+
self.processor = processor
|
318 |
+
|
319 |
+
self.dim = dim
|
320 |
+
self.heads = heads
|
321 |
+
self.inner_dim = dim_head * heads
|
322 |
+
self.dropout = dropout
|
323 |
+
|
324 |
+
self.context_dim = context_dim
|
325 |
+
self.context_pre_only = context_pre_only
|
326 |
+
|
327 |
+
self.to_q = nn.Linear(dim, self.inner_dim)
|
328 |
+
self.to_k = nn.Linear(dim, self.inner_dim)
|
329 |
+
self.to_v = nn.Linear(dim, self.inner_dim)
|
330 |
+
|
331 |
+
if self.context_dim is not None:
|
332 |
+
self.to_k_c = nn.Linear(context_dim, self.inner_dim)
|
333 |
+
self.to_v_c = nn.Linear(context_dim, self.inner_dim)
|
334 |
+
if self.context_pre_only is not None:
|
335 |
+
self.to_q_c = nn.Linear(context_dim, self.inner_dim)
|
336 |
+
|
337 |
+
self.to_out = nn.ModuleList([])
|
338 |
+
self.to_out.append(nn.Linear(self.inner_dim, dim))
|
339 |
+
self.to_out.append(nn.Dropout(dropout))
|
340 |
+
|
341 |
+
if self.context_pre_only is not None and not self.context_pre_only:
|
342 |
+
self.to_out_c = nn.Linear(self.inner_dim, dim)
|
343 |
+
|
344 |
+
def forward(
|
345 |
+
self,
|
346 |
+
x: float["b n d"], # noised input x # noqa: F722
|
347 |
+
c: float["b n d"] = None, # context c # noqa: F722
|
348 |
+
mask: bool["b n"] | None = None, # noqa: F722
|
349 |
+
rope=None, # rotary position embedding for x
|
350 |
+
c_rope=None, # rotary position embedding for c
|
351 |
+
) -> torch.Tensor:
|
352 |
+
if c is not None:
|
353 |
+
return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope)
|
354 |
+
else:
|
355 |
+
return self.processor(self, x, mask=mask, rope=rope)
|
356 |
+
|
357 |
+
|
358 |
+
# Attention processor
|
359 |
+
|
360 |
+
|
361 |
+
class AttnProcessor:
|
362 |
+
def __init__(self):
|
363 |
+
pass
|
364 |
+
|
365 |
+
def __call__(
|
366 |
+
self,
|
367 |
+
attn: Attention,
|
368 |
+
x: float["b n d"], # noised input x # noqa: F722
|
369 |
+
mask: bool["b n"] | None = None, # noqa: F722
|
370 |
+
rope=None, # rotary position embedding
|
371 |
+
) -> torch.FloatTensor:
|
372 |
+
batch_size = x.shape[0]
|
373 |
+
|
374 |
+
# `sample` projections.
|
375 |
+
query = attn.to_q(x)
|
376 |
+
key = attn.to_k(x)
|
377 |
+
value = attn.to_v(x)
|
378 |
+
|
379 |
+
# apply rotary position embedding
|
380 |
+
if rope is not None:
|
381 |
+
freqs, xpos_scale = rope
|
382 |
+
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
383 |
+
|
384 |
+
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
385 |
+
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
386 |
+
|
387 |
+
# attention
|
388 |
+
inner_dim = key.shape[-1]
|
389 |
+
head_dim = inner_dim // attn.heads
|
390 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
391 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
392 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
393 |
+
|
394 |
+
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
395 |
+
if mask is not None:
|
396 |
+
attn_mask = mask
|
397 |
+
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
|
398 |
+
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
|
399 |
+
else:
|
400 |
+
attn_mask = None
|
401 |
+
|
402 |
+
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
|
403 |
+
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
404 |
+
x = x.to(query.dtype)
|
405 |
+
|
406 |
+
# linear proj
|
407 |
+
x = attn.to_out[0](x)
|
408 |
+
# dropout
|
409 |
+
x = attn.to_out[1](x)
|
410 |
+
|
411 |
+
if mask is not None:
|
412 |
+
mask = mask.unsqueeze(-1)
|
413 |
+
x = x.masked_fill(~mask, 0.0)
|
414 |
+
|
415 |
+
return x
|
416 |
+
|
417 |
+
|
418 |
+
# DiT Block
|
419 |
+
|
420 |
+
|
421 |
+
class DiTBlock(nn.Module):
|
422 |
+
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
|
423 |
+
super().__init__()
|
424 |
+
|
425 |
+
self.attn_norm = AdaLayerNormZero(dim)
|
426 |
+
self.attn = Attention(
|
427 |
+
processor=AttnProcessor(),
|
428 |
+
dim=dim,
|
429 |
+
heads=heads,
|
430 |
+
dim_head=dim_head,
|
431 |
+
dropout=dropout,
|
432 |
+
)
|
433 |
+
|
434 |
+
self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
435 |
+
self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
436 |
+
|
437 |
+
def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
|
438 |
+
# pre-norm & modulation for attention input
|
439 |
+
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
|
440 |
+
|
441 |
+
# attention
|
442 |
+
attn_output = self.attn(x=norm, mask=mask, rope=rope)
|
443 |
+
|
444 |
+
# process attention output for input x
|
445 |
+
x = x + gate_msa.unsqueeze(1) * attn_output
|
446 |
+
|
447 |
+
norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
448 |
+
ff_output = self.ff(norm)
|
449 |
+
x = x + gate_mlp.unsqueeze(1) * ff_output
|
450 |
+
|
451 |
+
return x
|
452 |
+
|
453 |
+
|
454 |
+
# time step conditioning embedding
|
455 |
+
|
456 |
+
|
457 |
+
class TimestepEmbedding(nn.Module):
|
458 |
+
def __init__(self, dim, freq_embed_dim=256):
|
459 |
+
super().__init__()
|
460 |
+
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
|
461 |
+
self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
462 |
+
|
463 |
+
def forward(self, timestep: float["b"]): # noqa: F821
|
464 |
+
time_hidden = self.time_embed(timestep)
|
465 |
+
time_hidden = time_hidden.to(timestep.dtype)
|
466 |
+
time = self.time_mlp(time_hidden) # b d
|
467 |
+
return time
|
src/moviedubber/model/utils.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
|
4 |
+
import jieba
|
5 |
+
import torch
|
6 |
+
from pypinyin import Style, lazy_pinyin
|
7 |
+
from torch.nn.utils.rnn import pad_sequence
|
8 |
+
|
9 |
+
|
10 |
+
def exists(v):
|
11 |
+
return v is not None
|
12 |
+
|
13 |
+
|
14 |
+
def default(v, d):
|
15 |
+
return v if exists(v) else d
|
16 |
+
|
17 |
+
|
18 |
+
# tensor helpers
|
19 |
+
|
20 |
+
|
21 |
+
def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa: F722 F821
|
22 |
+
if not exists(length):
|
23 |
+
length = t.amax()
|
24 |
+
|
25 |
+
seq = torch.arange(length, device=t.device)
|
26 |
+
return seq[None, :] < t[:, None]
|
27 |
+
|
28 |
+
|
29 |
+
# simple utf-8 tokenizer, since paper went character based
|
30 |
+
def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]: # noqa: F722
|
31 |
+
list_tensors = [torch.tensor([*bytes(t, "UTF-8")]) for t in text] # ByT5 style
|
32 |
+
text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True)
|
33 |
+
return text
|
34 |
+
|
35 |
+
|
36 |
+
# char tokenizer, based on custom dataset's extracted .txt file
|
37 |
+
def list_str_to_idx(
|
38 |
+
text: list[str] | list[list[str]],
|
39 |
+
vocab_char_map: dict[str, int], # {char: idx}
|
40 |
+
padding_value=-1,
|
41 |
+
) -> int["b nt"]: # noqa: F722
|
42 |
+
list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
|
43 |
+
text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
|
44 |
+
return text
|
45 |
+
|
46 |
+
|
47 |
+
# Get tokenizer
|
48 |
+
|
49 |
+
|
50 |
+
def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
|
51 |
+
"""
|
52 |
+
tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
|
53 |
+
- "char" for char-wise tokenizer, need .txt vocab_file
|
54 |
+
- "byte" for utf-8 tokenizer
|
55 |
+
- "custom" if you're directly passing in a path to the vocab.txt you want to use
|
56 |
+
vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
|
57 |
+
- if use "char", derived from unfiltered character & symbol counts of custom dataset
|
58 |
+
- if use "byte", set to 256 (unicode byte range)
|
59 |
+
"""
|
60 |
+
if tokenizer in ["pinyin", "char"]:
|
61 |
+
# tokenizer_path = os.path.join(files("f5_tts").joinpath("../../data"), f"{dataset_name}_{tokenizer}/vocab.txt")
|
62 |
+
tokenizer_path = "/ailab-train/speech/zhengjunjie/huggingface/models/F5-TTS/F5TTS_Base/vocab.txt"
|
63 |
+
print(f"Loading {tokenizer} tokenizer from {tokenizer_path}")
|
64 |
+
with open(tokenizer_path, "r", encoding="utf-8") as f:
|
65 |
+
vocab_char_map = {}
|
66 |
+
for i, char in enumerate(f):
|
67 |
+
vocab_char_map[char[:-1]] = i
|
68 |
+
vocab_size = len(vocab_char_map)
|
69 |
+
assert vocab_char_map[" "] == 0, "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char"
|
70 |
+
|
71 |
+
elif tokenizer == "byte":
|
72 |
+
vocab_char_map = None
|
73 |
+
vocab_size = 256
|
74 |
+
|
75 |
+
elif tokenizer == "custom":
|
76 |
+
with open(dataset_name, "r", encoding="utf-8") as f:
|
77 |
+
vocab_char_map = {}
|
78 |
+
for i, char in enumerate(f):
|
79 |
+
vocab_char_map[char[:-1]] = i
|
80 |
+
vocab_size = len(vocab_char_map)
|
81 |
+
|
82 |
+
return vocab_char_map, vocab_size
|
83 |
+
|
84 |
+
|
85 |
+
# convert char to pinyin
|
86 |
+
|
87 |
+
jieba.initialize()
|
88 |
+
print("Word segmentation module jieba initialized.\n")
|
89 |
+
|
90 |
+
|
91 |
+
def convert_char_to_pinyin(text_list, polyphone=True):
|
92 |
+
final_text_list = []
|
93 |
+
custom_trans = str.maketrans(
|
94 |
+
{";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"}
|
95 |
+
) # add custom trans here, to address oov
|
96 |
+
|
97 |
+
def is_chinese(c):
|
98 |
+
return (
|
99 |
+
"\u3100" <= c <= "\u9fff" # common chinese characters
|
100 |
+
)
|
101 |
+
|
102 |
+
for text in text_list:
|
103 |
+
char_list = []
|
104 |
+
text = text.translate(custom_trans)
|
105 |
+
for seg in jieba.cut(text):
|
106 |
+
seg_byte_len = len(bytes(seg, "UTF-8"))
|
107 |
+
if seg_byte_len == len(seg): # if pure alphabets and symbols
|
108 |
+
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
|
109 |
+
char_list.append(" ")
|
110 |
+
char_list.extend(seg)
|
111 |
+
elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters
|
112 |
+
seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
|
113 |
+
for i, c in enumerate(seg):
|
114 |
+
if is_chinese(c):
|
115 |
+
char_list.append(" ")
|
116 |
+
char_list.append(seg_[i])
|
117 |
+
else: # if mixed characters, alphabets and symbols
|
118 |
+
for c in seg:
|
119 |
+
if ord(c) < 256:
|
120 |
+
char_list.extend(c)
|
121 |
+
elif is_chinese(c):
|
122 |
+
char_list.append(" ")
|
123 |
+
char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
|
124 |
+
else:
|
125 |
+
char_list.append(c)
|
126 |
+
final_text_list.append(char_list)
|
127 |
+
|
128 |
+
return final_text_list
|
src/third_party/BigVGAN/.gitignore
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# BigVGAN
|
2 |
+
alias_free_activation/cuda/build/
|
3 |
+
exp/
|
4 |
+
tmp/
|
5 |
+
|
6 |
+
# Symlinks for bundled LibriTTS filelists
|
7 |
+
filelists/LibriTTS/train-clean-100
|
8 |
+
filelists/LibriTTS/train-clean-360
|
9 |
+
filelists/LibriTTS/train-other-500
|
10 |
+
filelists/LibriTTS/dev-clean
|
11 |
+
filelists/LibriTTS/dev-other
|
12 |
+
filelists/LibriTTS/test-clean
|
13 |
+
filelists/LibriTTS/test-other
|
14 |
+
|
15 |
+
# VSCode configs
|
16 |
+
.vscode/
|
17 |
+
|
18 |
+
# Byte-compiled / optimized / DLL files
|
19 |
+
__pycache__/
|
20 |
+
*.py[cod]
|
21 |
+
*$py.class
|
22 |
+
|
23 |
+
# C extensions
|
24 |
+
*.so
|
25 |
+
|
26 |
+
# Distribution / packaging
|
27 |
+
.Python
|
28 |
+
build/
|
29 |
+
develop-eggs/
|
30 |
+
dist/
|
31 |
+
downloads/
|
32 |
+
eggs/
|
33 |
+
.eggs/
|
34 |
+
lib/
|
35 |
+
lib64/
|
36 |
+
parts/
|
37 |
+
sdist/
|
38 |
+
var/
|
39 |
+
wheels/
|
40 |
+
share/python-wheels/
|
41 |
+
*.egg-info/
|
42 |
+
.installed.cfg
|
43 |
+
*.egg
|
44 |
+
MANIFEST
|
45 |
+
|
46 |
+
# PyInstaller
|
47 |
+
# Usually these files are written by a python script from a template
|
48 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
49 |
+
*.manifest
|
50 |
+
*.spec
|
51 |
+
|
52 |
+
# Installer logs
|
53 |
+
pip-log.txt
|
54 |
+
pip-delete-this-directory.txt
|
55 |
+
|
56 |
+
# Unit test / coverage reports
|
57 |
+
htmlcov/
|
58 |
+
.tox/
|
59 |
+
.nox/
|
60 |
+
.coverage
|
61 |
+
.coverage.*
|
62 |
+
.cache
|
63 |
+
nosetests.xml
|
64 |
+
coverage.xml
|
65 |
+
*.cover
|
66 |
+
*.py,cover
|
67 |
+
.hypothesis/
|
68 |
+
.pytest_cache/
|
69 |
+
cover/
|
70 |
+
|
71 |
+
# Translations
|
72 |
+
*.mo
|
73 |
+
*.pot
|
74 |
+
|
75 |
+
# Django stuff:
|
76 |
+
*.log
|
77 |
+
local_settings.py
|
78 |
+
db.sqlite3
|
79 |
+
db.sqlite3-journal
|
80 |
+
|
81 |
+
# Flask stuff:
|
82 |
+
instance/
|
83 |
+
.webassets-cache
|
84 |
+
|
85 |
+
# Scrapy stuff:
|
86 |
+
.scrapy
|
87 |
+
|
88 |
+
# Sphinx documentation
|
89 |
+
docs/_build/
|
90 |
+
|
91 |
+
# PyBuilder
|
92 |
+
.pybuilder/
|
93 |
+
target/
|
94 |
+
|
95 |
+
# Jupyter Notebook
|
96 |
+
.ipynb_checkpoints
|
97 |
+
|
98 |
+
# IPython
|
99 |
+
profile_default/
|
100 |
+
ipython_config.py
|
101 |
+
|
102 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
103 |
+
__pypackages__/
|
104 |
+
|
105 |
+
# Celery stuff
|
106 |
+
celerybeat-schedule
|
107 |
+
celerybeat.pid
|
108 |
+
|
109 |
+
# SageMath parsed files
|
110 |
+
*.sage.py
|
111 |
+
|
112 |
+
# Environments
|
113 |
+
.env
|
114 |
+
.venv
|
115 |
+
env/
|
116 |
+
venv/
|
117 |
+
ENV/
|
118 |
+
env.bak/
|
119 |
+
venv.bak/
|
120 |
+
|
121 |
+
# Spyder project settings
|
122 |
+
.spyderproject
|
123 |
+
.spyproject
|
124 |
+
|
125 |
+
# Rope project settings
|
126 |
+
.ropeproject
|
127 |
+
|
128 |
+
# mkdocs documentation
|
129 |
+
/site
|
130 |
+
|
131 |
+
# mypy
|
132 |
+
.mypy_cache/
|
133 |
+
.dmypy.json
|
134 |
+
dmypy.json
|
135 |
+
|
136 |
+
# Pyre type checker
|
137 |
+
.pyre/
|
138 |
+
|
139 |
+
# pytype static type analyzer
|
140 |
+
.pytype/
|
141 |
+
|
142 |
+
# Cython debug symbols
|
143 |
+
cython_debug/
|
144 |
+
|
145 |
+
# PyCharm
|
146 |
+
.idea/
|
src/third_party/BigVGAN/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 NVIDIA CORPORATION.
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
src/third_party/BigVGAN/README.md
ADDED
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## BigVGAN: A Universal Neural Vocoder with Large-Scale Training
|
2 |
+
|
3 |
+
#### Sang-gil Lee, Wei Ping, Boris Ginsburg, Bryan Catanzaro, Sungroh Yoon
|
4 |
+
|
5 |
+
[[Paper]](https://arxiv.org/abs/2206.04658) - [[Code]](https://github.com/NVIDIA/BigVGAN) - [[Showcase]](https://bigvgan-demo.github.io/) - [[Project Page]](https://research.nvidia.com/labs/adlr/projects/bigvgan/) - [[Weights]](https://huggingface.co/collections/nvidia/bigvgan-66959df3d97fd7d98d97dc9a) - [[Demo]](https://huggingface.co/spaces/nvidia/BigVGAN)
|
6 |
+
|
7 |
+
[](https://paperswithcode.com/sota/speech-synthesis-on-libritts?p=bigvgan-a-universal-neural-vocoder-with-large)
|
8 |
+
|
9 |
+
<center><img src="https://user-images.githubusercontent.com/15963413/218609148-881e39df-33af-4af9-ab95-1427c4ebf062.png" width="800"></center>
|
10 |
+
|
11 |
+
## News
|
12 |
+
- **Sep 2024 (v2.4):**
|
13 |
+
- We have updated the pretrained checkpoints trained for 5M steps. This is final release of the BigVGAN-v2 checkpoints.
|
14 |
+
|
15 |
+
- **Jul 2024 (v2.3):**
|
16 |
+
- General refactor and code improvements for improved readability.
|
17 |
+
- Fully fused CUDA kernel of anti-alised activation (upsampling + activation + downsampling) with inference speed benchmark.
|
18 |
+
|
19 |
+
- **Jul 2024 (v2.2):** The repository now includes an interactive local demo using gradio.
|
20 |
+
|
21 |
+
- **Jul 2024 (v2.1):** BigVGAN is now integrated with 🤗 Hugging Face Hub with easy access to inference using pretrained checkpoints. We also provide an interactive demo on Hugging Face Spaces.
|
22 |
+
|
23 |
+
- **Jul 2024 (v2):** We release BigVGAN-v2 along with pretrained checkpoints. Below are the highlights:
|
24 |
+
- Custom CUDA kernel for inference: we provide a fused upsampling + activation kernel written in CUDA for accelerated inference speed. Our test shows 1.5 - 3x faster speed on a single A100 GPU.
|
25 |
+
- Improved discriminator and loss: BigVGAN-v2 is trained using a [multi-scale sub-band CQT discriminator](https://arxiv.org/abs/2311.14957) and a [multi-scale mel spectrogram loss](https://arxiv.org/abs/2306.06546).
|
26 |
+
- Larger training data: BigVGAN-v2 is trained using datasets containing diverse audio types, including speech in multiple languages, environmental sounds, and instruments.
|
27 |
+
- We provide pretrained checkpoints of BigVGAN-v2 using diverse audio configurations, supporting up to 44 kHz sampling rate and 512x upsampling ratio.
|
28 |
+
|
29 |
+
## Installation
|
30 |
+
|
31 |
+
The codebase has been tested on Python `3.10` and PyTorch `2.3.1` conda packages with either `pytorch-cuda=12.1` or `pytorch-cuda=11.8`. Below is an example command to create the conda environment:
|
32 |
+
|
33 |
+
```shell
|
34 |
+
conda create -n bigvgan python=3.10 pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
|
35 |
+
conda activate bigvgan
|
36 |
+
```
|
37 |
+
|
38 |
+
Clone the repository and install dependencies:
|
39 |
+
|
40 |
+
```shell
|
41 |
+
git clone https://github.com/NVIDIA/BigVGAN
|
42 |
+
cd BigVGAN
|
43 |
+
pip install -r requirements.txt
|
44 |
+
```
|
45 |
+
|
46 |
+
## Inference Quickstart using 🤗 Hugging Face Hub
|
47 |
+
|
48 |
+
Below example describes how you can use BigVGAN: load the pretrained BigVGAN generator from Hugging Face Hub, compute mel spectrogram from input waveform, and generate synthesized waveform using the mel spectrogram as the model's input.
|
49 |
+
|
50 |
+
```python
|
51 |
+
device = 'cuda'
|
52 |
+
|
53 |
+
import torch
|
54 |
+
import bigvgan
|
55 |
+
import librosa
|
56 |
+
from meldataset import get_mel_spectrogram
|
57 |
+
|
58 |
+
# instantiate the model. You can optionally set use_cuda_kernel=True for faster inference.
|
59 |
+
model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_24khz_100band_256x', use_cuda_kernel=False)
|
60 |
+
|
61 |
+
# remove weight norm in the model and set to eval mode
|
62 |
+
model.remove_weight_norm()
|
63 |
+
model = model.eval().to(device)
|
64 |
+
|
65 |
+
# load wav file and compute mel spectrogram
|
66 |
+
wav_path = '/path/to/your/audio.wav'
|
67 |
+
wav, sr = librosa.load(wav_path, sr=model.h.sampling_rate, mono=True) # wav is np.ndarray with shape [T_time] and values in [-1, 1]
|
68 |
+
wav = torch.FloatTensor(wav).unsqueeze(0) # wav is FloatTensor with shape [B(1), T_time]
|
69 |
+
|
70 |
+
# compute mel spectrogram from the ground truth audio
|
71 |
+
mel = get_mel_spectrogram(wav, model.h).to(device) # mel is FloatTensor with shape [B(1), C_mel, T_frame]
|
72 |
+
|
73 |
+
# generate waveform from mel
|
74 |
+
with torch.inference_mode():
|
75 |
+
wav_gen = model(mel) # wav_gen is FloatTensor with shape [B(1), 1, T_time] and values in [-1, 1]
|
76 |
+
wav_gen_float = wav_gen.squeeze(0).cpu() # wav_gen is FloatTensor with shape [1, T_time]
|
77 |
+
|
78 |
+
# you can convert the generated waveform to 16 bit linear PCM
|
79 |
+
wav_gen_int16 = (wav_gen_float * 32767.0).numpy().astype('int16') # wav_gen is now np.ndarray with shape [1, T_time] and int16 dtype
|
80 |
+
```
|
81 |
+
|
82 |
+
## Local gradio demo <a href='https://github.com/gradio-app/gradio'><img src='https://img.shields.io/github/stars/gradio-app/gradio'></a>
|
83 |
+
|
84 |
+
You can run a local gradio demo using below command:
|
85 |
+
|
86 |
+
```python
|
87 |
+
pip install -r demo/requirements.txt
|
88 |
+
python demo/app.py
|
89 |
+
```
|
90 |
+
|
91 |
+
## Training
|
92 |
+
|
93 |
+
Create symbolic link to the root of the dataset. The codebase uses filelist with the relative path from the dataset. Below are the example commands for LibriTTS dataset:
|
94 |
+
|
95 |
+
```shell
|
96 |
+
cd filelists/LibriTTS && \
|
97 |
+
ln -s /path/to/your/LibriTTS/train-clean-100 train-clean-100 && \
|
98 |
+
ln -s /path/to/your/LibriTTS/train-clean-360 train-clean-360 && \
|
99 |
+
ln -s /path/to/your/LibriTTS/train-other-500 train-other-500 && \
|
100 |
+
ln -s /path/to/your/LibriTTS/dev-clean dev-clean && \
|
101 |
+
ln -s /path/to/your/LibriTTS/dev-other dev-other && \
|
102 |
+
ln -s /path/to/your/LibriTTS/test-clean test-clean && \
|
103 |
+
ln -s /path/to/your/LibriTTS/test-other test-other && \
|
104 |
+
cd ../..
|
105 |
+
```
|
106 |
+
|
107 |
+
Train BigVGAN model. Below is an example command for training BigVGAN-v2 using LibriTTS dataset at 24kHz with a full 100-band mel spectrogram as input:
|
108 |
+
|
109 |
+
```shell
|
110 |
+
python train.py \
|
111 |
+
--config configs/bigvgan_v2_24khz_100band_256x.json \
|
112 |
+
--input_wavs_dir filelists/LibriTTS \
|
113 |
+
--input_training_file filelists/LibriTTS/train-full.txt \
|
114 |
+
--input_validation_file filelists/LibriTTS/val-full.txt \
|
115 |
+
--list_input_unseen_wavs_dir filelists/LibriTTS filelists/LibriTTS \
|
116 |
+
--list_input_unseen_validation_file filelists/LibriTTS/dev-clean.txt filelists/LibriTTS/dev-other.txt \
|
117 |
+
--checkpoint_path exp/bigvgan_v2_24khz_100band_256x
|
118 |
+
```
|
119 |
+
|
120 |
+
## Synthesis
|
121 |
+
|
122 |
+
Synthesize from BigVGAN model. Below is an example command for generating audio from the model.
|
123 |
+
It computes mel spectrograms using wav files from `--input_wavs_dir` and saves the generated audio to `--output_dir`.
|
124 |
+
|
125 |
+
```shell
|
126 |
+
python inference.py \
|
127 |
+
--checkpoint_file /path/to/your/bigvgan_v2_24khz_100band_256x/bigvgan_generator.pt \
|
128 |
+
--input_wavs_dir /path/to/your/input_wav \
|
129 |
+
--output_dir /path/to/your/output_wav
|
130 |
+
```
|
131 |
+
|
132 |
+
`inference_e2e.py` supports synthesis directly from the mel spectrogram saved in `.npy` format, with shapes `[1, channel, frame]` or `[channel, frame]`.
|
133 |
+
It loads mel spectrograms from `--input_mels_dir` and saves the generated audio to `--output_dir`.
|
134 |
+
|
135 |
+
Make sure that the STFT hyperparameters for mel spectrogram are the same as the model, which are defined in `config.json` of the corresponding model.
|
136 |
+
|
137 |
+
```shell
|
138 |
+
python inference_e2e.py \
|
139 |
+
--checkpoint_file /path/to/your/bigvgan_v2_24khz_100band_256x/bigvgan_generator.pt \
|
140 |
+
--input_mels_dir /path/to/your/input_mel \
|
141 |
+
--output_dir /path/to/your/output_wav
|
142 |
+
```
|
143 |
+
|
144 |
+
## Using Custom CUDA Kernel for Synthesis
|
145 |
+
|
146 |
+
You can apply the fast CUDA inference kernel by using a parameter `use_cuda_kernel` when instantiating BigVGAN:
|
147 |
+
|
148 |
+
```python
|
149 |
+
generator = BigVGAN(h, use_cuda_kernel=True)
|
150 |
+
```
|
151 |
+
|
152 |
+
You can also pass `--use_cuda_kernel` to `inference.py` and `inference_e2e.py` to enable this feature.
|
153 |
+
|
154 |
+
When applied for the first time, it builds the kernel using `nvcc` and `ninja`. If the build succeeds, the kernel is saved to `alias_free_activation/cuda/build` and the model automatically loads the kernel. The codebase has been tested using CUDA `12.1`.
|
155 |
+
|
156 |
+
Please make sure that both are installed in your system and `nvcc` installed in your system matches the version your PyTorch build is using.
|
157 |
+
|
158 |
+
We recommend running `test_cuda_vs_torch_model.py` first to build and check the correctness of the CUDA kernel. See below example command and its output, where it returns `[Success] test CUDA fused vs. plain torch BigVGAN inference`:
|
159 |
+
|
160 |
+
```python
|
161 |
+
python tests/test_cuda_vs_torch_model.py \
|
162 |
+
--checkpoint_file /path/to/your/bigvgan_generator.pt
|
163 |
+
```
|
164 |
+
|
165 |
+
```shell
|
166 |
+
loading plain Pytorch BigVGAN
|
167 |
+
...
|
168 |
+
loading CUDA kernel BigVGAN with auto-build
|
169 |
+
Detected CUDA files, patching ldflags
|
170 |
+
Emitting ninja build file /path/to/your/BigVGAN/alias_free_activation/cuda/build/build.ninja..
|
171 |
+
Building extension module anti_alias_activation_cuda...
|
172 |
+
...
|
173 |
+
Loading extension module anti_alias_activation_cuda...
|
174 |
+
...
|
175 |
+
Loading '/path/to/your/bigvgan_generator.pt'
|
176 |
+
...
|
177 |
+
[Success] test CUDA fused vs. plain torch BigVGAN inference
|
178 |
+
> mean_difference=0.0007238413265440613
|
179 |
+
...
|
180 |
+
```
|
181 |
+
|
182 |
+
If you see `[Fail] test CUDA fused vs. plain torch BigVGAN inference`, it means that the CUDA kernel inference is incorrect. Please check if `nvcc` installed in your system is compatible with your PyTorch version.
|
183 |
+
|
184 |
+
## Pretrained Models
|
185 |
+
|
186 |
+
We provide the [pretrained models on Hugging Face Collections](https://huggingface.co/collections/nvidia/bigvgan-66959df3d97fd7d98d97dc9a).
|
187 |
+
One can download the checkpoints of the generator weight (named `bigvgan_generator.pt`) and its discriminator/optimizer states (named `bigvgan_discriminator_optimizer.pt`) within the listed model repositories.
|
188 |
+
|
189 |
+
| Model Name | Sampling Rate | Mel band | fmax | Upsampling Ratio | Params | Dataset | Steps | Fine-Tuned |
|
190 |
+
|:--------------------------------------------------------------------------------------------------------:|:-------------:|:--------:|:-----:|:----------------:|:------:|:--------------------------:|:-----:|:----------:|
|
191 |
+
| [bigvgan_v2_44khz_128band_512x](https://huggingface.co/nvidia/bigvgan_v2_44khz_128band_512x) | 44 kHz | 128 | 22050 | 512 | 122M | Large-scale Compilation | 5M | No |
|
192 |
+
| [bigvgan_v2_44khz_128band_256x](https://huggingface.co/nvidia/bigvgan_v2_44khz_128band_256x) | 44 kHz | 128 | 22050 | 256 | 112M | Large-scale Compilation | 5M | No |
|
193 |
+
| [bigvgan_v2_24khz_100band_256x](https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x) | 24 kHz | 100 | 12000 | 256 | 112M | Large-scale Compilation | 5M | No |
|
194 |
+
| [bigvgan_v2_22khz_80band_256x](https://huggingface.co/nvidia/bigvgan_v2_22khz_80band_256x) | 22 kHz | 80 | 11025 | 256 | 112M | Large-scale Compilation | 5M | No |
|
195 |
+
| [bigvgan_v2_22khz_80band_fmax8k_256x](https://huggingface.co/nvidia/bigvgan_v2_22khz_80band_fmax8k_256x) | 22 kHz | 80 | 8000 | 256 | 112M | Large-scale Compilation | 5M | No |
|
196 |
+
| [bigvgan_24khz_100band](https://huggingface.co/nvidia/bigvgan_24khz_100band) | 24 kHz | 100 | 12000 | 256 | 112M | LibriTTS | 5M | No |
|
197 |
+
| [bigvgan_base_24khz_100band](https://huggingface.co/nvidia/bigvgan_base_24khz_100band) | 24 kHz | 100 | 12000 | 256 | 14M | LibriTTS | 5M | No |
|
198 |
+
| [bigvgan_22khz_80band](https://huggingface.co/nvidia/bigvgan_22khz_80band) | 22 kHz | 80 | 8000 | 256 | 112M | LibriTTS + VCTK + LJSpeech | 5M | No |
|
199 |
+
| [bigvgan_base_22khz_80band](https://huggingface.co/nvidia/bigvgan_base_22khz_80band) | 22 kHz | 80 | 8000 | 256 | 14M | LibriTTS + VCTK + LJSpeech | 5M | No |
|
200 |
+
|
201 |
+
The paper results are based on the original 24kHz BigVGAN models (`bigvgan_24khz_100band` and `bigvgan_base_24khz_100band`) trained on LibriTTS dataset.
|
202 |
+
We also provide 22kHz BigVGAN models with band-limited setup (i.e., fmax=8000) for TTS applications.
|
203 |
+
Note that the checkpoints use `snakebeta` activation with log scale parameterization, which have the best overall quality.
|
204 |
+
|
205 |
+
You can fine-tune the models by:
|
206 |
+
|
207 |
+
1. downloading the checkpoints (both the generator weight and its discriminator/optimizer states)
|
208 |
+
2. resuming training using your audio dataset by specifying `--checkpoint_path` that includes the checkpoints when launching `train.py`
|
209 |
+
|
210 |
+
## Training Details of BigVGAN-v2
|
211 |
+
|
212 |
+
Comapred to the original BigVGAN, the pretrained checkpoints of BigVGAN-v2 used `batch_size=32` with a longer `segment_size=65536` and are trained using 8 A100 GPUs.
|
213 |
+
|
214 |
+
Note that the BigVGAN-v2 `json` config files in `./configs` use `batch_size=4` as default to fit in a single A100 GPU for training. You can fine-tune the models adjusting `batch_size` depending on your GPUs.
|
215 |
+
|
216 |
+
When training BigVGAN-v2 from scratch with small batch size, it can potentially encounter the early divergence problem mentioned in the paper. In such case, we recommend lowering the `clip_grad_norm` value (e.g. `100`) for the early training iterations (e.g. 20000 steps) and increase the value to the default `500`.
|
217 |
+
|
218 |
+
## Evaluation Results of BigVGAN-v2
|
219 |
+
|
220 |
+
Below are the objective results of the 24kHz model (`bigvgan_v2_24khz_100band_256x`) obtained from the LibriTTS `dev` sets. BigVGAN-v2 shows noticeable improvements of the metrics. The model also exhibits reduced perceptual artifacts, especially for non-speech audio.
|
221 |
+
|
222 |
+
| Model | Dataset | Steps | PESQ(↑) | M-STFT(↓) | MCD(↓) | Periodicity(↓) | V/UV F1(↑) |
|
223 |
+
|:----------:|:-----------------------:|:-----:|:---------:|:----------:|:----------:|:--------------:|:----------:|
|
224 |
+
| BigVGAN | LibriTTS | 1M | 4.027 | 0.7997 | 0.3745 | 0.1018 | 0.9598 |
|
225 |
+
| BigVGAN | LibriTTS | 5M | 4.256 | 0.7409 | 0.2988 | 0.0809 | 0.9698 |
|
226 |
+
| BigVGAN-v2 | Large-scale Compilation | 3M | 4.359 | 0.7134 | 0.3060 | 0.0621 | 0.9777 |
|
227 |
+
| BigVGAN-v2 | Large-scale Compilation | 5M | **4.362** | **0.7026** | **0.2903** | **0.0593** | **0.9793** |
|
228 |
+
|
229 |
+
## Speed Benchmark
|
230 |
+
|
231 |
+
Below are the speed and VRAM usage benchmark results of BigVGAN from `tests/test_cuda_vs_torch_model.py`, using `bigvgan_v2_24khz_100band_256x` as a reference model.
|
232 |
+
|
233 |
+
| GPU | num_mel_frame | use_cuda_kernel | Speed (kHz) | Real-time Factor | VRAM (GB) |
|
234 |
+
|:--------------------------:|:-------------:|:---------------:|:-----------:|:----------------:|:---------:|
|
235 |
+
| NVIDIA A100 | 256 | False | 1672.1 | 69.7x | 1.3 |
|
236 |
+
| | | True | 3916.5 | 163.2x | 1.3 |
|
237 |
+
| | 2048 | False | 1899.6 | 79.2x | 1.7 |
|
238 |
+
| | | True | 5330.1 | 222.1x | 1.7 |
|
239 |
+
| | 16384 | False | 1973.8 | 82.2x | 5.0 |
|
240 |
+
| | | True | 5761.7 | 240.1x | 4.4 |
|
241 |
+
| NVIDIA GeForce RTX 3080 | 256 | False | 841.1 | 35.0x | 1.3 |
|
242 |
+
| | | True | 1598.1 | 66.6x | 1.3 |
|
243 |
+
| | 2048 | False | 929.9 | 38.7x | 1.7 |
|
244 |
+
| | | True | 1971.3 | 82.1x | 1.6 |
|
245 |
+
| | 16384 | False | 943.4 | 39.3x | 5.0 |
|
246 |
+
| | | True | 2026.5 | 84.4x | 3.9 |
|
247 |
+
| NVIDIA GeForce RTX 2080 Ti | 256 | False | 515.6 | 21.5x | 1.3 |
|
248 |
+
| | | True | 811.3 | 33.8x | 1.3 |
|
249 |
+
| | 2048 | False | 576.5 | 24.0x | 1.7 |
|
250 |
+
| | | True | 1023.0 | 42.6x | 1.5 |
|
251 |
+
| | 16384 | False | 589.4 | 24.6x | 5.0 |
|
252 |
+
| | | True | 1068.1 | 44.5x | 3.2 |
|
253 |
+
|
254 |
+
## Acknowledgements
|
255 |
+
|
256 |
+
We thank Vijay Anand Korthikanti and Kevin J. Shih for their generous support in implementing the CUDA kernel for inference.
|
257 |
+
|
258 |
+
## References
|
259 |
+
|
260 |
+
- [HiFi-GAN](https://github.com/jik876/hifi-gan) (for generator and multi-period discriminator)
|
261 |
+
- [Snake](https://github.com/EdwardDixon/snake) (for periodic activation)
|
262 |
+
- [Alias-free-torch](https://github.com/junjun3518/alias-free-torch) (for anti-aliasing)
|
263 |
+
- [Julius](https://github.com/adefossez/julius) (for low-pass filter)
|
264 |
+
- [UnivNet](https://github.com/mindslab-ai/univnet) (for multi-resolution discriminator)
|
265 |
+
- [descript-audio-codec](https://github.com/descriptinc/descript-audio-codec) and [vocos](https://github.com/gemelo-ai/vocos) (for multi-band multi-scale STFT discriminator and multi-scale mel spectrogram loss)
|
266 |
+
- [Amphion](https://github.com/open-mmlab/Amphion) (for multi-scale sub-band CQT discriminator)
|
src/third_party/BigVGAN/activations.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
|
2 |
+
# LICENSE is in incl_licenses directory.
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn, sin, pow
|
6 |
+
from torch.nn import Parameter
|
7 |
+
|
8 |
+
|
9 |
+
class Snake(nn.Module):
|
10 |
+
"""
|
11 |
+
Implementation of a sine-based periodic activation function
|
12 |
+
Shape:
|
13 |
+
- Input: (B, C, T)
|
14 |
+
- Output: (B, C, T), same shape as the input
|
15 |
+
Parameters:
|
16 |
+
- alpha - trainable parameter
|
17 |
+
References:
|
18 |
+
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
19 |
+
https://arxiv.org/abs/2006.08195
|
20 |
+
Examples:
|
21 |
+
>>> a1 = snake(256)
|
22 |
+
>>> x = torch.randn(256)
|
23 |
+
>>> x = a1(x)
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
|
28 |
+
):
|
29 |
+
"""
|
30 |
+
Initialization.
|
31 |
+
INPUT:
|
32 |
+
- in_features: shape of the input
|
33 |
+
- alpha: trainable parameter
|
34 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
35 |
+
alpha will be trained along with the rest of your model.
|
36 |
+
"""
|
37 |
+
super(Snake, self).__init__()
|
38 |
+
self.in_features = in_features
|
39 |
+
|
40 |
+
# Initialize alpha
|
41 |
+
self.alpha_logscale = alpha_logscale
|
42 |
+
if self.alpha_logscale: # Log scale alphas initialized to zeros
|
43 |
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
44 |
+
else: # Linear scale alphas initialized to ones
|
45 |
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
46 |
+
|
47 |
+
self.alpha.requires_grad = alpha_trainable
|
48 |
+
|
49 |
+
self.no_div_by_zero = 0.000000001
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
"""
|
53 |
+
Forward pass of the function.
|
54 |
+
Applies the function to the input elementwise.
|
55 |
+
Snake ∶= x + 1/a * sin^2 (xa)
|
56 |
+
"""
|
57 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # Line up with x to [B, C, T]
|
58 |
+
if self.alpha_logscale:
|
59 |
+
alpha = torch.exp(alpha)
|
60 |
+
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
61 |
+
|
62 |
+
return x
|
63 |
+
|
64 |
+
|
65 |
+
class SnakeBeta(nn.Module):
|
66 |
+
"""
|
67 |
+
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
68 |
+
Shape:
|
69 |
+
- Input: (B, C, T)
|
70 |
+
- Output: (B, C, T), same shape as the input
|
71 |
+
Parameters:
|
72 |
+
- alpha - trainable parameter that controls frequency
|
73 |
+
- beta - trainable parameter that controls magnitude
|
74 |
+
References:
|
75 |
+
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
76 |
+
https://arxiv.org/abs/2006.08195
|
77 |
+
Examples:
|
78 |
+
>>> a1 = snakebeta(256)
|
79 |
+
>>> x = torch.randn(256)
|
80 |
+
>>> x = a1(x)
|
81 |
+
"""
|
82 |
+
|
83 |
+
def __init__(
|
84 |
+
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
|
85 |
+
):
|
86 |
+
"""
|
87 |
+
Initialization.
|
88 |
+
INPUT:
|
89 |
+
- in_features: shape of the input
|
90 |
+
- alpha - trainable parameter that controls frequency
|
91 |
+
- beta - trainable parameter that controls magnitude
|
92 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
93 |
+
beta is initialized to 1 by default, higher values = higher-magnitude.
|
94 |
+
alpha will be trained along with the rest of your model.
|
95 |
+
"""
|
96 |
+
super(SnakeBeta, self).__init__()
|
97 |
+
self.in_features = in_features
|
98 |
+
|
99 |
+
# Initialize alpha
|
100 |
+
self.alpha_logscale = alpha_logscale
|
101 |
+
if self.alpha_logscale: # Log scale alphas initialized to zeros
|
102 |
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
103 |
+
self.beta = Parameter(torch.zeros(in_features) * alpha)
|
104 |
+
else: # Linear scale alphas initialized to ones
|
105 |
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
106 |
+
self.beta = Parameter(torch.ones(in_features) * alpha)
|
107 |
+
|
108 |
+
self.alpha.requires_grad = alpha_trainable
|
109 |
+
self.beta.requires_grad = alpha_trainable
|
110 |
+
|
111 |
+
self.no_div_by_zero = 0.000000001
|
112 |
+
|
113 |
+
def forward(self, x):
|
114 |
+
"""
|
115 |
+
Forward pass of the function.
|
116 |
+
Applies the function to the input elementwise.
|
117 |
+
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
118 |
+
"""
|
119 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # Line up with x to [B, C, T]
|
120 |
+
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
121 |
+
if self.alpha_logscale:
|
122 |
+
alpha = torch.exp(alpha)
|
123 |
+
beta = torch.exp(beta)
|
124 |
+
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
125 |
+
|
126 |
+
return x
|
src/third_party/BigVGAN/alias_free_activation/cuda/__init__.py
ADDED
File without changes
|
src/third_party/BigVGAN/alias_free_activation/cuda/activation1d.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 NVIDIA CORPORATION.
|
2 |
+
# Licensed under the MIT license.
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from alias_free_activation.torch.resample import UpSample1d, DownSample1d
|
7 |
+
|
8 |
+
# load fused CUDA kernel: this enables importing anti_alias_activation_cuda
|
9 |
+
from alias_free_activation.cuda import load
|
10 |
+
|
11 |
+
anti_alias_activation_cuda = load.load()
|
12 |
+
|
13 |
+
|
14 |
+
class FusedAntiAliasActivation(torch.autograd.Function):
|
15 |
+
"""
|
16 |
+
Assumes filter size 12, replication padding on upsampling/downsampling, and logscale alpha/beta parameters as inputs.
|
17 |
+
The hyperparameters are hard-coded in the kernel to maximize speed.
|
18 |
+
NOTE: The fused kenrel is incorrect for Activation1d with different hyperparameters.
|
19 |
+
"""
|
20 |
+
|
21 |
+
@staticmethod
|
22 |
+
def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta):
|
23 |
+
activation_results = anti_alias_activation_cuda.forward(
|
24 |
+
inputs, up_ftr, down_ftr, alpha, beta
|
25 |
+
)
|
26 |
+
|
27 |
+
return activation_results
|
28 |
+
|
29 |
+
@staticmethod
|
30 |
+
def backward(ctx, output_grads):
|
31 |
+
raise NotImplementedError
|
32 |
+
return output_grads, None, None
|
33 |
+
|
34 |
+
|
35 |
+
class Activation1d(nn.Module):
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
activation,
|
39 |
+
up_ratio: int = 2,
|
40 |
+
down_ratio: int = 2,
|
41 |
+
up_kernel_size: int = 12,
|
42 |
+
down_kernel_size: int = 12,
|
43 |
+
fused: bool = True,
|
44 |
+
):
|
45 |
+
super().__init__()
|
46 |
+
self.up_ratio = up_ratio
|
47 |
+
self.down_ratio = down_ratio
|
48 |
+
self.act = activation
|
49 |
+
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
50 |
+
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
51 |
+
|
52 |
+
self.fused = fused # Whether to use fused CUDA kernel or not
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
if not self.fused:
|
56 |
+
x = self.upsample(x)
|
57 |
+
x = self.act(x)
|
58 |
+
x = self.downsample(x)
|
59 |
+
return x
|
60 |
+
else:
|
61 |
+
if self.act.__class__.__name__ == "Snake":
|
62 |
+
beta = self.act.alpha.data # Snake uses same params for alpha and beta
|
63 |
+
else:
|
64 |
+
beta = (
|
65 |
+
self.act.beta.data
|
66 |
+
) # Snakebeta uses different params for alpha and beta
|
67 |
+
alpha = self.act.alpha.data
|
68 |
+
if (
|
69 |
+
not self.act.alpha_logscale
|
70 |
+
): # Exp baked into cuda kernel, cancel it out with a log
|
71 |
+
alpha = torch.log(alpha)
|
72 |
+
beta = torch.log(beta)
|
73 |
+
|
74 |
+
x = FusedAntiAliasActivation.apply(
|
75 |
+
x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta
|
76 |
+
)
|
77 |
+
return x
|
src/third_party/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* coding=utf-8
|
2 |
+
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
#include <torch/extension.h>
|
18 |
+
|
19 |
+
extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta);
|
20 |
+
|
21 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
22 |
+
m.def("forward", &fwd_cuda, "Anti-Alias Activation forward (CUDA)");
|
23 |
+
}
|
src/third_party/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* coding=utf-8
|
2 |
+
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
#include <ATen/ATen.h>
|
18 |
+
#include <cuda.h>
|
19 |
+
#include <cuda_runtime.h>
|
20 |
+
#include <cuda_fp16.h>
|
21 |
+
#include <cuda_profiler_api.h>
|
22 |
+
#include <ATen/cuda/CUDAContext.h>
|
23 |
+
#include <torch/extension.h>
|
24 |
+
#include "type_shim.h"
|
25 |
+
#include <assert.h>
|
26 |
+
#include <cfloat>
|
27 |
+
#include <limits>
|
28 |
+
#include <stdint.h>
|
29 |
+
#include <c10/macros/Macros.h>
|
30 |
+
|
31 |
+
namespace
|
32 |
+
{
|
33 |
+
// Hard-coded hyperparameters
|
34 |
+
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
35 |
+
constexpr int ELEMENTS_PER_LDG_STG = 1; //(WARP_ITERATIONS < 4) ? 1 : 4;
|
36 |
+
constexpr int BUFFER_SIZE = 32;
|
37 |
+
constexpr int FILTER_SIZE = 12;
|
38 |
+
constexpr int HALF_FILTER_SIZE = 6;
|
39 |
+
constexpr int UPSAMPLE_REPLICATION_PAD = 5; // 5 on each side, matching torch impl
|
40 |
+
constexpr int DOWNSAMPLE_REPLICATION_PAD_LEFT = 5; // matching torch impl
|
41 |
+
constexpr int DOWNSAMPLE_REPLICATION_PAD_RIGHT = 6; // matching torch impl
|
42 |
+
|
43 |
+
template <typename input_t, typename output_t, typename acc_t>
|
44 |
+
__global__ void anti_alias_activation_forward(
|
45 |
+
output_t *dst,
|
46 |
+
const input_t *src,
|
47 |
+
const input_t *up_ftr,
|
48 |
+
const input_t *down_ftr,
|
49 |
+
const input_t *alpha,
|
50 |
+
const input_t *beta,
|
51 |
+
int batch_size,
|
52 |
+
int channels,
|
53 |
+
int seq_len)
|
54 |
+
{
|
55 |
+
// Up and downsample filters
|
56 |
+
input_t up_filter[FILTER_SIZE];
|
57 |
+
input_t down_filter[FILTER_SIZE];
|
58 |
+
|
59 |
+
// Load data from global memory including extra indices reserved for replication paddings
|
60 |
+
input_t elements[2 * FILTER_SIZE + 2 * BUFFER_SIZE + 2 * UPSAMPLE_REPLICATION_PAD] = {0};
|
61 |
+
input_t intermediates[2 * FILTER_SIZE + 2 * BUFFER_SIZE + DOWNSAMPLE_REPLICATION_PAD_LEFT + DOWNSAMPLE_REPLICATION_PAD_RIGHT] = {0};
|
62 |
+
|
63 |
+
// Output stores downsampled output before writing to dst
|
64 |
+
output_t output[BUFFER_SIZE];
|
65 |
+
|
66 |
+
// blockDim/threadIdx = (128, 1, 1)
|
67 |
+
// gridDim/blockIdx = (seq_blocks, channels, batches)
|
68 |
+
int block_offset = (blockIdx.x * 128 * BUFFER_SIZE + seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
|
69 |
+
int local_offset = threadIdx.x * BUFFER_SIZE;
|
70 |
+
int seq_offset = blockIdx.x * 128 * BUFFER_SIZE + local_offset;
|
71 |
+
|
72 |
+
// intermediate have double the seq_len
|
73 |
+
int intermediate_local_offset = threadIdx.x * BUFFER_SIZE * 2;
|
74 |
+
int intermediate_seq_offset = blockIdx.x * 128 * BUFFER_SIZE * 2 + intermediate_local_offset;
|
75 |
+
|
76 |
+
// Get values needed for replication padding before moving pointer
|
77 |
+
const input_t *right_most_pntr = src + (seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
|
78 |
+
input_t seq_left_most_value = right_most_pntr[0];
|
79 |
+
input_t seq_right_most_value = right_most_pntr[seq_len - 1];
|
80 |
+
|
81 |
+
// Move src and dst pointers
|
82 |
+
src += block_offset + local_offset;
|
83 |
+
dst += block_offset + local_offset;
|
84 |
+
|
85 |
+
// Alpha and beta values for snake activatons. Applies exp by default
|
86 |
+
alpha = alpha + blockIdx.y;
|
87 |
+
input_t alpha_val = expf(alpha[0]);
|
88 |
+
beta = beta + blockIdx.y;
|
89 |
+
input_t beta_val = expf(beta[0]);
|
90 |
+
|
91 |
+
#pragma unroll
|
92 |
+
for (int it = 0; it < FILTER_SIZE; it += 1)
|
93 |
+
{
|
94 |
+
up_filter[it] = up_ftr[it];
|
95 |
+
down_filter[it] = down_ftr[it];
|
96 |
+
}
|
97 |
+
|
98 |
+
// Apply replication padding for upsampling, matching torch impl
|
99 |
+
#pragma unroll
|
100 |
+
for (int it = -HALF_FILTER_SIZE; it < BUFFER_SIZE + HALF_FILTER_SIZE; it += 1)
|
101 |
+
{
|
102 |
+
int element_index = seq_offset + it; // index for element
|
103 |
+
if ((element_index < 0) && (element_index >= -UPSAMPLE_REPLICATION_PAD))
|
104 |
+
{
|
105 |
+
elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_left_most_value;
|
106 |
+
}
|
107 |
+
if ((element_index >= seq_len) && (element_index < seq_len + UPSAMPLE_REPLICATION_PAD))
|
108 |
+
{
|
109 |
+
elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_right_most_value;
|
110 |
+
}
|
111 |
+
if ((element_index >= 0) && (element_index < seq_len))
|
112 |
+
{
|
113 |
+
elements[2 * (HALF_FILTER_SIZE + it)] = 2 * src[it];
|
114 |
+
}
|
115 |
+
}
|
116 |
+
|
117 |
+
// Apply upsampling strided convolution and write to intermediates. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT for replication padding of the downsampilng conv later
|
118 |
+
#pragma unroll
|
119 |
+
for (int it = 0; it < (2 * BUFFER_SIZE + 2 * FILTER_SIZE); it += 1)
|
120 |
+
{
|
121 |
+
input_t acc = 0.0;
|
122 |
+
int element_index = intermediate_seq_offset + it; // index for intermediate
|
123 |
+
#pragma unroll
|
124 |
+
for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
|
125 |
+
{
|
126 |
+
if ((element_index + f_idx) >= 0)
|
127 |
+
{
|
128 |
+
acc += up_filter[f_idx] * elements[it + f_idx];
|
129 |
+
}
|
130 |
+
}
|
131 |
+
intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] = acc;
|
132 |
+
}
|
133 |
+
|
134 |
+
// Apply activation function. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT and DOWNSAMPLE_REPLICATION_PAD_RIGHT for replication padding of the downsampilng conv later
|
135 |
+
double no_div_by_zero = 0.000000001;
|
136 |
+
#pragma unroll
|
137 |
+
for (int it = 0; it < 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it += 1)
|
138 |
+
{
|
139 |
+
intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] += (1.0 / (beta_val + no_div_by_zero)) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val);
|
140 |
+
}
|
141 |
+
|
142 |
+
// Apply replication padding before downsampling conv from intermediates
|
143 |
+
#pragma unroll
|
144 |
+
for (int it = 0; it < DOWNSAMPLE_REPLICATION_PAD_LEFT; it += 1)
|
145 |
+
{
|
146 |
+
intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT];
|
147 |
+
}
|
148 |
+
#pragma unroll
|
149 |
+
for (int it = DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it < DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE + DOWNSAMPLE_REPLICATION_PAD_RIGHT; it += 1)
|
150 |
+
{
|
151 |
+
intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE - 1];
|
152 |
+
}
|
153 |
+
|
154 |
+
// Apply downsample strided convolution (assuming stride=2) from intermediates
|
155 |
+
#pragma unroll
|
156 |
+
for (int it = 0; it < BUFFER_SIZE; it += 1)
|
157 |
+
{
|
158 |
+
input_t acc = 0.0;
|
159 |
+
#pragma unroll
|
160 |
+
for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
|
161 |
+
{
|
162 |
+
// Add constant DOWNSAMPLE_REPLICATION_PAD_RIGHT to match torch implementation
|
163 |
+
acc += down_filter[f_idx] * intermediates[it * 2 + f_idx + DOWNSAMPLE_REPLICATION_PAD_RIGHT];
|
164 |
+
}
|
165 |
+
output[it] = acc;
|
166 |
+
}
|
167 |
+
|
168 |
+
// Write output to dst
|
169 |
+
#pragma unroll
|
170 |
+
for (int it = 0; it < BUFFER_SIZE; it += ELEMENTS_PER_LDG_STG)
|
171 |
+
{
|
172 |
+
int element_index = seq_offset + it;
|
173 |
+
if (element_index < seq_len)
|
174 |
+
{
|
175 |
+
dst[it] = output[it];
|
176 |
+
}
|
177 |
+
}
|
178 |
+
|
179 |
+
}
|
180 |
+
|
181 |
+
template <typename input_t, typename output_t, typename acc_t>
|
182 |
+
void dispatch_anti_alias_activation_forward(
|
183 |
+
output_t *dst,
|
184 |
+
const input_t *src,
|
185 |
+
const input_t *up_ftr,
|
186 |
+
const input_t *down_ftr,
|
187 |
+
const input_t *alpha,
|
188 |
+
const input_t *beta,
|
189 |
+
int batch_size,
|
190 |
+
int channels,
|
191 |
+
int seq_len)
|
192 |
+
{
|
193 |
+
if (seq_len == 0)
|
194 |
+
{
|
195 |
+
return;
|
196 |
+
}
|
197 |
+
else
|
198 |
+
{
|
199 |
+
// Use 128 threads per block to maximimize gpu utilization
|
200 |
+
constexpr int threads_per_block = 128;
|
201 |
+
constexpr int seq_len_per_block = 4096;
|
202 |
+
int blocks_per_seq_len = (seq_len + seq_len_per_block - 1) / seq_len_per_block;
|
203 |
+
dim3 blocks(blocks_per_seq_len, channels, batch_size);
|
204 |
+
dim3 threads(threads_per_block, 1, 1);
|
205 |
+
|
206 |
+
anti_alias_activation_forward<input_t, output_t, acc_t>
|
207 |
+
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, up_ftr, down_ftr, alpha, beta, batch_size, channels, seq_len);
|
208 |
+
}
|
209 |
+
}
|
210 |
+
}
|
211 |
+
|
212 |
+
extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta)
|
213 |
+
{
|
214 |
+
// Input is a 3d tensor with dimensions [batches, channels, seq_len]
|
215 |
+
const int batches = input.size(0);
|
216 |
+
const int channels = input.size(1);
|
217 |
+
const int seq_len = input.size(2);
|
218 |
+
|
219 |
+
// Output
|
220 |
+
auto act_options = input.options().requires_grad(false);
|
221 |
+
|
222 |
+
torch::Tensor anti_alias_activation_results =
|
223 |
+
torch::empty({batches, channels, seq_len}, act_options);
|
224 |
+
|
225 |
+
void *input_ptr = static_cast<void *>(input.data_ptr());
|
226 |
+
void *up_filter_ptr = static_cast<void *>(up_filter.data_ptr());
|
227 |
+
void *down_filter_ptr = static_cast<void *>(down_filter.data_ptr());
|
228 |
+
void *alpha_ptr = static_cast<void *>(alpha.data_ptr());
|
229 |
+
void *beta_ptr = static_cast<void *>(beta.data_ptr());
|
230 |
+
void *anti_alias_activation_results_ptr = static_cast<void *>(anti_alias_activation_results.data_ptr());
|
231 |
+
|
232 |
+
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
233 |
+
input.scalar_type(),
|
234 |
+
"dispatch anti alias activation_forward",
|
235 |
+
dispatch_anti_alias_activation_forward<scalar_t, scalar_t, float>(
|
236 |
+
reinterpret_cast<scalar_t *>(anti_alias_activation_results_ptr),
|
237 |
+
reinterpret_cast<const scalar_t *>(input_ptr),
|
238 |
+
reinterpret_cast<const scalar_t *>(up_filter_ptr),
|
239 |
+
reinterpret_cast<const scalar_t *>(down_filter_ptr),
|
240 |
+
reinterpret_cast<const scalar_t *>(alpha_ptr),
|
241 |
+
reinterpret_cast<const scalar_t *>(beta_ptr),
|
242 |
+
batches,
|
243 |
+
channels,
|
244 |
+
seq_len););
|
245 |
+
return anti_alias_activation_results;
|
246 |
+
}
|
src/third_party/BigVGAN/alias_free_activation/cuda/compat.h
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* coding=utf-8
|
2 |
+
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
/*This code is copied fron NVIDIA apex:
|
18 |
+
* https://github.com/NVIDIA/apex
|
19 |
+
* with minor changes. */
|
20 |
+
|
21 |
+
#ifndef TORCH_CHECK
|
22 |
+
#define TORCH_CHECK AT_CHECK
|
23 |
+
#endif
|
24 |
+
|
25 |
+
#ifdef VERSION_GE_1_3
|
26 |
+
#define DATA_PTR data_ptr
|
27 |
+
#else
|
28 |
+
#define DATA_PTR data
|
29 |
+
#endif
|
src/third_party/BigVGAN/alias_free_activation/cuda/load.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 NVIDIA CORPORATION.
|
2 |
+
# Licensed under the MIT license.
|
3 |
+
|
4 |
+
import os
|
5 |
+
import pathlib
|
6 |
+
import subprocess
|
7 |
+
|
8 |
+
from torch.utils import cpp_extension
|
9 |
+
|
10 |
+
"""
|
11 |
+
Setting this param to a list has a problem of generating different compilation commands (with diferent order of architectures) and leading to recompilation of fused kernels.
|
12 |
+
Set it to empty stringo avoid recompilation and assign arch flags explicity in extra_cuda_cflags below
|
13 |
+
"""
|
14 |
+
os.environ["TORCH_CUDA_ARCH_LIST"] = ""
|
15 |
+
|
16 |
+
|
17 |
+
def load():
|
18 |
+
# Check if cuda 11 is installed for compute capability 8.0
|
19 |
+
cc_flag = []
|
20 |
+
_, bare_metal_major, _ = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
|
21 |
+
if int(bare_metal_major) >= 11:
|
22 |
+
cc_flag.append("-gencode")
|
23 |
+
cc_flag.append("arch=compute_80,code=sm_80")
|
24 |
+
|
25 |
+
# Build path
|
26 |
+
srcpath = pathlib.Path(__file__).parent.absolute()
|
27 |
+
buildpath = srcpath / "build"
|
28 |
+
_create_build_dir(buildpath)
|
29 |
+
|
30 |
+
# Helper function to build the kernels.
|
31 |
+
def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
|
32 |
+
return cpp_extension.load(
|
33 |
+
name=name,
|
34 |
+
sources=sources,
|
35 |
+
build_directory=buildpath,
|
36 |
+
extra_cflags=[
|
37 |
+
"-O3",
|
38 |
+
],
|
39 |
+
extra_cuda_cflags=[
|
40 |
+
"-O3",
|
41 |
+
"-gencode",
|
42 |
+
"arch=compute_70,code=sm_70",
|
43 |
+
"--use_fast_math",
|
44 |
+
]
|
45 |
+
+ extra_cuda_flags
|
46 |
+
+ cc_flag,
|
47 |
+
verbose=True,
|
48 |
+
)
|
49 |
+
|
50 |
+
extra_cuda_flags = [
|
51 |
+
"-U__CUDA_NO_HALF_OPERATORS__",
|
52 |
+
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
53 |
+
"--expt-relaxed-constexpr",
|
54 |
+
"--expt-extended-lambda",
|
55 |
+
]
|
56 |
+
|
57 |
+
sources = [
|
58 |
+
srcpath / "anti_alias_activation.cpp",
|
59 |
+
srcpath / "anti_alias_activation_cuda.cu",
|
60 |
+
]
|
61 |
+
anti_alias_activation_cuda = _cpp_extention_load_helper(
|
62 |
+
"anti_alias_activation_cuda", sources, extra_cuda_flags
|
63 |
+
)
|
64 |
+
|
65 |
+
return anti_alias_activation_cuda
|
66 |
+
|
67 |
+
|
68 |
+
def _get_cuda_bare_metal_version(cuda_dir):
|
69 |
+
raw_output = subprocess.check_output(
|
70 |
+
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
|
71 |
+
)
|
72 |
+
output = raw_output.split()
|
73 |
+
release_idx = output.index("release") + 1
|
74 |
+
release = output[release_idx].split(".")
|
75 |
+
bare_metal_major = release[0]
|
76 |
+
bare_metal_minor = release[1][0]
|
77 |
+
|
78 |
+
return raw_output, bare_metal_major, bare_metal_minor
|
79 |
+
|
80 |
+
|
81 |
+
def _create_build_dir(buildpath):
|
82 |
+
try:
|
83 |
+
os.mkdir(buildpath)
|
84 |
+
except OSError:
|
85 |
+
if not os.path.isdir(buildpath):
|
86 |
+
print(f"Creation of the build directory {buildpath} failed")
|
src/third_party/BigVGAN/alias_free_activation/cuda/type_shim.h
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* coding=utf-8
|
2 |
+
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
#include <ATen/ATen.h>
|
18 |
+
#include "compat.h"
|
19 |
+
|
20 |
+
#define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...) \
|
21 |
+
switch (TYPE) \
|
22 |
+
{ \
|
23 |
+
case at::ScalarType::Float: \
|
24 |
+
{ \
|
25 |
+
using scalar_t = float; \
|
26 |
+
__VA_ARGS__; \
|
27 |
+
break; \
|
28 |
+
} \
|
29 |
+
case at::ScalarType::Half: \
|
30 |
+
{ \
|
31 |
+
using scalar_t = at::Half; \
|
32 |
+
__VA_ARGS__; \
|
33 |
+
break; \
|
34 |
+
} \
|
35 |
+
case at::ScalarType::BFloat16: \
|
36 |
+
{ \
|
37 |
+
using scalar_t = at::BFloat16; \
|
38 |
+
__VA_ARGS__; \
|
39 |
+
break; \
|
40 |
+
} \
|
41 |
+
default: \
|
42 |
+
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
43 |
+
}
|
44 |
+
|
45 |
+
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
|
46 |
+
switch (TYPEIN) \
|
47 |
+
{ \
|
48 |
+
case at::ScalarType::Float: \
|
49 |
+
{ \
|
50 |
+
using scalar_t_in = float; \
|
51 |
+
switch (TYPEOUT) \
|
52 |
+
{ \
|
53 |
+
case at::ScalarType::Float: \
|
54 |
+
{ \
|
55 |
+
using scalar_t_out = float; \
|
56 |
+
__VA_ARGS__; \
|
57 |
+
break; \
|
58 |
+
} \
|
59 |
+
case at::ScalarType::Half: \
|
60 |
+
{ \
|
61 |
+
using scalar_t_out = at::Half; \
|
62 |
+
__VA_ARGS__; \
|
63 |
+
break; \
|
64 |
+
} \
|
65 |
+
case at::ScalarType::BFloat16: \
|
66 |
+
{ \
|
67 |
+
using scalar_t_out = at::BFloat16; \
|
68 |
+
__VA_ARGS__; \
|
69 |
+
break; \
|
70 |
+
} \
|
71 |
+
default: \
|
72 |
+
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
|
73 |
+
} \
|
74 |
+
break; \
|
75 |
+
} \
|
76 |
+
case at::ScalarType::Half: \
|
77 |
+
{ \
|
78 |
+
using scalar_t_in = at::Half; \
|
79 |
+
using scalar_t_out = at::Half; \
|
80 |
+
__VA_ARGS__; \
|
81 |
+
break; \
|
82 |
+
} \
|
83 |
+
case at::ScalarType::BFloat16: \
|
84 |
+
{ \
|
85 |
+
using scalar_t_in = at::BFloat16; \
|
86 |
+
using scalar_t_out = at::BFloat16; \
|
87 |
+
__VA_ARGS__; \
|
88 |
+
break; \
|
89 |
+
} \
|
90 |
+
default: \
|
91 |
+
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
|
92 |
+
}
|
src/third_party/BigVGAN/alias_free_activation/torch/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
# LICENSE is in incl_licenses directory.
|
3 |
+
|
4 |
+
from .filter import *
|
5 |
+
from .resample import *
|
6 |
+
from .act import *
|
src/third_party/BigVGAN/alias_free_activation/torch/act.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
# LICENSE is in incl_licenses directory.
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
from alias_free_activation.torch.resample import UpSample1d, DownSample1d
|
6 |
+
|
7 |
+
|
8 |
+
class Activation1d(nn.Module):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
activation,
|
12 |
+
up_ratio: int = 2,
|
13 |
+
down_ratio: int = 2,
|
14 |
+
up_kernel_size: int = 12,
|
15 |
+
down_kernel_size: int = 12,
|
16 |
+
):
|
17 |
+
super().__init__()
|
18 |
+
self.up_ratio = up_ratio
|
19 |
+
self.down_ratio = down_ratio
|
20 |
+
self.act = activation
|
21 |
+
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
22 |
+
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
23 |
+
|
24 |
+
# x: [B,C,T]
|
25 |
+
def forward(self, x):
|
26 |
+
x = self.upsample(x)
|
27 |
+
x = self.act(x)
|
28 |
+
x = self.downsample(x)
|
29 |
+
|
30 |
+
return x
|
src/third_party/BigVGAN/alias_free_activation/torch/filter.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
# LICENSE is in incl_licenses directory.
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import math
|
8 |
+
|
9 |
+
if "sinc" in dir(torch):
|
10 |
+
sinc = torch.sinc
|
11 |
+
else:
|
12 |
+
# This code is adopted from adefossez's julius.core.sinc under the MIT License
|
13 |
+
# https://adefossez.github.io/julius/julius/core.html
|
14 |
+
# LICENSE is in incl_licenses directory.
|
15 |
+
def sinc(x: torch.Tensor):
|
16 |
+
"""
|
17 |
+
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
|
18 |
+
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
|
19 |
+
"""
|
20 |
+
return torch.where(
|
21 |
+
x == 0,
|
22 |
+
torch.tensor(1.0, device=x.device, dtype=x.dtype),
|
23 |
+
torch.sin(math.pi * x) / math.pi / x,
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
|
28 |
+
# https://adefossez.github.io/julius/julius/lowpass.html
|
29 |
+
# LICENSE is in incl_licenses directory.
|
30 |
+
def kaiser_sinc_filter1d(
|
31 |
+
cutoff, half_width, kernel_size
|
32 |
+
): # return filter [1,1,kernel_size]
|
33 |
+
even = kernel_size % 2 == 0
|
34 |
+
half_size = kernel_size // 2
|
35 |
+
|
36 |
+
# For kaiser window
|
37 |
+
delta_f = 4 * half_width
|
38 |
+
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
39 |
+
if A > 50.0:
|
40 |
+
beta = 0.1102 * (A - 8.7)
|
41 |
+
elif A >= 21.0:
|
42 |
+
beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
|
43 |
+
else:
|
44 |
+
beta = 0.0
|
45 |
+
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
46 |
+
|
47 |
+
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
|
48 |
+
if even:
|
49 |
+
time = torch.arange(-half_size, half_size) + 0.5
|
50 |
+
else:
|
51 |
+
time = torch.arange(kernel_size) - half_size
|
52 |
+
if cutoff == 0:
|
53 |
+
filter_ = torch.zeros_like(time)
|
54 |
+
else:
|
55 |
+
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
|
56 |
+
"""
|
57 |
+
Normalize filter to have sum = 1, otherwise we will have a small leakage of the constant component in the input signal.
|
58 |
+
"""
|
59 |
+
filter_ /= filter_.sum()
|
60 |
+
filter = filter_.view(1, 1, kernel_size)
|
61 |
+
|
62 |
+
return filter
|
63 |
+
|
64 |
+
|
65 |
+
class LowPassFilter1d(nn.Module):
|
66 |
+
def __init__(
|
67 |
+
self,
|
68 |
+
cutoff=0.5,
|
69 |
+
half_width=0.6,
|
70 |
+
stride: int = 1,
|
71 |
+
padding: bool = True,
|
72 |
+
padding_mode: str = "replicate",
|
73 |
+
kernel_size: int = 12,
|
74 |
+
):
|
75 |
+
"""
|
76 |
+
kernel_size should be even number for stylegan3 setup, in this implementation, odd number is also possible.
|
77 |
+
"""
|
78 |
+
super().__init__()
|
79 |
+
if cutoff < -0.0:
|
80 |
+
raise ValueError("Minimum cutoff must be larger than zero.")
|
81 |
+
if cutoff > 0.5:
|
82 |
+
raise ValueError("A cutoff above 0.5 does not make sense.")
|
83 |
+
self.kernel_size = kernel_size
|
84 |
+
self.even = kernel_size % 2 == 0
|
85 |
+
self.pad_left = kernel_size // 2 - int(self.even)
|
86 |
+
self.pad_right = kernel_size // 2
|
87 |
+
self.stride = stride
|
88 |
+
self.padding = padding
|
89 |
+
self.padding_mode = padding_mode
|
90 |
+
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
91 |
+
self.register_buffer("filter", filter)
|
92 |
+
|
93 |
+
# Input [B, C, T]
|
94 |
+
def forward(self, x):
|
95 |
+
_, C, _ = x.shape
|
96 |
+
|
97 |
+
if self.padding:
|
98 |
+
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
|
99 |
+
out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
100 |
+
|
101 |
+
return out
|
src/third_party/BigVGAN/alias_free_activation/torch/resample.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
# LICENSE is in incl_licenses directory.
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
from alias_free_activation.torch.filter import LowPassFilter1d
|
7 |
+
from alias_free_activation.torch.filter import kaiser_sinc_filter1d
|
8 |
+
|
9 |
+
|
10 |
+
class UpSample1d(nn.Module):
|
11 |
+
def __init__(self, ratio=2, kernel_size=None):
|
12 |
+
super().__init__()
|
13 |
+
self.ratio = ratio
|
14 |
+
self.kernel_size = (
|
15 |
+
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
16 |
+
)
|
17 |
+
self.stride = ratio
|
18 |
+
self.pad = self.kernel_size // ratio - 1
|
19 |
+
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
20 |
+
self.pad_right = (
|
21 |
+
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
22 |
+
)
|
23 |
+
filter = kaiser_sinc_filter1d(
|
24 |
+
cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
|
25 |
+
)
|
26 |
+
self.register_buffer("filter", filter)
|
27 |
+
|
28 |
+
# x: [B, C, T]
|
29 |
+
def forward(self, x):
|
30 |
+
_, C, _ = x.shape
|
31 |
+
|
32 |
+
x = F.pad(x, (self.pad, self.pad), mode="replicate")
|
33 |
+
x = self.ratio * F.conv_transpose1d(
|
34 |
+
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
|
35 |
+
)
|
36 |
+
x = x[..., self.pad_left : -self.pad_right]
|
37 |
+
|
38 |
+
return x
|
39 |
+
|
40 |
+
|
41 |
+
class DownSample1d(nn.Module):
|
42 |
+
def __init__(self, ratio=2, kernel_size=None):
|
43 |
+
super().__init__()
|
44 |
+
self.ratio = ratio
|
45 |
+
self.kernel_size = (
|
46 |
+
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
47 |
+
)
|
48 |
+
self.lowpass = LowPassFilter1d(
|
49 |
+
cutoff=0.5 / ratio,
|
50 |
+
half_width=0.6 / ratio,
|
51 |
+
stride=ratio,
|
52 |
+
kernel_size=self.kernel_size,
|
53 |
+
)
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
xx = self.lowpass(x)
|
57 |
+
|
58 |
+
return xx
|
src/third_party/BigVGAN/bigvgan.py
ADDED
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 NVIDIA CORPORATION.
|
2 |
+
# Licensed under the MIT license.
|
3 |
+
|
4 |
+
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
5 |
+
# LICENSE is in incl_licenses directory.
|
6 |
+
|
7 |
+
import os
|
8 |
+
import json
|
9 |
+
from pathlib import Path
|
10 |
+
from typing import Optional, Union, Dict
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
from torch.nn import Conv1d, ConvTranspose1d
|
15 |
+
from torch.nn.utils import weight_norm, remove_weight_norm
|
16 |
+
|
17 |
+
import activations
|
18 |
+
from utils import init_weights, get_padding
|
19 |
+
from alias_free_activation.torch.act import Activation1d as TorchActivation1d
|
20 |
+
from env import AttrDict
|
21 |
+
|
22 |
+
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
|
23 |
+
|
24 |
+
|
25 |
+
def load_hparams_from_json(path) -> AttrDict:
|
26 |
+
with open(path) as f:
|
27 |
+
data = f.read()
|
28 |
+
return AttrDict(json.loads(data))
|
29 |
+
|
30 |
+
|
31 |
+
class AMPBlock1(torch.nn.Module):
|
32 |
+
"""
|
33 |
+
AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
|
34 |
+
AMPBlock1 has additional self.convs2 that contains additional Conv1d layers with a fixed dilation=1 followed by each layer in self.convs1
|
35 |
+
|
36 |
+
Args:
|
37 |
+
h (AttrDict): Hyperparameters.
|
38 |
+
channels (int): Number of convolution channels.
|
39 |
+
kernel_size (int): Size of the convolution kernel. Default is 3.
|
40 |
+
dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).
|
41 |
+
activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.
|
42 |
+
"""
|
43 |
+
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
h: AttrDict,
|
47 |
+
channels: int,
|
48 |
+
kernel_size: int = 3,
|
49 |
+
dilation: tuple = (1, 3, 5),
|
50 |
+
activation: str = None,
|
51 |
+
):
|
52 |
+
super().__init__()
|
53 |
+
|
54 |
+
self.h = h
|
55 |
+
|
56 |
+
self.convs1 = nn.ModuleList(
|
57 |
+
[
|
58 |
+
weight_norm(
|
59 |
+
Conv1d(
|
60 |
+
channels,
|
61 |
+
channels,
|
62 |
+
kernel_size,
|
63 |
+
stride=1,
|
64 |
+
dilation=d,
|
65 |
+
padding=get_padding(kernel_size, d),
|
66 |
+
)
|
67 |
+
)
|
68 |
+
for d in dilation
|
69 |
+
]
|
70 |
+
)
|
71 |
+
self.convs1.apply(init_weights)
|
72 |
+
|
73 |
+
self.convs2 = nn.ModuleList(
|
74 |
+
[
|
75 |
+
weight_norm(
|
76 |
+
Conv1d(
|
77 |
+
channels,
|
78 |
+
channels,
|
79 |
+
kernel_size,
|
80 |
+
stride=1,
|
81 |
+
dilation=1,
|
82 |
+
padding=get_padding(kernel_size, 1),
|
83 |
+
)
|
84 |
+
)
|
85 |
+
for _ in range(len(dilation))
|
86 |
+
]
|
87 |
+
)
|
88 |
+
self.convs2.apply(init_weights)
|
89 |
+
|
90 |
+
self.num_layers = len(self.convs1) + len(
|
91 |
+
self.convs2
|
92 |
+
) # Total number of conv layers
|
93 |
+
|
94 |
+
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
95 |
+
if self.h.get("use_cuda_kernel", False):
|
96 |
+
from alias_free_activation.cuda.activation1d import (
|
97 |
+
Activation1d as CudaActivation1d,
|
98 |
+
)
|
99 |
+
|
100 |
+
Activation1d = CudaActivation1d
|
101 |
+
else:
|
102 |
+
Activation1d = TorchActivation1d
|
103 |
+
|
104 |
+
# Activation functions
|
105 |
+
if activation == "snake":
|
106 |
+
self.activations = nn.ModuleList(
|
107 |
+
[
|
108 |
+
Activation1d(
|
109 |
+
activation=activations.Snake(
|
110 |
+
channels, alpha_logscale=h.snake_logscale
|
111 |
+
)
|
112 |
+
)
|
113 |
+
for _ in range(self.num_layers)
|
114 |
+
]
|
115 |
+
)
|
116 |
+
elif activation == "snakebeta":
|
117 |
+
self.activations = nn.ModuleList(
|
118 |
+
[
|
119 |
+
Activation1d(
|
120 |
+
activation=activations.SnakeBeta(
|
121 |
+
channels, alpha_logscale=h.snake_logscale
|
122 |
+
)
|
123 |
+
)
|
124 |
+
for _ in range(self.num_layers)
|
125 |
+
]
|
126 |
+
)
|
127 |
+
else:
|
128 |
+
raise NotImplementedError(
|
129 |
+
"activation incorrectly specified. check the config file and look for 'activation'."
|
130 |
+
)
|
131 |
+
|
132 |
+
def forward(self, x):
|
133 |
+
acts1, acts2 = self.activations[::2], self.activations[1::2]
|
134 |
+
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
|
135 |
+
xt = a1(x)
|
136 |
+
xt = c1(xt)
|
137 |
+
xt = a2(xt)
|
138 |
+
xt = c2(xt)
|
139 |
+
x = xt + x
|
140 |
+
|
141 |
+
return x
|
142 |
+
|
143 |
+
def remove_weight_norm(self):
|
144 |
+
for l in self.convs1:
|
145 |
+
remove_weight_norm(l)
|
146 |
+
for l in self.convs2:
|
147 |
+
remove_weight_norm(l)
|
148 |
+
|
149 |
+
|
150 |
+
class AMPBlock2(torch.nn.Module):
|
151 |
+
"""
|
152 |
+
AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
|
153 |
+
Unlike AMPBlock1, AMPBlock2 does not contain extra Conv1d layers with fixed dilation=1
|
154 |
+
|
155 |
+
Args:
|
156 |
+
h (AttrDict): Hyperparameters.
|
157 |
+
channels (int): Number of convolution channels.
|
158 |
+
kernel_size (int): Size of the convolution kernel. Default is 3.
|
159 |
+
dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).
|
160 |
+
activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.
|
161 |
+
"""
|
162 |
+
|
163 |
+
def __init__(
|
164 |
+
self,
|
165 |
+
h: AttrDict,
|
166 |
+
channels: int,
|
167 |
+
kernel_size: int = 3,
|
168 |
+
dilation: tuple = (1, 3, 5),
|
169 |
+
activation: str = None,
|
170 |
+
):
|
171 |
+
super().__init__()
|
172 |
+
|
173 |
+
self.h = h
|
174 |
+
|
175 |
+
self.convs = nn.ModuleList(
|
176 |
+
[
|
177 |
+
weight_norm(
|
178 |
+
Conv1d(
|
179 |
+
channels,
|
180 |
+
channels,
|
181 |
+
kernel_size,
|
182 |
+
stride=1,
|
183 |
+
dilation=d,
|
184 |
+
padding=get_padding(kernel_size, d),
|
185 |
+
)
|
186 |
+
)
|
187 |
+
for d in dilation
|
188 |
+
]
|
189 |
+
)
|
190 |
+
self.convs.apply(init_weights)
|
191 |
+
|
192 |
+
self.num_layers = len(self.convs) # Total number of conv layers
|
193 |
+
|
194 |
+
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
195 |
+
if self.h.get("use_cuda_kernel", False):
|
196 |
+
from alias_free_activation.cuda.activation1d import (
|
197 |
+
Activation1d as CudaActivation1d,
|
198 |
+
)
|
199 |
+
|
200 |
+
Activation1d = CudaActivation1d
|
201 |
+
else:
|
202 |
+
Activation1d = TorchActivation1d
|
203 |
+
|
204 |
+
# Activation functions
|
205 |
+
if activation == "snake":
|
206 |
+
self.activations = nn.ModuleList(
|
207 |
+
[
|
208 |
+
Activation1d(
|
209 |
+
activation=activations.Snake(
|
210 |
+
channels, alpha_logscale=h.snake_logscale
|
211 |
+
)
|
212 |
+
)
|
213 |
+
for _ in range(self.num_layers)
|
214 |
+
]
|
215 |
+
)
|
216 |
+
elif activation == "snakebeta":
|
217 |
+
self.activations = nn.ModuleList(
|
218 |
+
[
|
219 |
+
Activation1d(
|
220 |
+
activation=activations.SnakeBeta(
|
221 |
+
channels, alpha_logscale=h.snake_logscale
|
222 |
+
)
|
223 |
+
)
|
224 |
+
for _ in range(self.num_layers)
|
225 |
+
]
|
226 |
+
)
|
227 |
+
else:
|
228 |
+
raise NotImplementedError(
|
229 |
+
"activation incorrectly specified. check the config file and look for 'activation'."
|
230 |
+
)
|
231 |
+
|
232 |
+
def forward(self, x):
|
233 |
+
for c, a in zip(self.convs, self.activations):
|
234 |
+
xt = a(x)
|
235 |
+
xt = c(xt)
|
236 |
+
x = xt + x
|
237 |
+
return x
|
238 |
+
|
239 |
+
def remove_weight_norm(self):
|
240 |
+
for l in self.convs:
|
241 |
+
remove_weight_norm(l)
|
242 |
+
|
243 |
+
|
244 |
+
class BigVGAN(
|
245 |
+
torch.nn.Module,
|
246 |
+
PyTorchModelHubMixin,
|
247 |
+
library_name="bigvgan",
|
248 |
+
repo_url="https://github.com/NVIDIA/BigVGAN",
|
249 |
+
docs_url="https://github.com/NVIDIA/BigVGAN/blob/main/README.md",
|
250 |
+
pipeline_tag="audio-to-audio",
|
251 |
+
license="mit",
|
252 |
+
tags=["neural-vocoder", "audio-generation", "arxiv:2206.04658"],
|
253 |
+
):
|
254 |
+
"""
|
255 |
+
BigVGAN is a neural vocoder model that applies anti-aliased periodic activation for residual blocks (resblocks).
|
256 |
+
New in BigVGAN-v2: it can optionally use optimized CUDA kernels for AMP (anti-aliased multi-periodicity) blocks.
|
257 |
+
|
258 |
+
Args:
|
259 |
+
h (AttrDict): Hyperparameters.
|
260 |
+
use_cuda_kernel (bool): If set to True, loads optimized CUDA kernels for AMP. This should be used for inference only, as training is not supported with CUDA kernels.
|
261 |
+
|
262 |
+
Note:
|
263 |
+
- The `use_cuda_kernel` parameter should be used for inference only, as training with CUDA kernels is not supported.
|
264 |
+
- Ensure that the activation function is correctly specified in the hyperparameters (h.activation).
|
265 |
+
"""
|
266 |
+
|
267 |
+
def __init__(self, h: AttrDict, use_cuda_kernel: bool = False):
|
268 |
+
super().__init__()
|
269 |
+
self.h = h
|
270 |
+
self.h["use_cuda_kernel"] = use_cuda_kernel
|
271 |
+
|
272 |
+
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
273 |
+
if self.h.get("use_cuda_kernel", False):
|
274 |
+
from alias_free_activation.cuda.activation1d import (
|
275 |
+
Activation1d as CudaActivation1d,
|
276 |
+
)
|
277 |
+
|
278 |
+
Activation1d = CudaActivation1d
|
279 |
+
else:
|
280 |
+
Activation1d = TorchActivation1d
|
281 |
+
|
282 |
+
self.num_kernels = len(h.resblock_kernel_sizes)
|
283 |
+
self.num_upsamples = len(h.upsample_rates)
|
284 |
+
|
285 |
+
# Pre-conv
|
286 |
+
self.conv_pre = weight_norm(
|
287 |
+
Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
|
288 |
+
)
|
289 |
+
|
290 |
+
# Define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
|
291 |
+
if h.resblock == "1":
|
292 |
+
resblock_class = AMPBlock1
|
293 |
+
elif h.resblock == "2":
|
294 |
+
resblock_class = AMPBlock2
|
295 |
+
else:
|
296 |
+
raise ValueError(
|
297 |
+
f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}"
|
298 |
+
)
|
299 |
+
|
300 |
+
# Transposed conv-based upsamplers. does not apply anti-aliasing
|
301 |
+
self.ups = nn.ModuleList()
|
302 |
+
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
303 |
+
self.ups.append(
|
304 |
+
nn.ModuleList(
|
305 |
+
[
|
306 |
+
weight_norm(
|
307 |
+
ConvTranspose1d(
|
308 |
+
h.upsample_initial_channel // (2**i),
|
309 |
+
h.upsample_initial_channel // (2 ** (i + 1)),
|
310 |
+
k,
|
311 |
+
u,
|
312 |
+
padding=(k - u) // 2,
|
313 |
+
)
|
314 |
+
)
|
315 |
+
]
|
316 |
+
)
|
317 |
+
)
|
318 |
+
|
319 |
+
# Residual blocks using anti-aliased multi-periodicity composition modules (AMP)
|
320 |
+
self.resblocks = nn.ModuleList()
|
321 |
+
for i in range(len(self.ups)):
|
322 |
+
ch = h.upsample_initial_channel // (2 ** (i + 1))
|
323 |
+
for j, (k, d) in enumerate(
|
324 |
+
zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
|
325 |
+
):
|
326 |
+
self.resblocks.append(
|
327 |
+
resblock_class(h, ch, k, d, activation=h.activation)
|
328 |
+
)
|
329 |
+
|
330 |
+
# Post-conv
|
331 |
+
activation_post = (
|
332 |
+
activations.Snake(ch, alpha_logscale=h.snake_logscale)
|
333 |
+
if h.activation == "snake"
|
334 |
+
else (
|
335 |
+
activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
|
336 |
+
if h.activation == "snakebeta"
|
337 |
+
else None
|
338 |
+
)
|
339 |
+
)
|
340 |
+
if activation_post is None:
|
341 |
+
raise NotImplementedError(
|
342 |
+
"activation incorrectly specified. check the config file and look for 'activation'."
|
343 |
+
)
|
344 |
+
|
345 |
+
self.activation_post = Activation1d(activation=activation_post)
|
346 |
+
|
347 |
+
# Whether to use bias for the final conv_post. Default to True for backward compatibility
|
348 |
+
self.use_bias_at_final = h.get("use_bias_at_final", True)
|
349 |
+
self.conv_post = weight_norm(
|
350 |
+
Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final)
|
351 |
+
)
|
352 |
+
|
353 |
+
# Weight initialization
|
354 |
+
for i in range(len(self.ups)):
|
355 |
+
self.ups[i].apply(init_weights)
|
356 |
+
self.conv_post.apply(init_weights)
|
357 |
+
|
358 |
+
# Final tanh activation. Defaults to True for backward compatibility
|
359 |
+
self.use_tanh_at_final = h.get("use_tanh_at_final", True)
|
360 |
+
|
361 |
+
def forward(self, x):
|
362 |
+
# Pre-conv
|
363 |
+
x = self.conv_pre(x)
|
364 |
+
|
365 |
+
for i in range(self.num_upsamples):
|
366 |
+
# Upsampling
|
367 |
+
for i_up in range(len(self.ups[i])):
|
368 |
+
x = self.ups[i][i_up](x)
|
369 |
+
# AMP blocks
|
370 |
+
xs = None
|
371 |
+
for j in range(self.num_kernels):
|
372 |
+
if xs is None:
|
373 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
374 |
+
else:
|
375 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
376 |
+
x = xs / self.num_kernels
|
377 |
+
|
378 |
+
# Post-conv
|
379 |
+
x = self.activation_post(x)
|
380 |
+
x = self.conv_post(x)
|
381 |
+
# Final tanh activation
|
382 |
+
if self.use_tanh_at_final:
|
383 |
+
x = torch.tanh(x)
|
384 |
+
else:
|
385 |
+
x = torch.clamp(x, min=-1.0, max=1.0) # Bound the output to [-1, 1]
|
386 |
+
|
387 |
+
return x
|
388 |
+
|
389 |
+
def remove_weight_norm(self):
|
390 |
+
try:
|
391 |
+
print("Removing weight norm...")
|
392 |
+
for l in self.ups:
|
393 |
+
for l_i in l:
|
394 |
+
remove_weight_norm(l_i)
|
395 |
+
for l in self.resblocks:
|
396 |
+
l.remove_weight_norm()
|
397 |
+
remove_weight_norm(self.conv_pre)
|
398 |
+
remove_weight_norm(self.conv_post)
|
399 |
+
except ValueError:
|
400 |
+
print("[INFO] Model already removed weight norm. Skipping!")
|
401 |
+
pass
|
402 |
+
|
403 |
+
# Additional methods for huggingface_hub support
|
404 |
+
def _save_pretrained(self, save_directory: Path) -> None:
|
405 |
+
"""Save weights and config.json from a Pytorch model to a local directory."""
|
406 |
+
|
407 |
+
model_path = save_directory / "bigvgan_generator.pt"
|
408 |
+
torch.save({"generator": self.state_dict()}, model_path)
|
409 |
+
|
410 |
+
config_path = save_directory / "config.json"
|
411 |
+
with open(config_path, "w") as config_file:
|
412 |
+
json.dump(self.h, config_file, indent=4)
|
413 |
+
|
414 |
+
@classmethod
|
415 |
+
def _from_pretrained(
|
416 |
+
cls,
|
417 |
+
*,
|
418 |
+
model_id: str,
|
419 |
+
revision: str,
|
420 |
+
cache_dir: str,
|
421 |
+
force_download: bool,
|
422 |
+
proxies: Optional[Dict],
|
423 |
+
resume_download: bool,
|
424 |
+
local_files_only: bool,
|
425 |
+
token: Union[str, bool, None],
|
426 |
+
map_location: str = "cpu", # Additional argument
|
427 |
+
strict: bool = False, # Additional argument
|
428 |
+
use_cuda_kernel: bool = False,
|
429 |
+
**model_kwargs,
|
430 |
+
):
|
431 |
+
"""Load Pytorch pretrained weights and return the loaded model."""
|
432 |
+
|
433 |
+
# Download and load hyperparameters (h) used by BigVGAN
|
434 |
+
if os.path.isdir(model_id):
|
435 |
+
print("Loading config.json from local directory")
|
436 |
+
config_file = os.path.join(model_id, "config.json")
|
437 |
+
else:
|
438 |
+
config_file = hf_hub_download(
|
439 |
+
repo_id=model_id,
|
440 |
+
filename="config.json",
|
441 |
+
revision=revision,
|
442 |
+
cache_dir=cache_dir,
|
443 |
+
force_download=force_download,
|
444 |
+
proxies=proxies,
|
445 |
+
resume_download=resume_download,
|
446 |
+
token=token,
|
447 |
+
local_files_only=local_files_only,
|
448 |
+
)
|
449 |
+
h = load_hparams_from_json(config_file)
|
450 |
+
|
451 |
+
# instantiate BigVGAN using h
|
452 |
+
if use_cuda_kernel:
|
453 |
+
print(
|
454 |
+
f"[WARNING] You have specified use_cuda_kernel=True during BigVGAN.from_pretrained(). Only inference is supported (training is not implemented)!"
|
455 |
+
)
|
456 |
+
print(
|
457 |
+
f"[WARNING] You need nvcc and ninja installed in your system that matches your PyTorch build is using to build the kernel. If not, the model will fail to initialize or generate incorrect waveform!"
|
458 |
+
)
|
459 |
+
print(
|
460 |
+
f"[WARNING] For detail, see the official GitHub repository: https://github.com/NVIDIA/BigVGAN?tab=readme-ov-file#using-custom-cuda-kernel-for-synthesis"
|
461 |
+
)
|
462 |
+
model = cls(h, use_cuda_kernel=use_cuda_kernel)
|
463 |
+
|
464 |
+
# Download and load pretrained generator weight
|
465 |
+
if os.path.isdir(model_id):
|
466 |
+
print("Loading weights from local directory")
|
467 |
+
model_file = os.path.join(model_id, "bigvgan_generator.pt")
|
468 |
+
else:
|
469 |
+
print(f"Loading weights from {model_id}")
|
470 |
+
model_file = hf_hub_download(
|
471 |
+
repo_id=model_id,
|
472 |
+
filename="bigvgan_generator.pt",
|
473 |
+
revision=revision,
|
474 |
+
cache_dir=cache_dir,
|
475 |
+
force_download=force_download,
|
476 |
+
proxies=proxies,
|
477 |
+
resume_download=resume_download,
|
478 |
+
token=token,
|
479 |
+
local_files_only=local_files_only,
|
480 |
+
)
|
481 |
+
|
482 |
+
checkpoint_dict = torch.load(model_file, map_location=map_location)
|
483 |
+
|
484 |
+
try:
|
485 |
+
model.load_state_dict(checkpoint_dict["generator"])
|
486 |
+
except RuntimeError:
|
487 |
+
print(
|
488 |
+
f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
|
489 |
+
)
|
490 |
+
model.remove_weight_norm()
|
491 |
+
model.load_state_dict(checkpoint_dict["generator"])
|
492 |
+
|
493 |
+
return model
|
src/third_party/BigVGAN/configs/bigvgan_22khz_80band.json
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"resblock": "1",
|
3 |
+
"num_gpus": 0,
|
4 |
+
"batch_size": 32,
|
5 |
+
"learning_rate": 0.0001,
|
6 |
+
"adam_b1": 0.8,
|
7 |
+
"adam_b2": 0.99,
|
8 |
+
"lr_decay": 0.9999996,
|
9 |
+
"seed": 1234,
|
10 |
+
|
11 |
+
"upsample_rates": [4,4,2,2,2,2],
|
12 |
+
"upsample_kernel_sizes": [8,8,4,4,4,4],
|
13 |
+
"upsample_initial_channel": 1536,
|
14 |
+
"resblock_kernel_sizes": [3,7,11],
|
15 |
+
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
16 |
+
|
17 |
+
"activation": "snakebeta",
|
18 |
+
"snake_logscale": true,
|
19 |
+
|
20 |
+
"resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]],
|
21 |
+
"mpd_reshapes": [2, 3, 5, 7, 11],
|
22 |
+
"use_spectral_norm": false,
|
23 |
+
"discriminator_channel_mult": 1,
|
24 |
+
|
25 |
+
"segment_size": 8192,
|
26 |
+
"num_mels": 80,
|
27 |
+
"num_freq": 1025,
|
28 |
+
"n_fft": 1024,
|
29 |
+
"hop_size": 256,
|
30 |
+
"win_size": 1024,
|
31 |
+
|
32 |
+
"sampling_rate": 22050,
|
33 |
+
|
34 |
+
"fmin": 0,
|
35 |
+
"fmax": 8000,
|
36 |
+
"fmax_for_loss": null,
|
37 |
+
|
38 |
+
"num_workers": 4,
|
39 |
+
|
40 |
+
"dist_config": {
|
41 |
+
"dist_backend": "nccl",
|
42 |
+
"dist_url": "tcp://localhost:54321",
|
43 |
+
"world_size": 1
|
44 |
+
}
|
45 |
+
}
|
src/third_party/BigVGAN/configs/bigvgan_24khz_100band.json
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"resblock": "1",
|
3 |
+
"num_gpus": 0,
|
4 |
+
"batch_size": 32,
|
5 |
+
"learning_rate": 0.0001,
|
6 |
+
"adam_b1": 0.8,
|
7 |
+
"adam_b2": 0.99,
|
8 |
+
"lr_decay": 0.9999996,
|
9 |
+
"seed": 1234,
|
10 |
+
|
11 |
+
"upsample_rates": [4,4,2,2,2,2],
|
12 |
+
"upsample_kernel_sizes": [8,8,4,4,4,4],
|
13 |
+
"upsample_initial_channel": 1536,
|
14 |
+
"resblock_kernel_sizes": [3,7,11],
|
15 |
+
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
16 |
+
|
17 |
+
"activation": "snakebeta",
|
18 |
+
"snake_logscale": true,
|
19 |
+
|
20 |
+
"resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]],
|
21 |
+
"mpd_reshapes": [2, 3, 5, 7, 11],
|
22 |
+
"use_spectral_norm": false,
|
23 |
+
"discriminator_channel_mult": 1,
|
24 |
+
|
25 |
+
"segment_size": 8192,
|
26 |
+
"num_mels": 100,
|
27 |
+
"num_freq": 1025,
|
28 |
+
"n_fft": 1024,
|
29 |
+
"hop_size": 256,
|
30 |
+
"win_size": 1024,
|
31 |
+
|
32 |
+
"sampling_rate": 24000,
|
33 |
+
|
34 |
+
"fmin": 0,
|
35 |
+
"fmax": 12000,
|
36 |
+
"fmax_for_loss": null,
|
37 |
+
|
38 |
+
"num_workers": 4,
|
39 |
+
|
40 |
+
"dist_config": {
|
41 |
+
"dist_backend": "nccl",
|
42 |
+
"dist_url": "tcp://localhost:54321",
|
43 |
+
"world_size": 1
|
44 |
+
}
|
45 |
+
}
|
src/third_party/BigVGAN/configs/bigvgan_base_22khz_80band.json
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"resblock": "1",
|
3 |
+
"num_gpus": 0,
|
4 |
+
"batch_size": 32,
|
5 |
+
"learning_rate": 0.0001,
|
6 |
+
"adam_b1": 0.8,
|
7 |
+
"adam_b2": 0.99,
|
8 |
+
"lr_decay": 0.9999996,
|
9 |
+
"seed": 1234,
|
10 |
+
|
11 |
+
"upsample_rates": [8,8,2,2],
|
12 |
+
"upsample_kernel_sizes": [16,16,4,4],
|
13 |
+
"upsample_initial_channel": 512,
|
14 |
+
"resblock_kernel_sizes": [3,7,11],
|
15 |
+
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
16 |
+
|
17 |
+
"activation": "snakebeta",
|
18 |
+
"snake_logscale": true,
|
19 |
+
|
20 |
+
"resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]],
|
21 |
+
"mpd_reshapes": [2, 3, 5, 7, 11],
|
22 |
+
"use_spectral_norm": false,
|
23 |
+
"discriminator_channel_mult": 1,
|
24 |
+
|
25 |
+
"segment_size": 8192,
|
26 |
+
"num_mels": 80,
|
27 |
+
"num_freq": 1025,
|
28 |
+
"n_fft": 1024,
|
29 |
+
"hop_size": 256,
|
30 |
+
"win_size": 1024,
|
31 |
+
|
32 |
+
"sampling_rate": 22050,
|
33 |
+
|
34 |
+
"fmin": 0,
|
35 |
+
"fmax": 8000,
|
36 |
+
"fmax_for_loss": null,
|
37 |
+
|
38 |
+
"num_workers": 4,
|
39 |
+
|
40 |
+
"dist_config": {
|
41 |
+
"dist_backend": "nccl",
|
42 |
+
"dist_url": "tcp://localhost:54321",
|
43 |
+
"world_size": 1
|
44 |
+
}
|
45 |
+
}
|
src/third_party/BigVGAN/configs/bigvgan_base_24khz_100band.json
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"resblock": "1",
|
3 |
+
"num_gpus": 0,
|
4 |
+
"batch_size": 32,
|
5 |
+
"learning_rate": 0.0001,
|
6 |
+
"adam_b1": 0.8,
|
7 |
+
"adam_b2": 0.99,
|
8 |
+
"lr_decay": 0.9999996,
|
9 |
+
"seed": 1234,
|
10 |
+
|
11 |
+
"upsample_rates": [8,8,2,2],
|
12 |
+
"upsample_kernel_sizes": [16,16,4,4],
|
13 |
+
"upsample_initial_channel": 512,
|
14 |
+
"resblock_kernel_sizes": [3,7,11],
|
15 |
+
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
16 |
+
|
17 |
+
"activation": "snakebeta",
|
18 |
+
"snake_logscale": true,
|
19 |
+
|
20 |
+
"resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]],
|
21 |
+
"mpd_reshapes": [2, 3, 5, 7, 11],
|
22 |
+
"use_spectral_norm": false,
|
23 |
+
"discriminator_channel_mult": 1,
|
24 |
+
|
25 |
+
"segment_size": 8192,
|
26 |
+
"num_mels": 100,
|
27 |
+
"num_freq": 1025,
|
28 |
+
"n_fft": 1024,
|
29 |
+
"hop_size": 256,
|
30 |
+
"win_size": 1024,
|
31 |
+
|
32 |
+
"sampling_rate": 24000,
|
33 |
+
|
34 |
+
"fmin": 0,
|
35 |
+
"fmax": 12000,
|
36 |
+
"fmax_for_loss": null,
|
37 |
+
|
38 |
+
"num_workers": 4,
|
39 |
+
|
40 |
+
"dist_config": {
|
41 |
+
"dist_backend": "nccl",
|
42 |
+
"dist_url": "tcp://localhost:54321",
|
43 |
+
"world_size": 1
|
44 |
+
}
|
45 |
+
}
|
src/third_party/BigVGAN/configs/bigvgan_v2_22khz_80band_256x.json
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"resblock": "1",
|
3 |
+
"num_gpus": 0,
|
4 |
+
"batch_size": 4,
|
5 |
+
"learning_rate": 0.0001,
|
6 |
+
"adam_b1": 0.8,
|
7 |
+
"adam_b2": 0.99,
|
8 |
+
"lr_decay": 0.9999996,
|
9 |
+
"seed": 1234,
|
10 |
+
|
11 |
+
"upsample_rates": [4,4,2,2,2,2],
|
12 |
+
"upsample_kernel_sizes": [8,8,4,4,4,4],
|
13 |
+
"upsample_initial_channel": 1536,
|
14 |
+
"resblock_kernel_sizes": [3,7,11],
|
15 |
+
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
16 |
+
|
17 |
+
"use_tanh_at_final": false,
|
18 |
+
"use_bias_at_final": false,
|
19 |
+
|
20 |
+
"activation": "snakebeta",
|
21 |
+
"snake_logscale": true,
|
22 |
+
|
23 |
+
"use_cqtd_instead_of_mrd": true,
|
24 |
+
"cqtd_filters": 128,
|
25 |
+
"cqtd_max_filters": 1024,
|
26 |
+
"cqtd_filters_scale": 1,
|
27 |
+
"cqtd_dilations": [1, 2, 4],
|
28 |
+
"cqtd_hop_lengths": [512, 256, 256],
|
29 |
+
"cqtd_n_octaves": [9, 9, 9],
|
30 |
+
"cqtd_bins_per_octaves": [24, 36, 48],
|
31 |
+
|
32 |
+
"mpd_reshapes": [2, 3, 5, 7, 11],
|
33 |
+
"use_spectral_norm": false,
|
34 |
+
"discriminator_channel_mult": 1,
|
35 |
+
|
36 |
+
"use_multiscale_melloss": true,
|
37 |
+
"lambda_melloss": 15,
|
38 |
+
|
39 |
+
"clip_grad_norm": 500,
|
40 |
+
|
41 |
+
"segment_size": 65536,
|
42 |
+
"num_mels": 80,
|
43 |
+
"num_freq": 1025,
|
44 |
+
"n_fft": 1024,
|
45 |
+
"hop_size": 256,
|
46 |
+
"win_size": 1024,
|
47 |
+
|
48 |
+
"sampling_rate": 22050,
|
49 |
+
|
50 |
+
"fmin": 0,
|
51 |
+
"fmax": null,
|
52 |
+
"fmax_for_loss": null,
|
53 |
+
|
54 |
+
"num_workers": 4,
|
55 |
+
|
56 |
+
"dist_config": {
|
57 |
+
"dist_backend": "nccl",
|
58 |
+
"dist_url": "tcp://localhost:54321",
|
59 |
+
"world_size": 1
|
60 |
+
}
|
61 |
+
}
|
src/third_party/BigVGAN/configs/bigvgan_v2_22khz_80band_fmax8k_256x.json
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"resblock": "1",
|
3 |
+
"num_gpus": 0,
|
4 |
+
"batch_size": 4,
|
5 |
+
"learning_rate": 0.0001,
|
6 |
+
"adam_b1": 0.8,
|
7 |
+
"adam_b2": 0.99,
|
8 |
+
"lr_decay": 0.9999996,
|
9 |
+
"seed": 1234,
|
10 |
+
|
11 |
+
"upsample_rates": [4,4,2,2,2,2],
|
12 |
+
"upsample_kernel_sizes": [8,8,4,4,4,4],
|
13 |
+
"upsample_initial_channel": 1536,
|
14 |
+
"resblock_kernel_sizes": [3,7,11],
|
15 |
+
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
16 |
+
|
17 |
+
"use_tanh_at_final": false,
|
18 |
+
"use_bias_at_final": false,
|
19 |
+
|
20 |
+
"activation": "snakebeta",
|
21 |
+
"snake_logscale": true,
|
22 |
+
|
23 |
+
"use_cqtd_instead_of_mrd": true,
|
24 |
+
"cqtd_filters": 128,
|
25 |
+
"cqtd_max_filters": 1024,
|
26 |
+
"cqtd_filters_scale": 1,
|
27 |
+
"cqtd_dilations": [1, 2, 4],
|
28 |
+
"cqtd_hop_lengths": [512, 256, 256],
|
29 |
+
"cqtd_n_octaves": [9, 9, 9],
|
30 |
+
"cqtd_bins_per_octaves": [24, 36, 48],
|
31 |
+
|
32 |
+
"mpd_reshapes": [2, 3, 5, 7, 11],
|
33 |
+
"use_spectral_norm": false,
|
34 |
+
"discriminator_channel_mult": 1,
|
35 |
+
|
36 |
+
"use_multiscale_melloss": true,
|
37 |
+
"lambda_melloss": 15,
|
38 |
+
|
39 |
+
"clip_grad_norm": 500,
|
40 |
+
|
41 |
+
"segment_size": 65536,
|
42 |
+
"num_mels": 80,
|
43 |
+
"num_freq": 1025,
|
44 |
+
"n_fft": 1024,
|
45 |
+
"hop_size": 256,
|
46 |
+
"win_size": 1024,
|
47 |
+
|
48 |
+
"sampling_rate": 22050,
|
49 |
+
|
50 |
+
"fmin": 0,
|
51 |
+
"fmax": 8000,
|
52 |
+
"fmax_for_loss": null,
|
53 |
+
|
54 |
+
"num_workers": 4,
|
55 |
+
|
56 |
+
"dist_config": {
|
57 |
+
"dist_backend": "nccl",
|
58 |
+
"dist_url": "tcp://localhost:54321",
|
59 |
+
"world_size": 1
|
60 |
+
}
|
61 |
+
}
|
src/third_party/BigVGAN/configs/bigvgan_v2_24khz_100band_256x.json
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"resblock": "1",
|
3 |
+
"num_gpus": 0,
|
4 |
+
"batch_size": 4,
|
5 |
+
"learning_rate": 0.0001,
|
6 |
+
"adam_b1": 0.8,
|
7 |
+
"adam_b2": 0.99,
|
8 |
+
"lr_decay": 0.9999996,
|
9 |
+
"seed": 1234,
|
10 |
+
|
11 |
+
"upsample_rates": [4,4,2,2,2,2],
|
12 |
+
"upsample_kernel_sizes": [8,8,4,4,4,4],
|
13 |
+
"upsample_initial_channel": 1536,
|
14 |
+
"resblock_kernel_sizes": [3,7,11],
|
15 |
+
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
16 |
+
|
17 |
+
"use_tanh_at_final": false,
|
18 |
+
"use_bias_at_final": false,
|
19 |
+
|
20 |
+
"activation": "snakebeta",
|
21 |
+
"snake_logscale": true,
|
22 |
+
|
23 |
+
"use_cqtd_instead_of_mrd": true,
|
24 |
+
"cqtd_filters": 128,
|
25 |
+
"cqtd_max_filters": 1024,
|
26 |
+
"cqtd_filters_scale": 1,
|
27 |
+
"cqtd_dilations": [1, 2, 4],
|
28 |
+
"cqtd_hop_lengths": [512, 256, 256],
|
29 |
+
"cqtd_n_octaves": [9, 9, 9],
|
30 |
+
"cqtd_bins_per_octaves": [24, 36, 48],
|
31 |
+
|
32 |
+
"mpd_reshapes": [2, 3, 5, 7, 11],
|
33 |
+
"use_spectral_norm": false,
|
34 |
+
"discriminator_channel_mult": 1,
|
35 |
+
|
36 |
+
"use_multiscale_melloss": true,
|
37 |
+
"lambda_melloss": 15,
|
38 |
+
|
39 |
+
"clip_grad_norm": 500,
|
40 |
+
|
41 |
+
"segment_size": 65536,
|
42 |
+
"num_mels": 100,
|
43 |
+
"num_freq": 1025,
|
44 |
+
"n_fft": 1024,
|
45 |
+
"hop_size": 256,
|
46 |
+
"win_size": 1024,
|
47 |
+
|
48 |
+
"sampling_rate": 24000,
|
49 |
+
|
50 |
+
"fmin": 0,
|
51 |
+
"fmax": null,
|
52 |
+
"fmax_for_loss": null,
|
53 |
+
|
54 |
+
"num_workers": 4,
|
55 |
+
|
56 |
+
"dist_config": {
|
57 |
+
"dist_backend": "nccl",
|
58 |
+
"dist_url": "tcp://localhost:54321",
|
59 |
+
"world_size": 1
|
60 |
+
}
|
61 |
+
}
|
src/third_party/BigVGAN/configs/bigvgan_v2_44khz_128band_256x.json
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"resblock": "1",
|
3 |
+
"num_gpus": 0,
|
4 |
+
"batch_size": 4,
|
5 |
+
"learning_rate": 0.0001,
|
6 |
+
"adam_b1": 0.8,
|
7 |
+
"adam_b2": 0.99,
|
8 |
+
"lr_decay": 0.9999996,
|
9 |
+
"seed": 1234,
|
10 |
+
|
11 |
+
"upsample_rates": [4,4,2,2,2,2],
|
12 |
+
"upsample_kernel_sizes": [8,8,4,4,4,4],
|
13 |
+
"upsample_initial_channel": 1536,
|
14 |
+
"resblock_kernel_sizes": [3,7,11],
|
15 |
+
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
16 |
+
|
17 |
+
"use_tanh_at_final": false,
|
18 |
+
"use_bias_at_final": false,
|
19 |
+
|
20 |
+
"activation": "snakebeta",
|
21 |
+
"snake_logscale": true,
|
22 |
+
|
23 |
+
"use_cqtd_instead_of_mrd": true,
|
24 |
+
"cqtd_filters": 128,
|
25 |
+
"cqtd_max_filters": 1024,
|
26 |
+
"cqtd_filters_scale": 1,
|
27 |
+
"cqtd_dilations": [1, 2, 4],
|
28 |
+
"cqtd_hop_lengths": [512, 256, 256],
|
29 |
+
"cqtd_n_octaves": [9, 9, 9],
|
30 |
+
"cqtd_bins_per_octaves": [24, 36, 48],
|
31 |
+
|
32 |
+
"mpd_reshapes": [2, 3, 5, 7, 11],
|
33 |
+
"use_spectral_norm": false,
|
34 |
+
"discriminator_channel_mult": 1,
|
35 |
+
|
36 |
+
"use_multiscale_melloss": true,
|
37 |
+
"lambda_melloss": 15,
|
38 |
+
|
39 |
+
"clip_grad_norm": 500,
|
40 |
+
|
41 |
+
"segment_size": 65536,
|
42 |
+
"num_mels": 128,
|
43 |
+
"num_freq": 1025,
|
44 |
+
"n_fft": 1024,
|
45 |
+
"hop_size": 256,
|
46 |
+
"win_size": 1024,
|
47 |
+
|
48 |
+
"sampling_rate": 44100,
|
49 |
+
|
50 |
+
"fmin": 0,
|
51 |
+
"fmax": null,
|
52 |
+
"fmax_for_loss": null,
|
53 |
+
|
54 |
+
"num_workers": 4,
|
55 |
+
|
56 |
+
"dist_config": {
|
57 |
+
"dist_backend": "nccl",
|
58 |
+
"dist_url": "tcp://localhost:54321",
|
59 |
+
"world_size": 1
|
60 |
+
}
|
61 |
+
}
|
src/third_party/BigVGAN/configs/bigvgan_v2_44khz_128band_512x.json
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"resblock": "1",
|
3 |
+
"num_gpus": 0,
|
4 |
+
"batch_size": 4,
|
5 |
+
"learning_rate": 0.0001,
|
6 |
+
"adam_b1": 0.8,
|
7 |
+
"adam_b2": 0.99,
|
8 |
+
"lr_decay": 0.9999996,
|
9 |
+
"seed": 1234,
|
10 |
+
|
11 |
+
"upsample_rates": [8,4,2,2,2,2],
|
12 |
+
"upsample_kernel_sizes": [16,8,4,4,4,4],
|
13 |
+
"upsample_initial_channel": 1536,
|
14 |
+
"resblock_kernel_sizes": [3,7,11],
|
15 |
+
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
16 |
+
|
17 |
+
"use_tanh_at_final": false,
|
18 |
+
"use_bias_at_final": false,
|
19 |
+
|
20 |
+
"activation": "snakebeta",
|
21 |
+
"snake_logscale": true,
|
22 |
+
|
23 |
+
"use_cqtd_instead_of_mrd": true,
|
24 |
+
"cqtd_filters": 128,
|
25 |
+
"cqtd_max_filters": 1024,
|
26 |
+
"cqtd_filters_scale": 1,
|
27 |
+
"cqtd_dilations": [1, 2, 4],
|
28 |
+
"cqtd_hop_lengths": [512, 256, 256],
|
29 |
+
"cqtd_n_octaves": [9, 9, 9],
|
30 |
+
"cqtd_bins_per_octaves": [24, 36, 48],
|
31 |
+
|
32 |
+
"mpd_reshapes": [2, 3, 5, 7, 11],
|
33 |
+
"use_spectral_norm": false,
|
34 |
+
"discriminator_channel_mult": 1,
|
35 |
+
|
36 |
+
"use_multiscale_melloss": true,
|
37 |
+
"lambda_melloss": 15,
|
38 |
+
|
39 |
+
"clip_grad_norm": 500,
|
40 |
+
|
41 |
+
"segment_size": 65536,
|
42 |
+
"num_mels": 128,
|
43 |
+
"num_freq": 2049,
|
44 |
+
"n_fft": 2048,
|
45 |
+
"hop_size": 512,
|
46 |
+
"win_size": 2048,
|
47 |
+
|
48 |
+
"sampling_rate": 44100,
|
49 |
+
|
50 |
+
"fmin": 0,
|
51 |
+
"fmax": null,
|
52 |
+
"fmax_for_loss": null,
|
53 |
+
|
54 |
+
"num_workers": 4,
|
55 |
+
|
56 |
+
"dist_config": {
|
57 |
+
"dist_backend": "nccl",
|
58 |
+
"dist_url": "tcp://localhost:54321",
|
59 |
+
"world_size": 1
|
60 |
+
}
|
61 |
+
}
|