Opus commited on
Commit
9d9ac6c
·
1 Parent(s): 597284f
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +180 -0
  2. README.md +5 -5
  3. app.py +189 -0
  4. datasets/CoTMovieDubbing/README.md +20 -0
  5. datasets/CoTMovieDubbing/filelist/cot_spk_for_speech_gen.lst +0 -0
  6. datasets/CoTMovieDubbing/filelist/mmlm_test.jsonl +0 -0
  7. datasets/CoTMovieDubbing/filelist/mmlm_train.jsonl +0 -0
  8. datasets/Grid/README.md +1 -0
  9. datasets/V2C/README.md +1 -0
  10. datasets/V2C/V2C_Setting2.txt +0 -0
  11. datasets/V2C/V2C_Setting3.txt +0 -0
  12. requirements.txt +237 -0
  13. ruff.toml +11 -0
  14. src/internvl/eval.py +337 -0
  15. src/moviedubber/configs/basemodel.yaml +9 -0
  16. src/moviedubber/eval.py +245 -0
  17. src/moviedubber/infer/basic.toml +4 -0
  18. src/moviedubber/infer/utils_infer.py +399 -0
  19. src/moviedubber/infer/video_preprocess.py +315 -0
  20. src/moviedubber/infer_with_mmlm_result.py +339 -0
  21. src/moviedubber/model/__init__.py +5 -0
  22. src/moviedubber/model/cfm.py +209 -0
  23. src/moviedubber/model/dit.py +297 -0
  24. src/moviedubber/model/modules.py +467 -0
  25. src/moviedubber/model/utils.py +128 -0
  26. src/third_party/BigVGAN/.gitignore +146 -0
  27. src/third_party/BigVGAN/LICENSE +21 -0
  28. src/third_party/BigVGAN/README.md +266 -0
  29. src/third_party/BigVGAN/activations.py +126 -0
  30. src/third_party/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
  31. src/third_party/BigVGAN/alias_free_activation/cuda/activation1d.py +77 -0
  32. src/third_party/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  33. src/third_party/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
  34. src/third_party/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
  35. src/third_party/BigVGAN/alias_free_activation/cuda/load.py +86 -0
  36. src/third_party/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
  37. src/third_party/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
  38. src/third_party/BigVGAN/alias_free_activation/torch/act.py +30 -0
  39. src/third_party/BigVGAN/alias_free_activation/torch/filter.py +101 -0
  40. src/third_party/BigVGAN/alias_free_activation/torch/resample.py +58 -0
  41. src/third_party/BigVGAN/bigvgan.py +493 -0
  42. src/third_party/BigVGAN/configs/bigvgan_22khz_80band.json +45 -0
  43. src/third_party/BigVGAN/configs/bigvgan_24khz_100band.json +45 -0
  44. src/third_party/BigVGAN/configs/bigvgan_base_22khz_80band.json +45 -0
  45. src/third_party/BigVGAN/configs/bigvgan_base_24khz_100band.json +45 -0
  46. src/third_party/BigVGAN/configs/bigvgan_v2_22khz_80band_256x.json +61 -0
  47. src/third_party/BigVGAN/configs/bigvgan_v2_22khz_80band_fmax8k_256x.json +61 -0
  48. src/third_party/BigVGAN/configs/bigvgan_v2_24khz_100band_256x.json +61 -0
  49. src/third_party/BigVGAN/configs/bigvgan_v2_44khz_128band_256x.json +61 -0
  50. src/third_party/BigVGAN/configs/bigvgan_v2_44khz_128band_512x.json +61 -0
.gitignore ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+ data
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+
162
+ .idea/
163
+
164
+ .DS_Store
165
+ data_process/
166
+ internvl_chat/work_dirs/
167
+ internvl_chat/unittest/
168
+ internvl_chat/data/
169
+ Husky2/*
170
+ data_process/
171
+ *distillation*
172
+
173
+ batchscript-*
174
+ results/
175
+
176
+ # *txt
177
+ *csv
178
+ *mp4
179
+ temp/
180
+ src/moviedubber/infer/basic_test.toml
README.md CHANGED
@@ -1,12 +1,12 @@
1
  ---
2
- title: DeepDubber V1
3
- emoji: 🐠
4
- colorFrom: blue
5
- colorTo: pink
6
  sdk: gradio
7
  sdk_version: 5.22.0
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Deepdubber V1
3
+ emoji: 🌍
4
+ colorFrom: red
5
+ colorTo: green
6
  sdk: gradio
7
  sdk_version: 5.22.0
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+
4
+ import gradio as gr
5
+ import librosa
6
+ import soundfile
7
+ import tomli
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import torchaudio
11
+ from moviepy import VideoFileClip
12
+ from pydub import AudioSegment
13
+ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
14
+
15
+ from src.moviedubber.infer.utils_infer import (
16
+ cfg_strength,
17
+ chunk_text,
18
+ nfe_step,
19
+ sway_sampling_coef,
20
+ )
21
+ from src.moviedubber.infer.video_preprocess import VideoFeatureExtractor
22
+ from src.moviedubber.infer_with_mmlm_result import concat_movie_with_audio, get_spk_emb, load_models
23
+ from src.moviedubber.model.utils import convert_char_to_pinyin
24
+
25
+
26
+ def load_asr_model(model_id="openai/whisper-large-v3-turbo"):
27
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
28
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
29
+ model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
30
+ ).to(device)
31
+
32
+ processor = AutoProcessor.from_pretrained(model_id)
33
+
34
+ pipe = pipeline(
35
+ "automatic-speech-recognition",
36
+ model=model,
37
+ tokenizer=processor.tokenizer,
38
+ feature_extractor=processor.feature_extractor,
39
+ torch_dtype=torch_dtype,
40
+ device=device,
41
+ )
42
+ return pipe
43
+
44
+
45
+ device = "cpu"
46
+ config = tomli.load(open("src/moviedubber/infer/basic.toml", "rb"))
47
+
48
+
49
+ ema_model, vocoder, ort_session = load_models(config, device=device)
50
+ asr_pipe = load_asr_model()
51
+
52
+ videofeature_extractor = VideoFeatureExtractor(device=device)
53
+
54
+
55
+ def deepdubber(video_path: str, subtitle_text: str, audio_path: str = None) -> str:
56
+ print(f"Starting deepdubber with video_path: {video_path} and subtitle_text: {subtitle_text}")
57
+ gen_clip = videofeature_extractor.extract_features(video_path)
58
+ gen_text = subtitle_text
59
+
60
+ clip = VideoFileClip(video_path)
61
+ gen_audio_len = int(clip.duration * 24000 // 256)
62
+
63
+ gen_clip = gen_clip.unsqueeze(0).to(device=device, dtype=torch.float32).transpose(1, 2)
64
+ gen_clip = F.interpolate(gen_clip, size=(gen_audio_len,), mode="linear", align_corners=False).transpose(1, 2)
65
+
66
+ ref_audio_len = None
67
+ if audio_path is not None:
68
+ print("reference audio is not None, dubbing with reference audio")
69
+
70
+ if audio_path.endswith(".mp3"):
71
+ audio = AudioSegment.from_mp3(audio_path)
72
+
73
+ wav_file = audio_path.replace(".mp3", ".wav")
74
+ audio.export(wav_file, format="wav")
75
+ else:
76
+ wav_file = audio_path
77
+
78
+ ref_text = asr_pipe(librosa.load(wav_file, sr=16000)[0], generate_kwargs={"language": "english"})["text"]
79
+ ref_text = ref_text.replace("\n", " ").replace("\r", " ")
80
+ print(f"Reference text: {ref_text}")
81
+
82
+ spk_emb = get_spk_emb(wav_file, ort_session)
83
+ spk_emb = torch.tensor(spk_emb).to(device=device, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
84
+
85
+ audio_data, sr = torchaudio.load(wav_file)
86
+ resampler = torchaudio.transforms.Resample(sr, 24000)
87
+ if sr != 24000:
88
+ audio_data = resampler(audio_data)
89
+
90
+ if audio_data.shape[0] > 1:
91
+ audio_data = torch.mean(audio_data, dim=0, keepdim=True)
92
+
93
+ audio_data = audio_data.to(device)
94
+
95
+ ref_audio_len = int(audio_data.shape[-1] // 256)
96
+ ref_clip = torch.zeros((1, ref_audio_len, 768)).to(device=device)
97
+
98
+ gen_clip = torch.cat((gen_clip, ref_clip), dim=1)
99
+
100
+ gen_audio_len = ref_audio_len + gen_audio_len
101
+
102
+ gen_text = ref_text + " " + gen_text
103
+
104
+ else:
105
+ spk_emb = torch.zeros((1, 1, 192)).to(device=device)
106
+ audio_data = torch.zeros((1, gen_audio_len, 100)).to(device=device)
107
+
108
+ gen_text_batches = chunk_text(gen_text, max_chars=1024)
109
+ final_text_list = convert_char_to_pinyin(gen_text_batches)
110
+
111
+ with torch.inference_mode():
112
+ generated, _ = ema_model.sample(
113
+ cond=audio_data,
114
+ text=final_text_list,
115
+ clip=gen_clip,
116
+ spk_emb=spk_emb,
117
+ duration=gen_audio_len,
118
+ steps=nfe_step,
119
+ cfg_strength=cfg_strength,
120
+ sway_sampling_coef=sway_sampling_coef,
121
+ no_ref_audio=False,
122
+ )
123
+
124
+ generated = generated.to(torch.float32)
125
+
126
+ if ref_audio_len is not None:
127
+ generated = generated[:, ref_audio_len:, :]
128
+
129
+ generated_mel_spec = generated.permute(0, 2, 1)
130
+ generated_wave = vocoder(generated_mel_spec)
131
+
132
+ generated_wave = generated_wave.squeeze().cpu().numpy()
133
+
134
+ # using a temporary wav file to save the generated audio
135
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_wav_file:
136
+ temp_wav_path = temp_wav_file.name
137
+ soundfile.write(temp_wav_path, generated_wave, samplerate=24000)
138
+
139
+ concated_video = concat_movie_with_audio(temp_wav_path, video_path, ".")
140
+
141
+ # Ensure the temporary file is deleted after use
142
+ os.remove(temp_wav_path)
143
+
144
+ print(f"Deepdubber completed successfully, output path: {concated_video}")
145
+ return concated_video
146
+
147
+
148
+ def process_video_dubbing(video_path: str, subtitle_text: str, audio_path: str = None) -> str:
149
+ try:
150
+ print(f"Processing video: {video_path}")
151
+ if not os.path.exists(video_path):
152
+ raise ValueError("Video file does not exist")
153
+
154
+ if not subtitle_text.strip():
155
+ raise ValueError("Subtitle text cannot be empty")
156
+
157
+ output_path = deepdubber(video_path, subtitle_text, audio_path)
158
+
159
+ return output_path
160
+
161
+ except Exception as e:
162
+ print(f"Error in process_video_dubbing: {e}")
163
+
164
+ return None
165
+
166
+
167
+ def create_ui():
168
+ with gr.Blocks(title="DeepDubber-V1") as app:
169
+ gr.Markdown("# DeepDubber-V1\nUpload your video file and enter the text you want to dub")
170
+
171
+ with gr.Row():
172
+ video_input = gr.Video(label="Upload video")
173
+ audio_input = gr.Audio(label="Upload audio", type="filepath")
174
+ subtitle_input = gr.Textbox(label="Enter the text", placeholder="Enter the text to be dubbed...", lines=5)
175
+
176
+ process_btn = gr.Button("Start Dubbing")
177
+
178
+ output_video = gr.Video(label="Dubbed Video")
179
+
180
+ process_btn.click(
181
+ fn=process_video_dubbing, inputs=[video_input, subtitle_input, audio_input], outputs=output_video
182
+ )
183
+
184
+ return app
185
+
186
+
187
+ if __name__ == "__main__":
188
+ app = create_ui()
189
+ app.launch()
datasets/CoTMovieDubbing/README.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## MovieData Preprocessing
2
+
3
+ TODO
4
+
5
+ ## Data Tree
6
+
7
+ ```
8
+ moviecopy
9
+ |-- movie1_name
10
+ |-- xxxx.mp4
11
+ |-- xxxx.mp3
12
+ |-- xxxx.txt
13
+ |-- ...
14
+ |-- movie2_name
15
+ |-- xxxx.mp4
16
+ |-- xxxx.mp3
17
+ |-- xxxx.txt
18
+ |-- ...
19
+ |-- ...
20
+ ```
datasets/CoTMovieDubbing/filelist/cot_spk_for_speech_gen.lst ADDED
The diff for this file is too large to render. See raw diff
 
datasets/CoTMovieDubbing/filelist/mmlm_test.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
datasets/CoTMovieDubbing/filelist/mmlm_train.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
datasets/Grid/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ Refer to: [Grid](https://paperswithcode.com/dataset/grid)
datasets/V2C/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ Refer to: [V2C](https://github.com/chenqi008/V2C)
datasets/V2C/V2C_Setting2.txt ADDED
The diff for this file is too large to render. See raw diff
 
datasets/V2C/V2C_Setting3.txt ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.1.0
2
+ accelerate==0.34.2
3
+ addict==2.4.0
4
+ aiofiles==24.1.0
5
+ aiohappyeyeballs==2.4.6
6
+ aiohttp==3.11.13
7
+ aiosignal==1.3.2
8
+ altair==5.5.0
9
+ annotated-types==0.7.0
10
+ antlr4-python3-runtime==4.9.3
11
+ anyio==4.8.0
12
+ async-timeout==5.0.1
13
+ attrs==25.1.0
14
+ audioread==3.0.1
15
+ beautifulsoup4==4.13.3
16
+ bitsandbytes==0.42.0
17
+ blinker==1.9.0
18
+ boto3==1.37.12
19
+ botocore==1.37.12
20
+ cached_path==1.7.1
21
+ cachetools==5.5.2
22
+ certifi==2025.1.31
23
+ cffi==1.17.1
24
+ charset-normalizer==3.4.1
25
+ click==8.1.8
26
+ colorama==0.4.6
27
+ coloredlogs==15.0.1
28
+ contourpy==1.3.0
29
+ cycler==0.12.1
30
+ datasets==3.3.2
31
+ decorator==5.2.1
32
+ decord==0.6.0
33
+ deepspeed==0.15.4
34
+ dill==0.3.8
35
+ docstring_parser==0.16
36
+ einops==0.8.1
37
+ einops-exts==0.0.4
38
+ einx==0.3.0
39
+ eval_type_backport==0.2.2
40
+ exceptiongroup==1.2.2
41
+ fastapi==0.115.8
42
+ ffmpy==0.5.0
43
+ filelock==3.13.1
44
+ flash-attn==2.6.3
45
+ flatbuffers==25.2.10
46
+ fonttools==4.56.0
47
+ frozendict==2.4.6
48
+ frozenlist==1.5.0
49
+ fsspec==2024.6.1
50
+ future==1.0.0
51
+ gdown==5.2.0
52
+ gitdb==4.0.12
53
+ GitPython==3.1.44
54
+ google-api-core==2.24.2
55
+ google-auth==2.38.0
56
+ google-cloud-core==2.4.3
57
+ google-cloud-storage==2.19.0
58
+ google-crc32c==1.6.0
59
+ google-resumable-media==2.7.2
60
+ googleapis-common-protos==1.69.1
61
+ gradio==3.35.2
62
+ gradio_client==0.2.9
63
+ grpcio==1.70.0
64
+ h11==0.14.0
65
+ hjson==3.1.0
66
+ httpcore==0.17.3
67
+ httpx==0.24.0
68
+ huggingface-hub==0.27.1
69
+ humanfriendly==10.0
70
+ idna==3.10
71
+ imageio==2.37.0
72
+ imageio-ffmpeg==0.6.0
73
+ importlib_metadata==8.6.1
74
+ importlib_resources==6.5.2
75
+ jieba==0.42.1
76
+ Jinja2==3.1.4
77
+ jmespath==1.0.1
78
+ joblib==1.4.2
79
+ jsonschema==4.23.0
80
+ jsonschema-specifications==2024.10.1
81
+ kiwisolver==1.4.7
82
+ latex2mathml==3.77.0
83
+ lazy_loader==0.4
84
+ librosa==0.11.0
85
+ liger_kernel==0.4.2
86
+ linkify-it-py==2.0.3
87
+ llvmlite==0.43.0
88
+ loguru==0.7.3
89
+ Markdown==3.7
90
+ markdown-it-py==2.2.0
91
+ markdown2==2.5.3
92
+ MarkupSafe==2.1.5
93
+ matplotlib==3.9.4
94
+ mdit-py-plugins==0.3.3
95
+ mdurl==0.1.2
96
+ mmcls==0.25.0
97
+ mmcv-full==1.6.2
98
+ mmsegmentation==0.30.0
99
+ model-index==0.1.11
100
+ moviepy==2.1.2
101
+ mpmath==1.3.0
102
+ msgpack==1.1.0
103
+ multidict==6.1.0
104
+ multiprocess==0.70.16
105
+ narwhals==1.28.0
106
+ networkx==3.2.1
107
+ ninja==1.11.1.3
108
+ numba==0.60.0
109
+ numpy==1.26.3
110
+ nvidia-cublas-cu11==11.11.3.6
111
+ nvidia-cublas-cu12==12.1.3.1
112
+ nvidia-cuda-cupti-cu11==11.8.87
113
+ nvidia-cuda-cupti-cu12==12.1.105
114
+ nvidia-cuda-nvrtc-cu11==11.8.89
115
+ nvidia-cuda-nvrtc-cu12==12.1.105
116
+ nvidia-cuda-runtime-cu11==11.8.89
117
+ nvidia-cuda-runtime-cu12==12.1.105
118
+ nvidia-cudnn-cu11==9.1.0.70
119
+ nvidia-cudnn-cu12==9.1.0.70
120
+ nvidia-cufft-cu11==10.9.0.58
121
+ nvidia-cufft-cu12==11.0.2.54
122
+ nvidia-curand-cu11==10.3.0.86
123
+ nvidia-curand-cu12==10.3.2.106
124
+ nvidia-cusolver-cu11==11.4.1.48
125
+ nvidia-cusolver-cu12==11.4.5.107
126
+ nvidia-cusparse-cu11==11.7.5.86
127
+ nvidia-cusparse-cu12==12.1.0.106
128
+ nvidia-ml-py==12.570.86
129
+ nvidia-nccl-cu11==2.20.5
130
+ nvidia-nccl-cu12==2.20.5
131
+ nvidia-nvjitlink-cu12==12.8.61
132
+ nvidia-nvtx-cu11==11.8.86
133
+ nvidia-nvtx-cu12==12.1.105
134
+ omegaconf==2.3.0
135
+ onnxruntime==1.18.0
136
+ opencv-python==4.11.0.86
137
+ opendatalab==0.0.10
138
+ openmim==0.3.9
139
+ openxlab==0.0.11
140
+ ordered-set==4.1.0
141
+ orjson==3.10.15
142
+ packaging==24.2
143
+ pandas==2.2.3
144
+ peft==0.10.0
145
+ pillow==10.4.0
146
+ platformdirs==4.3.6
147
+ pooch==1.8.2
148
+ prettytable==3.14.0
149
+ proglog==0.1.10
150
+ propcache==0.3.0
151
+ proto-plus==1.26.1
152
+ protobuf==5.29.3
153
+ psutil==7.0.0
154
+ py-cpuinfo==9.0.0
155
+ pyarrow==19.0.1
156
+ pyasn1==0.6.1
157
+ pyasn1_modules==0.4.1
158
+ pycocoevalcap==1.2
159
+ pycocotools==2.0.8
160
+ pycparser==2.22
161
+ pycryptodome==3.21.0
162
+ pydantic==2.10.6
163
+ pydantic_core==2.27.2
164
+ pydeck==0.9.1
165
+ pydub==0.25.1
166
+ Pygments==2.19.1
167
+ pyparsing==3.2.1
168
+ pypinyin==0.53.0
169
+ PySocks==1.7.1
170
+ python-dateutil==2.9.0.post0
171
+ python-dotenv==1.0.1
172
+ python-multipart==0.0.20
173
+ pytz==2025.1
174
+ PyYAML==6.0.2
175
+ referencing==0.36.2
176
+ regex==2024.11.6
177
+ requests==2.32.3
178
+ rich==13.9.4
179
+ rpds-py==0.23.1
180
+ rsa==4.9
181
+ s3transfer==0.11.4
182
+ safetensors==0.5.3
183
+ scikit-learn==1.6.1
184
+ scipy==1.13.1
185
+ semantic-version==2.10.0
186
+ sentencepiece==0.1.99
187
+ shortuuid==1.0.13
188
+ shtab==1.7.1
189
+ six==1.17.0
190
+ smmap==5.0.2
191
+ sniffio==1.3.1
192
+ soundfile==0.13.1
193
+ soupsieve==2.6
194
+ soxr==0.5.0.post1
195
+ starlette==0.45.3
196
+ streamlit==1.42.2
197
+ streamlit-image-select==0.6.0
198
+ svgwrite==1.4.3
199
+ sympy==1.13.1
200
+ tabulate==0.9.0
201
+ tenacity==9.0.0
202
+ tensorboard==2.19.0
203
+ tensorboard-data-server==0.7.2
204
+ tensorboardX==2.6.2.2
205
+ termcolor==2.5.0
206
+ threadpoolctl==3.5.0
207
+ timm==0.9.12
208
+ tokenizers==0.19.1
209
+ toml==0.10.2
210
+ tomli==2.2.1
211
+ torch==2.4.1
212
+ torchaudio==2.4.1
213
+ torchdiffeq==0.2.5
214
+ torchvision==0.19.1
215
+ tornado==6.4.2
216
+ tqdm==4.67.1
217
+ transformers==4.42.1
218
+ triton==3.0.0
219
+ trl==0.10.1
220
+ typeguard==4.4.2
221
+ typing_extensions==4.12.2
222
+ tyro==0.9.16
223
+ tzdata==2025.1
224
+ uc-micro-py==1.0.3
225
+ urllib3==1.26.20
226
+ uvicorn==0.34.0
227
+ watchdog==6.0.0
228
+ wavedrom==2.0.3.post3
229
+ wcwidth==0.2.13
230
+ websockets==15.0
231
+ Werkzeug==3.1.3
232
+ x-transformers==2.1.37
233
+ xxhash==3.5.0
234
+ yacs==0.1.8
235
+ yapf==0.40.1
236
+ yarl==1.18.3
237
+ zipp==3.21.0
ruff.toml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ line-length = 120
2
+ target-version = "py310"
3
+
4
+ [lint]
5
+ # Only ignore variables with names starting with "_".
6
+ dummy-variable-rgx = "^_.*$"
7
+ ignore = ["E402"]
8
+
9
+ [lint.isort]
10
+ force-single-line = false
11
+ lines-after-imports = 2
src/internvl/eval.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import sys
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ import torch
9
+ import torchvision.transforms as T
10
+ from decord import VideoReader, cpu
11
+ from PIL import Image
12
+ from torchvision.transforms.functional import InterpolationMode
13
+ from tqdm import tqdm
14
+ from transformers import AutoTokenizer
15
+
16
+
17
+ sys.path.insert(0, os.path.join(str(Path(__file__).resolve().parents[2]), "src/third_party/InternVL/internvl_chat"))
18
+ from internvl.model.internvl_chat.modeling_internvl_chat import InternVLChatModel # type: ignore
19
+
20
+
21
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
22
+ IMAGENET_STD = (0.229, 0.224, 0.225)
23
+
24
+
25
+ def build_transform(input_size):
26
+ MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
27
+ transform = T.Compose(
28
+ [
29
+ T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
30
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
31
+ T.ToTensor(),
32
+ T.Normalize(mean=MEAN, std=STD),
33
+ ]
34
+ )
35
+ return transform
36
+
37
+
38
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
39
+ best_ratio_diff = float("inf")
40
+ best_ratio = (1, 1)
41
+ area = width * height
42
+ for ratio in target_ratios:
43
+ target_aspect_ratio = ratio[0] / ratio[1]
44
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
45
+ if ratio_diff < best_ratio_diff:
46
+ best_ratio_diff = ratio_diff
47
+ best_ratio = ratio
48
+ elif ratio_diff == best_ratio_diff:
49
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
50
+ best_ratio = ratio
51
+ return best_ratio
52
+
53
+
54
+ def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
55
+ orig_width, orig_height = image.size
56
+ aspect_ratio = orig_width / orig_height
57
+
58
+ # calculate the existing image aspect ratio
59
+ target_ratios = set(
60
+ (i, j)
61
+ for n in range(min_num, max_num + 1)
62
+ for i in range(1, n + 1)
63
+ for j in range(1, n + 1)
64
+ if i * j <= max_num and i * j >= min_num
65
+ )
66
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
67
+
68
+ # find the closest aspect ratio to the target
69
+ target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)
70
+
71
+ # calculate the target width and height
72
+ target_width = image_size * target_aspect_ratio[0]
73
+ target_height = image_size * target_aspect_ratio[1]
74
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
75
+
76
+ # resize the image
77
+ resized_img = image.resize((target_width, target_height))
78
+ processed_images = []
79
+ for i in range(blocks):
80
+ box = (
81
+ (i % (target_width // image_size)) * image_size,
82
+ (i // (target_width // image_size)) * image_size,
83
+ ((i % (target_width // image_size)) + 1) * image_size,
84
+ ((i // (target_width // image_size)) + 1) * image_size,
85
+ )
86
+ # split the image
87
+ split_img = resized_img.crop(box)
88
+ processed_images.append(split_img)
89
+ assert len(processed_images) == blocks
90
+ if use_thumbnail and len(processed_images) != 1:
91
+ thumbnail_img = image.resize((image_size, image_size))
92
+ processed_images.append(thumbnail_img)
93
+ return processed_images
94
+
95
+
96
+ def load_image(image_file, input_size=448, max_num=12):
97
+ image = Image.open(image_file).convert("RGB")
98
+ transform = build_transform(input_size=input_size)
99
+ images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
100
+ pixel_values = [transform(image) for image in images]
101
+ pixel_values = torch.stack(pixel_values)
102
+ return pixel_values
103
+
104
+
105
+ def get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
106
+ if bound:
107
+ start, end = bound[0], bound[1]
108
+ else:
109
+ start, end = -100000, 100000
110
+ start_idx = max(first_idx, round(start * fps))
111
+ end_idx = min(round(end * fps), max_frame)
112
+ seg_size = float(end_idx - start_idx) / num_segments
113
+ frame_indices = np.array(
114
+ [int(start_idx + (seg_size / 2) + np.round(seg_size * idx)) for idx in range(num_segments)]
115
+ )
116
+ return frame_indices
117
+
118
+
119
+ def load_video(
120
+ video_path,
121
+ bound=None,
122
+ input_size=448,
123
+ max_num=1,
124
+ num_segments=32,
125
+ cache_dir=".cache/expcache",
126
+ ):
127
+ vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
128
+ max_frame = len(vr) - 1
129
+ fps = float(vr.get_avg_fps())
130
+
131
+ video_cache_dir = video_path.split("/")[-2] + "_" + os.path.basename(video_path).split(".")[0]
132
+ video_cache_dir = os.path.join(cache_dir, video_cache_dir)
133
+ cache_filename = os.path.join(
134
+ video_cache_dir,
135
+ f"_bound-{bound}_input_size-{input_size}_max_num-{max_num}_num_segments-{num_segments}.pt",
136
+ )
137
+ if os.path.exists(cache_filename) and os.path.isfile(cache_filename):
138
+ cache = torch.load(cache_filename, weights_only=True)
139
+ pixel_values = cache["pixel_values"]
140
+ num_patches_list = cache["num_patches_list"]
141
+
142
+ else:
143
+ pixel_values_list, num_patches_list = [], []
144
+ transform = build_transform(input_size=input_size)
145
+ frame_indices = get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments)
146
+
147
+ frame_indices = np.append(0, frame_indices) # Add 0 at the beginning of the list
148
+ frame_indices = np.append(frame_indices, max_frame) # Add max_frame at the end of the list
149
+
150
+ os.makedirs(video_cache_dir, exist_ok=True)
151
+
152
+ idx = 0
153
+ for frame_index in frame_indices:
154
+ img = Image.fromarray(vr[frame_index].asnumpy()).convert("RGB")
155
+
156
+ img.save(os.path.join(video_cache_dir, f"frame_{frame_index}_tile_{idx}.png"))
157
+
158
+ img = dynamic_preprocess(img, image_size=input_size, use_thumbnail=True, max_num=max_num)
159
+
160
+ pixel_values = [transform(tile) for tile in img]
161
+ pixel_values = torch.stack(pixel_values)
162
+ num_patches_list.append(pixel_values.shape[0])
163
+ pixel_values_list.append(pixel_values)
164
+
165
+ idx += 1
166
+ pixel_values = torch.cat(pixel_values_list)
167
+
168
+ os.makedirs(cache_dir, exist_ok=True)
169
+ torch.save({"pixel_values": pixel_values, "num_patches_list": num_patches_list}, cache_filename)
170
+
171
+ return pixel_values, num_patches_list
172
+
173
+
174
+ def analyze_predictions(file_path):
175
+ # Read the CSV file
176
+ df = pd.read_csv(file_path)
177
+
178
+ # Calculate overall accuracy
179
+ total_samples = len(df)
180
+ correct_predictions = df["is_correct"].value_counts().get(True, 0)
181
+ overall_accuracy = correct_predictions / total_samples
182
+
183
+ # Initialize metrics for each class
184
+ classes = ["A", "B", "C"]
185
+ class_metrics = {}
186
+
187
+ for cls in classes:
188
+ # Filter for samples where target is this class
189
+ true_class = df[df["target"] == cls]
190
+ # Filter for samples where prediction is this class
191
+ # pred_class = df[df["predict"] == cls]
192
+
193
+ # Calculate TP, FP, FN
194
+ TP = len(df[(df["target"] == cls) & (df["predict"] == cls)])
195
+ FP = len(df[(df["target"] != cls) & (df["predict"] == cls)])
196
+ FN = len(df[(df["target"] == cls) & (df["predict"] != cls)])
197
+
198
+ # Calculate precision, recall, F1
199
+ precision = TP / (TP + FP) if (TP + FP) > 0 else 0
200
+ recall = TP / (TP + FN) if (TP + FN) > 0 else 0
201
+ f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
202
+
203
+ # Store metrics
204
+ class_metrics[cls] = {
205
+ "total_samples": len(true_class),
206
+ "precision": precision,
207
+ "recall": recall,
208
+ "f1": f1,
209
+ "true_positives": TP,
210
+ "false_positives": FP,
211
+ "false_negatives": FN,
212
+ }
213
+
214
+ print(f"Overall Accuracy: {overall_accuracy:.4f} ({correct_predictions}/{total_samples})")
215
+ print()
216
+ print("Indicators for each category:")
217
+
218
+ for cls in classes:
219
+ metrics = class_metrics[cls]
220
+ print(f" Class {cls}:")
221
+ print(f" Total Samples: {metrics['total_samples']}")
222
+ print(f" Precision: {metrics['precision']:.4f}")
223
+ print(f" Recall: {metrics['recall']:.4f}")
224
+ print(f" F1 Score: {metrics['f1']:.4f}")
225
+ print(f" True Positives: {metrics['true_positives']}")
226
+ print(f" False Positives: {metrics['false_positives']}")
227
+ print(f" False Negatives: {metrics['false_negatives']}")
228
+
229
+ return overall_accuracy, class_metrics
230
+
231
+
232
+ def s_thread(video_dir, model_path, device, chunk, idx, queue):
233
+ model = InternVLChatModel.from_pretrained(
234
+ model_path,
235
+ torch_dtype=torch.bfloat16,
236
+ low_cpu_mem_usage=True,
237
+ use_flash_attn=True,
238
+ )
239
+ model = model.eval().to(device)
240
+
241
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False)
242
+
243
+ generation_config = dict(max_new_tokens=1024, do_sample=False)
244
+
245
+ res = []
246
+ for line in tqdm(chunk, position=idx, desc=f"Device {device}"):
247
+ data = json.loads(line)
248
+
249
+ video_path = os.path.join(video_dir, data["video"])
250
+ ques = data["conversations"][0]["value"]
251
+
252
+ target_ans = data["conversations"][1]["value"].split("<CONCLUSION>")[1].split("</CONCLUSION>")[0].strip()
253
+
254
+ pixel_values, num_patches_list = load_video(video_path, num_segments=8, max_num=1)
255
+ pixel_values = pixel_values.to(torch.bfloat16).to(device)
256
+ video_prefix = "".join([f"Frame{i + 1}: <image>\n" for i in range(len(num_patches_list))])
257
+ question = video_prefix + f"{ques}"
258
+ response = model.chat(
259
+ tokenizer,
260
+ pixel_values,
261
+ question,
262
+ generation_config,
263
+ num_patches_list=num_patches_list,
264
+ history=None,
265
+ return_history=False,
266
+ )
267
+
268
+ try:
269
+ ans = response.split("<CONCLUSION>")[1].split("</CONCLUSION>")[0].strip()
270
+ except Exception as e:
271
+ print(f"Error: {e}, response: {response}")
272
+ ans = response.strip()[0]
273
+
274
+ is_correct = False
275
+ if ans == target_ans:
276
+ is_correct = True
277
+
278
+ res.append(f"{video_path},{is_correct},{target_ans},{ans}")
279
+
280
+ queue.put(res)
281
+
282
+
283
+ if __name__ == "__main__":
284
+ import argparse
285
+
286
+ import torch.multiprocessing as mp
287
+
288
+ parser = argparse.ArgumentParser(description="eval script for mmlm")
289
+ parser.add_argument("--model_path", type=str, help="Path to the model checkpoint.")
290
+ parser.add_argument("--test_file", type=str, help="Path to the test file.")
291
+ parser.add_argument("--video_dir", type=str, help="Path to the test video directory.")
292
+ parser.add_argument("--gpuids", type=str, help="GPU ids to use.")
293
+
294
+ # python eval.py --model_path /path/to/model --test_file /path/to/test_file --video_dir /path/to/video_dir --gpuids 0,1,2,3
295
+
296
+ args = parser.parse_args()
297
+
298
+ model_path = args.model_path
299
+ test_file = args.test_file
300
+ video_dir = args.video_dir
301
+
302
+ gpu_ids = args.gpuids.split(",") if args.gpuids else ["0"]
303
+
304
+ cot_test = Path(test_file).read_text().splitlines()
305
+
306
+ chunks = np.array_split(cot_test, len(gpu_ids))
307
+
308
+ mp.set_start_method("spawn", force=True)
309
+
310
+ queue = mp.Queue()
311
+
312
+ processes = []
313
+ for idx, chunk in enumerate(chunks):
314
+ device = gpu_ids[idx % len(gpu_ids)]
315
+ device = f"cuda:{device}"
316
+
317
+ p = mp.Process(target=s_thread, args=(video_dir, model_path, device, chunk, idx, queue))
318
+ processes.append(p)
319
+ p.start()
320
+
321
+ for process in processes:
322
+ process.join()
323
+
324
+ result = []
325
+ for _ in range(len(chunks)):
326
+ res = queue.get()
327
+ result.extend(res)
328
+
329
+ res_saved = f"{'__'.join(model_path.split('/'))}_res.csv"
330
+ with open(res_saved, "w") as f:
331
+ f.write("video_id,is_correct,target,predict\n")
332
+ for res in result:
333
+ f.write(f"{res}\n")
334
+
335
+ accuracy, metrics = analyze_predictions(res_saved)
336
+
337
+ print("All processes finished.\n\n")
src/moviedubber/configs/basemodel.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ arch:
3
+ dim: 1024
4
+ depth: 22
5
+ heads: 16
6
+ ff_mult: 2
7
+ text_dim: 512
8
+ conv_layers: 4
9
+
src/moviedubber/eval.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import string
4
+ from concurrent.futures import ProcessPoolExecutor
5
+ from pathlib import Path
6
+
7
+ import librosa
8
+ import numpy as np
9
+ import torch
10
+ from evaluate import load
11
+ from pymcd.mcd import Calculate_MCD
12
+ from tqdm import tqdm
13
+ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, Wav2Vec2FeatureExtractor, WavLMForXVector, pipeline
14
+
15
+
16
+ def convert_numbers_to_words(text):
17
+ """Convert single digits in text to words with spaces"""
18
+ number_word_map = {
19
+ "0": "zero",
20
+ "1": "one",
21
+ "2": "two",
22
+ "3": "three",
23
+ "4": "four",
24
+ "5": "five",
25
+ "6": "six",
26
+ "7": "seven",
27
+ "8": "eight",
28
+ "9": "nine",
29
+ }
30
+
31
+ words = text.split()
32
+ converted_words = []
33
+
34
+ for word in words:
35
+ # Check if the word contains both letters and numbers (like 'j4')
36
+ if any(c.isdigit() for c in word) and any(c.isalpha() for c in word):
37
+ # Split the word into parts and convert digits
38
+ new_word = ""
39
+ for c in word:
40
+ if c.isdigit():
41
+ new_word += " " + number_word_map[c]
42
+ else:
43
+ new_word += c
44
+ converted_words.append(new_word)
45
+ # Check if the word is a single digit
46
+ elif word.isdigit() and len(word) == 1:
47
+ converted_words.append(number_word_map[word])
48
+ else:
49
+ converted_words.append(word)
50
+
51
+ return " ".join(converted_words)
52
+
53
+
54
+ def clean_text(text):
55
+ text = convert_numbers_to_words(text)
56
+ text = text.translate(str.maketrans("", "", string.punctuation))
57
+ text = text.lower()
58
+ return text
59
+
60
+
61
+ def wer_pipe(gen_dir: str, target_dir: str, model_id="openai/whisper-large-v3-turbo"):
62
+ device = "cuda" if torch.cuda.is_available() else "cpu"
63
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
64
+ print(f"Using Model: {model_id} for WER Evaluation")
65
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
66
+ model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
67
+ ).to(device)
68
+
69
+ processor = AutoProcessor.from_pretrained(model_id)
70
+
71
+ pipe = pipeline(
72
+ "automatic-speech-recognition",
73
+ model=model,
74
+ tokenizer=processor.tokenizer,
75
+ feature_extractor=processor.feature_extractor,
76
+ torch_dtype=torch_dtype,
77
+ device=device,
78
+ )
79
+
80
+ gen_list = list(Path(gen_dir).glob("*.wav"))
81
+ for line in tqdm(gen_list, desc="Processing audio files"):
82
+ wav = line
83
+ if not wav.exists():
84
+ continue
85
+
86
+ text = pipe(librosa.load(wav, sr=16000)[0], generate_kwargs={"language": "english"})["text"]
87
+ with open(wav.with_suffix(".asrtxt"), "w") as fw:
88
+ fw.write(text)
89
+
90
+ wer_metric = load("wer")
91
+
92
+ val_list = list(Path(target_dir).glob("*.txt"))
93
+
94
+ wer = []
95
+ for txt in tqdm(val_list, desc="Calculating WER"):
96
+ try:
97
+ # Since the original text is automatically transcribed and has not been manually verified, all texts will be cleaned here.
98
+
99
+ target_text = " ".join(set(txt.read_text().splitlines()))
100
+ target_text = clean_text(target_text)
101
+
102
+ gen_text = " ".join(Path(os.path.join(gen_dir, txt.with_suffix(".asrtxt").name)).read_text().splitlines())
103
+ gen_text = clean_text(gen_text)
104
+
105
+ if target_text == "" or gen_text == "":
106
+ continue
107
+
108
+ wer_ = wer_metric.compute(references=[target_text], predictions=[gen_text])
109
+
110
+ except Exception as e:
111
+ print("Error in wer calculation: ", e)
112
+ continue
113
+
114
+ wer.append(wer_)
115
+
116
+ return np.mean(wer)
117
+
118
+
119
+ def spk_sim_pipe(gen_dir, target_dir):
120
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("microsoft/wavlm-base-sv")
121
+ model = WavLMForXVector.from_pretrained("microsoft/wavlm-base-sv").cuda()
122
+
123
+ cosine_sim = torch.nn.CosineSimilarity(dim=-1)
124
+
125
+ val_list = list(Path(target_dir).glob("*.wav"))
126
+
127
+ scos = []
128
+
129
+ for target_wav in tqdm(val_list, desc="Calculating speaker similarity"):
130
+ target = librosa.load(target_wav, sr=16000)[0]
131
+ gen = librosa.load(os.path.join(gen_dir, target_wav.name), sr=16000)[0]
132
+
133
+ try:
134
+ input1 = feature_extractor(gen, return_tensors="pt", sampling_rate=16000).to("cuda")
135
+ embeddings1 = model(**input1).embeddings
136
+
137
+ input2 = feature_extractor(target, return_tensors="pt", sampling_rate=16000).to("cuda")
138
+ embeddings2 = model(**input2).embeddings
139
+
140
+ similarity = cosine_sim(embeddings1[0], embeddings2[0])
141
+
142
+ except Exception as e:
143
+ print(f"Error in {target_wav}, {e}")
144
+ continue
145
+
146
+ scos.append(similarity.detach().cpu().numpy())
147
+
148
+ return np.mean(scos)
149
+
150
+
151
+ def calculate_mcd_for_wav(target_wav, gen_dir, mcd_toolbox_dtw, mcd_toolbox_dtw_sl):
152
+ _mcd_dtw = mcd_toolbox_dtw.calculate_mcd(target_wav, os.path.join(gen_dir, target_wav.name))
153
+ _mcd_dtw_sl = mcd_toolbox_dtw_sl.calculate_mcd(target_wav, os.path.join(gen_dir, target_wav.name))
154
+ return _mcd_dtw, _mcd_dtw_sl
155
+
156
+
157
+ def mcd_pipe(gen_dir, target_dir, num_processes=16):
158
+ mcd_toolbox_dtw = Calculate_MCD(MCD_mode="dtw")
159
+ mcd_toolbox_dtw_sl = Calculate_MCD(MCD_mode="dtw_sl")
160
+
161
+ val_list = list(Path(target_dir).glob("*.wav"))
162
+
163
+ mcd_dtw = []
164
+ mcd_dtw_sl = []
165
+
166
+ with ProcessPoolExecutor(max_workers=num_processes) as executor:
167
+ futures = [
168
+ executor.submit(calculate_mcd_for_wav, target_wav, gen_dir, mcd_toolbox_dtw, mcd_toolbox_dtw_sl)
169
+ for target_wav in val_list
170
+ ]
171
+ for future in tqdm(futures, desc="Calculating MCD"):
172
+ _mcd_dtw, _mcd_dtw_sl = future.result()
173
+ mcd_dtw.append(_mcd_dtw)
174
+ mcd_dtw_sl.append(_mcd_dtw_sl)
175
+
176
+ return np.mean(mcd_dtw), np.mean(mcd_dtw_sl)
177
+
178
+
179
+ def run_all_metrics(gen_dir, target_dir, whisper_model="openai/whisper-large-v3-turbo"):
180
+ """Run all evaluation metrics and return results"""
181
+ results = {}
182
+
183
+ print("Running WER evaluation...")
184
+ results["wer"] = wer_pipe(gen_dir, target_dir, model_id=whisper_model)
185
+
186
+ print("Running speaker similarity evaluation...")
187
+ results["speaker_similarity"] = spk_sim_pipe(gen_dir, target_dir)
188
+
189
+ print("Running MCD evaluation...")
190
+ mcd_dtw, mcd_dtw_sl = mcd_pipe(gen_dir, target_dir)
191
+ results["mcd_dtw"] = mcd_dtw
192
+ results["mcd_dtw_sl"] = mcd_dtw_sl
193
+
194
+ return results
195
+
196
+
197
+ if __name__ == "__main__":
198
+ parser = argparse.ArgumentParser(description="Audio evaluation metrics")
199
+ parser.add_argument("--gen_dir", type=str, required=True, help="Directory containing generated audio files")
200
+ parser.add_argument("--target_dir", type=str, required=True, help="Directory containing target audio files")
201
+ parser.add_argument(
202
+ "--metric",
203
+ type=str,
204
+ default="all",
205
+ choices=["wer", "spk_sim", "mcd", "all"],
206
+ help="Evaluation metric to use",
207
+ )
208
+ parser.add_argument(
209
+ "--whisper_model",
210
+ type=str,
211
+ default="openai/whisper-large-v3-turbo",
212
+ help="Whisper model to use for WER evaluation",
213
+ )
214
+ # python eval.py --gen_dir path/to/generated --target_dir path/to/target
215
+ # keep the name of gen_wav and target_wav the same
216
+ args = parser.parse_args()
217
+
218
+ gen_dir = args.gen_dir
219
+ target_dir = args.target_dir
220
+
221
+ if not os.path.exists(gen_dir):
222
+ raise ValueError(f"Generated audio directory does not exist: {gen_dir}")
223
+ if not os.path.exists(target_dir):
224
+ raise ValueError(f"Target audio directory does not exist: {target_dir}")
225
+
226
+ if args.metric == "all":
227
+ results = run_all_metrics(gen_dir, target_dir, args.whisper_model)
228
+ print("\nEvaluation Results:")
229
+ print(f"WER: {results['wer']:.4f}")
230
+ print(f"Speaker Similarity: {results['speaker_similarity']:.4f}")
231
+ print(f"MCD (DTW): {results['mcd_dtw']:.4f}")
232
+ print(f"MCD (DTW-SL): {results['mcd_dtw_sl']:.4f}")
233
+
234
+ elif args.metric == "wer":
235
+ wer = wer_pipe(gen_dir, target_dir, model_id=args.whisper_model)
236
+ print(f"WER: {wer:.4f}")
237
+
238
+ elif args.metric == "spk_sim":
239
+ spk_sim = spk_sim_pipe(gen_dir, target_dir)
240
+ print(f"Speaker Similarity: {spk_sim:.4f}")
241
+
242
+ elif args.metric == "mcd":
243
+ mcd_dtw, mcd_dtw_sl = mcd_pipe(gen_dir, target_dir)
244
+ print(f"MCD (DTW): {mcd_dtw:.4f}")
245
+ print(f"MCD (DTW-SL): {mcd_dtw_sl:.4f}")
src/moviedubber/infer/basic.toml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ ckpt_file = "/path/to/ckpt_file.pth"
2
+ vocab_file = "/path/to/vocab_file.txt"
3
+ vocoder_local_path = "/path/to/bigvgan"
4
+ campplus_path = "/path/to/campplus.onnx"
src/moviedubber/infer/utils_infer.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # A unified script for inference process
2
+ # Make adjustments inside functions, and consider both gradio and cli scripts if need to change func output format
3
+
4
+ import re
5
+ from importlib.resources import files
6
+
7
+ import matplotlib
8
+
9
+
10
+ matplotlib.use("Agg")
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn.functional as F
15
+ import torchaudio
16
+ import tqdm
17
+
18
+ from src.moviedubber.model import CFM
19
+ from src.moviedubber.model.utils import convert_char_to_pinyin, get_tokenizer
20
+
21
+
22
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
23
+
24
+ # -----------------------------------------
25
+
26
+ target_sample_rate = 24000
27
+ n_mel_channels = 100
28
+ hop_length = 256
29
+ win_length = 1024
30
+ n_fft = 1024
31
+ mel_spec_type = "bigvgan"
32
+ target_rms = 0.1
33
+ cross_fade_duration = 0.15
34
+ ode_method = "euler"
35
+ nfe_step = 32 # 16, 32
36
+ # cfg_strength = 2.0
37
+ cfg_strength = 1
38
+ sway_sampling_coef = -1.0
39
+ speed = 1.0
40
+ fix_duration = None
41
+
42
+ # -----------------------------------------
43
+
44
+
45
+ # chunk text into smaller pieces
46
+
47
+
48
+ def chunk_text(text, max_chars=135):
49
+ """
50
+ Splits the input text into chunks, each with a maximum number of characters.
51
+
52
+ Args:
53
+ text (str): The text to be split.
54
+ max_chars (int): The maximum number of characters per chunk.
55
+
56
+ Returns:
57
+ List[str]: A list of text chunks.
58
+ """
59
+ chunks = []
60
+ current_chunk = ""
61
+ # Split the text into sentences based on punctuation followed by whitespace
62
+ sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", text)
63
+
64
+ for sentence in sentences:
65
+ if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars:
66
+ current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
67
+ else:
68
+ if current_chunk:
69
+ chunks.append(current_chunk.strip())
70
+ current_chunk = sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
71
+
72
+ if current_chunk:
73
+ chunks.append(current_chunk.strip())
74
+
75
+ return chunks
76
+
77
+
78
+ # load vocoder
79
+ def load_vocoder(local_path, device=device):
80
+ from src.third_party.BigVGAN import bigvgan
81
+
82
+ vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
83
+
84
+ vocoder.remove_weight_norm()
85
+ vocoder = vocoder.eval().to(device)
86
+ return vocoder
87
+
88
+
89
+ # load model checkpoint for inference
90
+
91
+
92
+ def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
93
+ if dtype is None:
94
+ dtype = (
95
+ torch.float16
96
+ if "cuda" in device
97
+ and torch.cuda.get_device_properties(device).major >= 6
98
+ and not torch.cuda.get_device_name().endswith("[ZLUDA]")
99
+ else torch.float32
100
+ )
101
+ model = model.to(dtype)
102
+
103
+ ckpt_type = ckpt_path.split(".")[-1]
104
+ if ckpt_type == "safetensors":
105
+ from safetensors.torch import load_file
106
+
107
+ checkpoint = load_file(ckpt_path, device=device)
108
+ else:
109
+ checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)
110
+
111
+ if use_ema:
112
+ if ckpt_type == "safetensors":
113
+ checkpoint = {"ema_model_state_dict": checkpoint}
114
+ checkpoint["model_state_dict"] = {
115
+ k.replace("ema_model.", ""): v
116
+ for k, v in checkpoint["ema_model_state_dict"].items()
117
+ if k not in ["initted", "step"]
118
+ }
119
+
120
+ # patch for backward compatibility, 305e3ea
121
+ for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
122
+ if key in checkpoint["model_state_dict"]:
123
+ del checkpoint["model_state_dict"][key]
124
+
125
+ state_dict_result = model.load_state_dict(checkpoint["model_state_dict"], strict=False)
126
+ if state_dict_result.unexpected_keys:
127
+ print("\nUnexpected keys in state_dict:", state_dict_result.unexpected_keys)
128
+ if state_dict_result.missing_keys:
129
+ print("\nMissing keys in state_dict:", state_dict_result.missing_keys)
130
+ else:
131
+ if ckpt_type == "safetensors":
132
+ checkpoint = {"model_state_dict": checkpoint}
133
+ model.load_state_dict(checkpoint["model_state_dict"], strict=True)
134
+
135
+ del checkpoint
136
+ torch.cuda.empty_cache()
137
+
138
+ return model.to(device)
139
+
140
+
141
+ # load model for inference
142
+
143
+
144
+ def load_model(
145
+ model_cls,
146
+ model_cfg,
147
+ ckpt_path,
148
+ controlnet=None,
149
+ mel_spec_type=mel_spec_type,
150
+ vocab_file="",
151
+ ode_method=ode_method,
152
+ use_ema=True,
153
+ device=device,
154
+ ):
155
+ tokenizer = "custom"
156
+
157
+ print("\nvocab : ", vocab_file)
158
+ print("token : ", tokenizer)
159
+ print("model : ", ckpt_path, "\n")
160
+
161
+ vocab_char_map, vocab_size = get_tokenizer(vocab_file, tokenizer)
162
+
163
+ if controlnet is not None:
164
+ controlnet = controlnet(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels)
165
+
166
+ model = CFM(
167
+ transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
168
+ mel_spec_kwargs=dict(
169
+ n_fft=n_fft,
170
+ hop_length=hop_length,
171
+ win_length=win_length,
172
+ n_mel_channels=n_mel_channels,
173
+ target_sample_rate=target_sample_rate,
174
+ mel_spec_type=mel_spec_type,
175
+ ),
176
+ odeint_kwargs=dict(
177
+ method=ode_method,
178
+ ),
179
+ vocab_char_map=vocab_char_map,
180
+ controlnet=controlnet,
181
+ ).to(device)
182
+
183
+ dtype = torch.float32 if mel_spec_type == "bigvgan" else None
184
+ model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
185
+
186
+ return model
187
+
188
+
189
+ def infer_process(
190
+ ref_audio,
191
+ ref_text,
192
+ ref_clip,
193
+ ref_lip,
194
+ gen_text,
195
+ gen_clip,
196
+ gen_lip,
197
+ model_obj,
198
+ vocoder,
199
+ gen_caption=None,
200
+ mel_spec_type=mel_spec_type,
201
+ progress=tqdm,
202
+ target_rms=target_rms,
203
+ cross_fade_duration=cross_fade_duration,
204
+ nfe_step=nfe_step,
205
+ cfg_strength=cfg_strength,
206
+ sway_sampling_coef=sway_sampling_coef,
207
+ speed=speed,
208
+ fix_duration=fix_duration,
209
+ device=device,
210
+ ):
211
+ # Split the input text into batches
212
+ audio, sr = torchaudio.load(ref_audio)
213
+ max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
214
+ gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
215
+
216
+ return infer_batch_process(
217
+ (audio, sr),
218
+ ref_text,
219
+ ref_clip,
220
+ ref_lip,
221
+ gen_text_batches,
222
+ gen_clip,
223
+ gen_lip,
224
+ model_obj,
225
+ vocoder,
226
+ gen_caption=gen_caption,
227
+ mel_spec_type=mel_spec_type,
228
+ progress=progress,
229
+ target_rms=target_rms,
230
+ cross_fade_duration=cross_fade_duration,
231
+ nfe_step=nfe_step,
232
+ cfg_strength=cfg_strength,
233
+ sway_sampling_coef=sway_sampling_coef,
234
+ speed=speed,
235
+ fix_duration=fix_duration,
236
+ device=device,
237
+ )
238
+
239
+
240
+ # infer batches
241
+
242
+
243
+ def infer_batch_process(
244
+ ref_audio,
245
+ ref_text,
246
+ ref_clip,
247
+ ref_lip,
248
+ gen_text_batches,
249
+ gen_clip,
250
+ gen_lip,
251
+ model_obj,
252
+ vocoder,
253
+ gen_caption=None,
254
+ mel_spec_type="vocos",
255
+ target_rms=0.1,
256
+ cross_fade_duration=0.15,
257
+ nfe_step=32,
258
+ cfg_strength=2.0,
259
+ sway_sampling_coef=-1,
260
+ speed=1,
261
+ fix_duration=None,
262
+ device=None,
263
+ ):
264
+ audio, sr = ref_audio
265
+ if audio.shape[0] > 1:
266
+ audio = torch.mean(audio, dim=0, keepdim=True)
267
+
268
+ rms = torch.sqrt(torch.mean(torch.square(audio)))
269
+ if rms < target_rms:
270
+ audio = audio * target_rms / rms
271
+ if sr != target_sample_rate:
272
+ resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
273
+ audio = resampler(audio)
274
+ audio = audio.to(device)
275
+
276
+ generated_waves = []
277
+ spectrograms = []
278
+
279
+ if len(ref_text[-1].encode("utf-8")) == 1:
280
+ ref_text = ref_text + " "
281
+
282
+ for i, gen_text in enumerate(gen_text_batches):
283
+ # Prepare the text
284
+ text_list = [ref_text + gen_text]
285
+ final_text_list = convert_char_to_pinyin(text_list)
286
+
287
+ ref_audio_len = audio.shape[-1] // hop_length
288
+ if ref_clip is not None:
289
+ ref_clip = F.interpolate(
290
+ ref_clip.unsqueeze(0).transpose(1, 2), size=ref_audio_len, mode="linear", align_corners=False
291
+ )
292
+ else:
293
+ ref_clip = torch.zeros(1, 768, ref_audio_len).to(device)
294
+
295
+ if fix_duration is not None:
296
+ duration = int(fix_duration * target_sample_rate / hop_length)
297
+ gen_audio_len = duration - ref_audio_len
298
+
299
+ gen_clip = F.interpolate(
300
+ gen_clip.unsqueeze(0).transpose(1, 2), size=gen_audio_len, mode="linear", align_corners=False
301
+ )
302
+
303
+ else:
304
+ # Calculate duration
305
+ ref_text_len = len(ref_text.encode("utf-8"))
306
+ gen_text_len = len(gen_text.encode("utf-8"))
307
+
308
+ gen_audio_len = int(ref_audio_len / ref_text_len * gen_text_len)
309
+ gen_clip = F.interpolate(
310
+ gen_clip.unsqueeze(0).transpose(1, 2), size=gen_audio_len, mode="linear", align_corners=False
311
+ )
312
+
313
+ duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
314
+
315
+ if ref_lip is None:
316
+ ref_lip = torch.zeros(ref_audio_len // 4, 512)
317
+
318
+ clip = torch.cat([ref_clip, gen_clip], dim=-1).permute(0, 2, 1).to(device)
319
+
320
+ if gen_lip is not None:
321
+ lip = torch.cat([ref_lip.unsqueeze(0).transpose(1, 2), gen_lip.unsqueeze(0).transpose(1, 2)], dim=-1).to(
322
+ device
323
+ )
324
+ lip = F.pad(lip, (0, duration - lip.size(-1)), value=0).permute(0, 2, 1)
325
+ else:
326
+ lip = None
327
+
328
+ # inference
329
+ with torch.inference_mode():
330
+ generated, _ = model_obj.sample(
331
+ cond=audio,
332
+ text=final_text_list,
333
+ clip=clip,
334
+ lip=lip,
335
+ caption_emb=gen_caption,
336
+ duration=duration,
337
+ steps=nfe_step,
338
+ cfg_strength=cfg_strength,
339
+ sway_sampling_coef=sway_sampling_coef,
340
+ no_ref_audio=False,
341
+ )
342
+
343
+ generated = generated.to(torch.float32)
344
+ generated = generated[:, ref_audio_len:, :]
345
+ generated_mel_spec = generated.permute(0, 2, 1)
346
+ if mel_spec_type == "vocos":
347
+ generated_wave = vocoder.decode(generated_mel_spec)
348
+ elif mel_spec_type == "bigvgan":
349
+ generated_wave = vocoder(generated_mel_spec)
350
+ if rms < target_rms:
351
+ generated_wave = generated_wave * rms / target_rms
352
+
353
+ # wav -> numpy
354
+ generated_wave = generated_wave.squeeze().cpu().numpy()
355
+
356
+ generated_waves.append(generated_wave)
357
+ spectrograms.append(generated_mel_spec[0].cpu().numpy())
358
+
359
+ # Combine all generated waves with cross-fading
360
+ if cross_fade_duration <= 0:
361
+ # Simply concatenate
362
+ final_wave = np.concatenate(generated_waves)
363
+ else:
364
+ final_wave = generated_waves[0]
365
+ for i in range(1, len(generated_waves)):
366
+ prev_wave = final_wave
367
+ next_wave = generated_waves[i]
368
+
369
+ # Calculate cross-fade samples, ensuring it does not exceed wave lengths
370
+ cross_fade_samples = int(cross_fade_duration * target_sample_rate)
371
+ cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
372
+
373
+ if cross_fade_samples <= 0:
374
+ # No overlap possible, concatenate
375
+ final_wave = np.concatenate([prev_wave, next_wave])
376
+ continue
377
+
378
+ # Overlapping parts
379
+ prev_overlap = prev_wave[-cross_fade_samples:]
380
+ next_overlap = next_wave[:cross_fade_samples]
381
+
382
+ # Fade out and fade in
383
+ fade_out = np.linspace(1, 0, cross_fade_samples)
384
+ fade_in = np.linspace(0, 1, cross_fade_samples)
385
+
386
+ # Cross-faded overlap
387
+ cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
388
+
389
+ # Combine
390
+ new_wave = np.concatenate(
391
+ [prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]]
392
+ )
393
+
394
+ final_wave = new_wave
395
+
396
+ # Create a combined spectrogram
397
+ combined_spectrogram = np.concatenate(spectrograms, axis=1)
398
+
399
+ return final_wave, target_sample_rate, combined_spectrogram
src/moviedubber/infer/video_preprocess.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ import logging
4
+ import os
5
+ import os.path as osp
6
+ from pathlib import Path
7
+ from typing import Optional, Union
8
+
9
+ import cv2
10
+ import imageio
11
+ import numpy as np
12
+ import torch
13
+ import torch.multiprocessing as mp
14
+ from decord import AudioReader, VideoReader, cpu
15
+ from PIL import Image
16
+ from tqdm import tqdm
17
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
18
+
19
+
20
+ logging.basicConfig(level=logging.ERROR, format="%(asctime)s - %(levelname)s - %(message)s")
21
+
22
+
23
+ NUM_FRAMES = None # NUM_FRAMES = 160
24
+ MAX_FRAMES = None # MAX_FRAMES = 256
25
+ NUM_FRAMES_PER_SECOND = 10
26
+
27
+
28
+ def get_full_indices(reader: Union[VideoReader, AudioReader]) -> np.ndarray:
29
+ if isinstance(reader, VideoReader):
30
+ return np.linspace(0, len(reader) - 1, len(reader), dtype=int)
31
+ elif isinstance(reader, AudioReader):
32
+ return np.linspace(0, reader.shape[-1] - 1, reader.shape[-1], dtype=int)
33
+
34
+
35
+ def create_output_directories(output_dir):
36
+ try:
37
+ os.makedirs(osp.join(output_dir, "audio"), exist_ok=True)
38
+ os.makedirs(osp.join(output_dir, "video"), exist_ok=True)
39
+ except OSError as e:
40
+ print(f"Error creating directories: {e}")
41
+ raise
42
+
43
+
44
+ def frame_sample(duration, mode="uniform", num_frames=None, fps=None):
45
+ if mode == "uniform":
46
+ assert num_frames is not None, "Number of frames must be provided for uniform sampling."
47
+ # NOTE: v1 version
48
+ # Calculate the size of each segment from which a frame will be extracted
49
+ seg_size = float(duration - 1) / num_frames
50
+
51
+ frame_ids = []
52
+ for i in range(num_frames):
53
+ # Calculate the start and end indices of each segment
54
+ start = seg_size * i
55
+ end = seg_size * (i + 1)
56
+ # Append the middle index of the segment to the list
57
+ frame_ids.append((start + end) / 2)
58
+
59
+ return np.round(np.array(frame_ids) + 1e-6).astype(int)
60
+ # NOTE: v0 version
61
+ # return np.linspace(0, duration-1, num_frames, dtype=int)
62
+ elif mode == "fps":
63
+ assert fps is not None, "FPS must be provided for FPS sampling."
64
+ segment_len = min(fps // NUM_FRAMES_PER_SECOND, duration)
65
+ return np.arange(segment_len // 2, duration, segment_len, dtype=int)
66
+ else:
67
+ raise ImportError(f"Unsupported frame sampling mode: {mode}")
68
+
69
+
70
+ def expand2square(pil_img, background_color):
71
+ width, height = pil_img.size
72
+ if width == height:
73
+ return pil_img
74
+ elif width > height:
75
+ result = Image.new(pil_img.mode, (width, width), background_color)
76
+ result.paste(pil_img, (0, (width - height) // 2))
77
+ return result
78
+ else:
79
+ result = Image.new(pil_img.mode, (height, height), background_color)
80
+ result.paste(pil_img, ((height - width) // 2, 0))
81
+ return result
82
+
83
+
84
+ def process_video(video_path, processor, s=None, e=None, aspect_ratio="pad", num_frames=NUM_FRAMES):
85
+ if isinstance(video_path, str):
86
+ if s is not None and e is not None:
87
+ s = s if s >= 0.0 else 0.0
88
+ e = e if e >= 0.0 else 0.0
89
+ if s > e:
90
+ s, e = e, s
91
+ elif s == e:
92
+ e = s + 1
93
+
94
+ # 1. Loading Video
95
+ if os.path.isdir(video_path):
96
+ frame_files = sorted(os.listdir(video_path))
97
+
98
+ fps = 3
99
+ num_frames_of_video = len(frame_files)
100
+ elif video_path.endswith(".gif"):
101
+ gif_reader = imageio.get_reader(video_path)
102
+
103
+ fps = 25
104
+ num_frames_of_video = len(gif_reader)
105
+ else:
106
+ try:
107
+ vreader = VideoReader(video_path, ctx=cpu(0), num_threads=1)
108
+ except: # noqa: E722
109
+ return None
110
+
111
+ fps = vreader.get_avg_fps()
112
+ num_frames_of_video = len(vreader)
113
+
114
+ # 2. Determine frame range & Calculate frame indices
115
+ f_start = 0 if s is None else max(int(s * fps) - 1, 0)
116
+ f_end = num_frames_of_video - 1 if e is None else min(int(e * fps) - 1, num_frames_of_video - 1)
117
+ frame_indices = list(range(f_start, f_end + 1))
118
+
119
+ duration = len(frame_indices)
120
+ # 3. Sampling frame indices
121
+ if num_frames is None:
122
+ sampled_frame_indices = [frame_indices[i] for i in frame_sample(duration, mode="fps", fps=fps)]
123
+ else:
124
+ sampled_frame_indices = [
125
+ frame_indices[i] for i in frame_sample(duration, mode="uniform", num_frames=num_frames)
126
+ ]
127
+
128
+ # 4. Acquire frame data
129
+ if os.path.isdir(video_path):
130
+ video_data = [Image.open(os.path.join(video_path, frame_files[f_idx])) for f_idx in sampled_frame_indices]
131
+ elif video_path.endswith(".gif"):
132
+ video_data = [
133
+ Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB))
134
+ for idx, frame in enumerate(gif_reader)
135
+ if idx in sampled_frame_indices
136
+ ]
137
+ else:
138
+ video_data = [Image.fromarray(frame) for frame in vreader.get_batch(sampled_frame_indices).asnumpy()]
139
+
140
+ elif isinstance(video_path, np.ndarray):
141
+ video_data = [Image.fromarray(f) for f in video_path]
142
+ elif isinstance(video_path, list) and isinstance(video_path[0], np.ndarray):
143
+ video_data = [Image.fromarray(f) for f in video_path]
144
+ elif isinstance(video_path, list) and isinstance(video_path[0], str):
145
+ video_data = [Image.open(f) for f in video_path]
146
+ elif isinstance(video_path, list) and isinstance(video_path[0], Image.Image):
147
+ video_data = video_path
148
+ else:
149
+ raise ValueError(f"Unsupported video path type: {type(video_path)}")
150
+
151
+ while num_frames is not None and len(video_data) < num_frames:
152
+ video_data.append(Image.fromarray(np.zeros((*video_data[-1].size, 3), dtype=np.uint8)))
153
+
154
+ # MAX_FRAMES filter
155
+ if MAX_FRAMES:
156
+ video_data = video_data[:MAX_FRAMES]
157
+
158
+ if aspect_ratio == "pad":
159
+ images = [expand2square(f, tuple(int(x * 255) for x in processor.image_mean)) for f in video_data]
160
+ else:
161
+ images = list(video_data)
162
+ video = processor.preprocess(images, return_tensors="pt")["pixel_values"]
163
+ return video
164
+
165
+
166
+ class VideoFeatureExtractor:
167
+ def __init__(
168
+ self,
169
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = "openai/clip-vit-large-patch14",
170
+ device: str = "cuda",
171
+ ):
172
+ self.device = device
173
+
174
+ self.processor = CLIPImageProcessor.from_pretrained(pretrained_model_name_or_path)
175
+ self.model = CLIPVisionModelWithProjection.from_pretrained(pretrained_model_name_or_path).to(self.device).half()
176
+
177
+ def extract_features(self, video_path):
178
+ images = process_video(video_path, self.processor)
179
+ if images is None:
180
+ return None
181
+ clip_feature = self.model(images.to(self.device).half()).image_embeds
182
+
183
+ return clip_feature
184
+
185
+
186
+ def video_processor(item, feature_extractor, output_dir=None):
187
+ video_path = Path(item)
188
+ if not os.path.exists(video_path):
189
+ return
190
+
191
+ clip_feature = feature_extractor.extract_features(str(video_path))
192
+ if clip_feature is None:
193
+ return
194
+
195
+ if output_dir is not None and not os.path.exists(output_dir):
196
+ os.makedirs(output_dir, exist_ok=True)
197
+ output_path = osp.join(output_dir, f"{video_path.stem}.pt")
198
+ else:
199
+ output_path = video_path.with_suffix(".clip")
200
+
201
+ torch.save(clip_feature, output_path)
202
+
203
+
204
+ def s_thread(items, id, device, output_dir):
205
+ feature_extractor = VideoFeatureExtractor(device=device)
206
+ for i, data in tqdm(enumerate(items), total=len(items), position=id):
207
+ video_processor(data, feature_extractor, output_dir)
208
+
209
+
210
+ def load_tensor(file_path, map_location="cpu", weights_only=True):
211
+ try:
212
+ return torch.load(file_path, map_location=map_location, weights_only=weights_only)
213
+ except FileNotFoundError:
214
+ logging.error(f"File not found: {file_path}")
215
+ except torch.serialization.pickle.UnpicklingError:
216
+ logging.error(f"Failed to unpickle file: {file_path}")
217
+ except Exception as e:
218
+ logging.error(f"An error occurred while loading {file_path}: {e}")
219
+ return None
220
+
221
+
222
+ def post_check(directory):
223
+ if not osp.isdir(directory):
224
+ logging.error(f"Invalid directory: {directory}")
225
+ return
226
+
227
+ video_dir = osp.join(directory, "video")
228
+ pt_files = glob.glob(f"{video_dir}/*.pt")
229
+
230
+ for file_path in tqdm(pt_files):
231
+ embeds = load_tensor(file_path)
232
+ if embeds is None:
233
+ continue
234
+
235
+ audio_file_path = file_path.replace("video", "audio")
236
+ audio_text_embeds = load_tensor(audio_file_path)
237
+ if audio_text_embeds is None:
238
+ logging.error(f"Failed to load audio file: {audio_file_path}")
239
+ continue
240
+
241
+ text = audio_text_embeds.get("text")
242
+ mel = audio_text_embeds.get("mel")
243
+ if text is None or mel is None:
244
+ logging.error(f"Missing 'text' or 'mel' in {audio_file_path}")
245
+
246
+
247
+ def args_parse():
248
+ args = argparse.ArgumentParser()
249
+ args.add_argument("--data_type", "-d", type=str, default="video", help="'audio' or 'video'")
250
+ args.add_argument("--check", action="store_true", help="post check, if any pt file was damaged")
251
+
252
+ args.add_argument(
253
+ "--num_threads",
254
+ "-n",
255
+ type=int,
256
+ default=1,
257
+ required=False,
258
+ help="num_threads",
259
+ )
260
+ args.add_argument(
261
+ "--input",
262
+ "-i",
263
+ type=str,
264
+ required=True,
265
+ help="input file path",
266
+ )
267
+ args.add_argument(
268
+ "--output_dir",
269
+ "-o",
270
+ type=str,
271
+ default=None,
272
+ help="output folder path",
273
+ )
274
+ args.add_argument("--multi_gpu", "-m", nargs="+", type=str, default=None, required=False, help="GPU ids")
275
+
276
+ args = args.parse_args()
277
+ return args
278
+
279
+
280
+ if __name__ == "__main__":
281
+ args_main = args_parse()
282
+
283
+ if args_main.check:
284
+ post_check(args_main.output_dir)
285
+ exit(0)
286
+
287
+ gpu_ids = ["cuda:0"]
288
+ if args_main.multi_gpu is not None:
289
+ gpu_ids = [f"cuda:{gpu}" for gpu in args_main.multi_gpu]
290
+
291
+ output_dir = args_main.output_dir
292
+ if output_dir is not None:
293
+ create_output_directories(output_dir)
294
+
295
+ rows = None
296
+ rows = [it.strip() for it in Path(args_main.input).read_text().split("\n") if it.strip() != ""]
297
+
298
+ chunks = np.array_split(rows, args_main.num_threads)
299
+ chunks = [chunk.tolist() for chunk in chunks]
300
+
301
+ processes = []
302
+ mp.set_start_method("spawn", force=True)
303
+ for idx, chunk in enumerate(chunks):
304
+ device = gpu_ids[idx % len(gpu_ids)]
305
+ p = mp.Process(target=s_thread, args=(chunk, idx, device, output_dir))
306
+ processes.append(p)
307
+ p.start()
308
+
309
+ for process in processes:
310
+ process.join()
311
+
312
+ # DEBUG
313
+ # s_thread(args_main, input_dir, output_dir, chunks[0], 0, "cuda:0")
314
+
315
+ print("process done!")
src/moviedubber/infer_with_mmlm_result.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import os.path as osp
4
+ import random
5
+ import sys
6
+ from pathlib import Path
7
+
8
+ import numpy as np
9
+ import onnxruntime
10
+ import soundfile
11
+ import tomli
12
+ import torch
13
+ import torch.nn.functional as F
14
+ import torchaudio
15
+ import torchaudio.compliance.kaldi as kaldi
16
+ from moviepy import AudioFileClip, VideoFileClip
17
+ from omegaconf import OmegaConf
18
+ from pydub import AudioSegment
19
+ from tqdm import tqdm
20
+
21
+
22
+ src_path = Path(osp.dirname(__file__)).parent.parent
23
+ sys.path.insert(0, str(src_path))
24
+ sys.path.append(str(src_path / "src/third_party/BigVGAN"))
25
+
26
+ from src.moviedubber.infer.utils_infer import (
27
+ cfg_strength,
28
+ chunk_text,
29
+ load_model,
30
+ load_vocoder,
31
+ mel_spec_type,
32
+ nfe_step,
33
+ sway_sampling_coef,
34
+ )
35
+ from src.moviedubber.infer.video_preprocess import VideoFeatureExtractor
36
+ from src.moviedubber.model import ControlNetDiT, DiT
37
+ from src.moviedubber.model.utils import convert_char_to_pinyin
38
+
39
+
40
+ def concat_movie_with_audio(wav, video_path, out_dir):
41
+ if not os.path.exists(wav):
42
+ raise FileNotFoundError(f"Audio file {wav} does not exist")
43
+
44
+ if not os.path.exists(video_path):
45
+ raise FileNotFoundError(f"Video file {video_path} does not exist")
46
+
47
+ try:
48
+ with (
49
+ AudioFileClip(str(wav)) as audio_clip,
50
+ VideoFileClip(str(video_path)) as video_clip,
51
+ ):
52
+ duration = min(video_clip.duration, audio_clip.duration)
53
+
54
+ video_subclip = video_clip.subclipped(0, duration)
55
+ audio_subclip = audio_clip.subclipped(0, duration)
56
+
57
+ final_video = video_subclip.with_audio(audio_subclip)
58
+
59
+ output_path = wav.replace(".wav", ".mp4")
60
+
61
+ final_video.write_videofile(
62
+ str(output_path),
63
+ codec="libx264",
64
+ audio_codec="mp3",
65
+ fps=25,
66
+ logger=None,
67
+ threads=1,
68
+ temp_audiofile_path=out_dir,
69
+ )
70
+
71
+ except Exception as e:
72
+ print(f"Error processing {wav} {video_path}: {str(e)}")
73
+
74
+ return output_path
75
+
76
+
77
+ def get_spk_emb(audio_path, ort_session):
78
+ audio, sample_rate = torchaudio.load(str(audio_path))
79
+ if sample_rate != 16000:
80
+ audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
81
+ feat = kaldi.fbank(audio, num_mel_bins=80, dither=0, sample_frequency=16000)
82
+ feat = feat - feat.mean(dim=0, keepdim=True)
83
+ embedding = (
84
+ ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0]
85
+ .flatten()
86
+ .tolist()
87
+ )
88
+ return embedding
89
+
90
+
91
+ def load_models(config, device):
92
+ model_cfg = config.get("model_cfg", "src/moviedubber/configs/basemodel.yaml")
93
+ ckpt_file = config.get("ckpt_file", None)
94
+ campplus_path = config.get("campplus_path", None)
95
+ vocab_file = config.get("vocab_file", None)
96
+
97
+ vocoder_local_path = config.get("vocoder_local_path", None)
98
+
99
+ if ckpt_file is None or vocab_file is None or vocoder_local_path is None or campplus_path is None:
100
+ raise ValueError("ckpt_file, vocab_file and vocoder_local_path must be specified")
101
+
102
+ vocoder_name = config.get("vocoder_name", mel_spec_type)
103
+
104
+ vocoder = load_vocoder(local_path=vocoder_local_path, device=device)
105
+
106
+ model_cls = DiT
107
+ model_cfg = OmegaConf.load(model_cfg).model.arch
108
+ controlnet = ControlNetDiT
109
+
110
+ ema_model = load_model(
111
+ model_cls,
112
+ model_cfg,
113
+ ckpt_file,
114
+ mel_spec_type=vocoder_name,
115
+ vocab_file=vocab_file,
116
+ controlnet=controlnet,
117
+ device=device,
118
+ )
119
+
120
+ option = onnxruntime.SessionOptions()
121
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
122
+ option.intra_op_num_threads = 1
123
+ providers = ["CPUExecutionProvider"]
124
+ ort_session = onnxruntime.InferenceSession(
125
+ campplus_path,
126
+ sess_options=option,
127
+ providers=providers,
128
+ )
129
+ return ema_model, vocoder, ort_session
130
+
131
+
132
+ def main(config, device, chunk, gen_dir, target_dir, out_dir, idx):
133
+ ema_model, vocoder, ort_session = load_models(config, device=device)
134
+
135
+ videofeature_extractor = VideoFeatureExtractor(device=device)
136
+
137
+ for it in tqdm(chunk, total=len(chunk), position=idx, desc=f"Processing {idx}"):
138
+ wav, video, text, ref_wav = it
139
+
140
+ with open(f"{target_dir}/{wav.split('/')[-1].split('.')[0]}.txt", "a") as f:
141
+ f.write(text + "\n")
142
+
143
+ if wav.endswith(".mp3"):
144
+ audio = AudioSegment.from_mp3(wav)
145
+
146
+ wav_file = wav.replace(".mp3", ".wav")
147
+ audio.export(wav_file, format="wav")
148
+
149
+ wav = Path(wav).with_suffix(".wav")
150
+ if wav.exists() is False:
151
+ continue
152
+
153
+ os.system(f"cp {wav} {target_dir}/")
154
+
155
+ gen_audio, sr = torchaudio.load(str(wav))
156
+ resampler = torchaudio.transforms.Resample(sr, 24000)
157
+ if sr != 24000:
158
+ gen_audio = resampler(gen_audio)
159
+
160
+ if gen_audio.shape[0] > 1:
161
+ gen_audio = torch.mean(gen_audio, dim=0, keepdim=True)
162
+
163
+ gen_video = video
164
+ gen_clip_path = gen_video.replace(".mp4", ".clip")
165
+
166
+ if not os.path.exists(gen_clip_path):
167
+ gen_clip = videofeature_extractor.extract_features(gen_video)
168
+
169
+ torch.save(gen_clip.detach().cpu(), gen_clip_path)
170
+
171
+ else:
172
+ gen_clip = torch.load(gen_clip_path, weights_only=True).to(device=device, dtype=torch.float32)
173
+
174
+ if ref_wav == "None":
175
+ use_ref_audio = False
176
+ gen_text_ = text
177
+
178
+ gen_clip_ = gen_clip
179
+
180
+ ref_audio_ = gen_audio
181
+
182
+ spk_emb = torch.zeros(1, 1, 192).to(device=device, dtype=torch.float32)
183
+
184
+ else:
185
+ use_ref_audio = True
186
+ ref_audio = Path(ref_wav)
187
+
188
+ spk_emb = get_spk_emb(ref_audio, ort_session)
189
+ spk_emb = torch.tensor(spk_emb).to(device=device, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
190
+
191
+ ref_text = ref_audio.with_suffix(".txt").read_text().strip()
192
+ gen_text_ = ref_text + " " + text
193
+
194
+ if ref_audio.exists() is False:
195
+ raise Exception(f"ref_audio {ref_audio} not found")
196
+
197
+ if ref_audio.suffix == ".mp3":
198
+ audio = AudioSegment.from_mp3(ref_audio)
199
+
200
+ wav_file = ref_audio.with_suffix(".wav")
201
+ audio.export(wav_file, format="wav")
202
+
203
+ ref_audio_, _ = torchaudio.load(str(ref_audio.with_suffix(".wav")))
204
+ resampler = torchaudio.transforms.Resample(sr, 24000)
205
+ if sr != 24000:
206
+ ref_audio_ = resampler(ref_audio_)
207
+
208
+ if ref_audio_.shape[0] > 1:
209
+ ref_audio_ = torch.mean(ref_audio_, dim=0, keepdim=True)
210
+
211
+ ref_video = ref_audio.with_suffix(".mp4")
212
+ ref_clip_path = ref_video.with_suffix(".clip")
213
+
214
+ if not ref_clip_path.exists():
215
+ ref_clip = videofeature_extractor.extract_features(str(ref_video))
216
+
217
+ torch.save(ref_clip.detach().cpu(), ref_clip_path)
218
+
219
+ else:
220
+ ref_clip = torch.load(ref_clip_path, weights_only=True).to(device=device, dtype=torch.float32)
221
+
222
+ gen_clip_ = torch.cat([ref_clip, gen_clip], dim=0)
223
+
224
+ gen_audio_len = gen_audio.shape[1] // 256
225
+
226
+ if use_ref_audio:
227
+ ref_audio_len = ref_audio_.shape[1] // 256
228
+ duration = ref_audio_len + gen_audio_len
229
+ else:
230
+ duration = gen_audio_len
231
+
232
+ gen_clip_ = gen_clip_.unsqueeze(0).to(device=device, dtype=torch.float32).transpose(1, 2)
233
+ gen_clip_ = F.interpolate(gen_clip_, size=duration, mode="linear", align_corners=False).transpose(1, 2)
234
+
235
+ gen_text_batches = chunk_text(gen_text_, max_chars=1024)
236
+ final_text_list = convert_char_to_pinyin(gen_text_batches)
237
+
238
+ with torch.inference_mode():
239
+ generated, _ = ema_model.sample(
240
+ cond=ref_audio_.to(device),
241
+ text=final_text_list,
242
+ clip=gen_clip_,
243
+ spk_emb=spk_emb,
244
+ duration=duration,
245
+ steps=nfe_step,
246
+ cfg_strength=cfg_strength,
247
+ sway_sampling_coef=sway_sampling_coef,
248
+ no_ref_audio=not use_ref_audio,
249
+ )
250
+
251
+ generated = generated.to(torch.float32)
252
+
253
+ if use_ref_audio:
254
+ generated = generated[:, ref_audio_len:, :]
255
+
256
+ generated_mel_spec = generated.permute(0, 2, 1)
257
+ generated_wave = vocoder(generated_mel_spec)
258
+
259
+ generated_wave = generated_wave.squeeze().cpu().numpy()
260
+
261
+ out_path = osp.join(gen_dir, f"{wav.stem}.wav")
262
+ soundfile.write(out_path, generated_wave, samplerate=24000)
263
+ _ = concat_movie_with_audio(out_path, gen_video, out_dir)
264
+
265
+
266
+ if __name__ == "__main__":
267
+ import torch.multiprocessing as mp
268
+
269
+ parser = argparse.ArgumentParser(
270
+ prog="python3 infer-cli.py",
271
+ description="Commandline interface for moviedubber infer with Advanced Batch Processing.",
272
+ epilog="Specify options above to override one or more settings from config.",
273
+ )
274
+ parser.add_argument(
275
+ "-c",
276
+ "--config",
277
+ type=str,
278
+ default="src/moviedubber/infer/basic.toml",
279
+ help="The configuration file, default see infer/basic.toml",
280
+ )
281
+ parser.add_argument("-i", "--input_list", type=str, required=True, help="The val list file")
282
+ parser.add_argument("-s", "--ref_spk_list", type=str, required=True, help="The spk list file")
283
+ parser.add_argument("-o", "--out_dir", type=str, default="data/dubberout", help="The output directory")
284
+ parser.add_argument("--gpuids", type=str, help="GPU ids to use, split by comma")
285
+ parser.add_argument("--nums_workers", type=int, default=1, help="Number of workers for per gpu")
286
+
287
+ args = parser.parse_args()
288
+
289
+ out_dir = args.out_dir
290
+ input_list = args.input_list
291
+ gpu_ids = args.gpuids.split(",") if args.gpuids else ["0"]
292
+ num_pre = args.nums_workers
293
+ spk_ref_path = args.ref_spk_list
294
+
295
+ config = tomli.load(open(args.config, "rb"))
296
+
297
+ gen_lst = Path(input_list).read_text().splitlines()[1:]
298
+
299
+ gen_pre_conf = []
300
+
301
+ spk_lines = Path(spk_ref_path).read_text().splitlines()
302
+
303
+ for idx, line in enumerate(gen_lst):
304
+ if line.strip():
305
+ mp4_path, is_correc, _, _ = line.split(",")
306
+
307
+ wav_path = mp4_path.replace(".mp4", ".mp3")
308
+ text = Path(wav_path.replace(".mp3", ".txt")).read_text().strip()
309
+
310
+ if is_correc == "True":
311
+ ref_wav = spk_lines[idx].split(",")[1].strip()
312
+ else:
313
+ ref_wav = random.choice(spk_lines).split(",")[-1].strip() # Use random speaker for incorrect samples
314
+
315
+ gen_pre_conf.append([wav_path, mp4_path, text, ref_wav])
316
+
317
+ chunks = np.array_split(gen_pre_conf, len(gpu_ids) * num_pre)
318
+
319
+ gen_dir = os.path.join(out_dir, "generated")
320
+ target_dir = os.path.join(out_dir, "target")
321
+
322
+ if os.path.exists(gen_dir) is False or os.path.exists(target_dir) is False:
323
+ os.makedirs(gen_dir)
324
+ os.makedirs(target_dir)
325
+
326
+ mp.set_start_method("spawn", force=True)
327
+ processes = []
328
+ for idx, chunk in enumerate(chunks):
329
+ device = gpu_ids[idx % len(gpu_ids)]
330
+
331
+ device = f"cuda:{device}"
332
+ p = mp.Process(target=main, args=(config, device, chunk, gen_dir, target_dir, out_dir, idx))
333
+ processes.append(p)
334
+ p.start()
335
+
336
+ for process in processes:
337
+ process.join()
338
+
339
+ print("All processes finished.")
src/moviedubber/model/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .cfm import CFM
2
+ from .dit import ControlNetDiT, DiT
3
+
4
+
5
+ __all__ = ["CFM", "UNetT", "DiT", "ControlNetDiT", "MMDiT", "Trainer"]
src/moviedubber/model/cfm.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/SWivid/F5-TTS/blob/main/src/f5_tts/model/cfm.py
2
+
3
+ """
4
+ ein notation:
5
+ b - batch
6
+ n - sequence
7
+ nt - text sequence
8
+ nw - raw wave length
9
+ d - dimension
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ from typing import Callable
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+ from torch.nn.utils.rnn import pad_sequence
20
+ from torchdiffeq import odeint
21
+
22
+ from .modules import MelSpec
23
+ from .utils import (
24
+ default,
25
+ exists,
26
+ lens_to_mask,
27
+ list_str_to_idx,
28
+ list_str_to_tensor,
29
+ )
30
+
31
+
32
+ class CFM(nn.Module):
33
+ def __init__(
34
+ self,
35
+ transformer: nn.Module,
36
+ sigma=0.0,
37
+ odeint_kwargs: dict = dict(
38
+ method="euler" # 'midpoint'
39
+ ),
40
+ num_channels=None,
41
+ mel_spec_module: nn.Module | None = None,
42
+ mel_spec_kwargs: dict = dict(),
43
+ frac_lengths_mask: tuple[float, float] = (0.7, 1.0),
44
+ vocab_char_map: dict[str:int] | None = None,
45
+ controlnet: nn.Module | None = None,
46
+ ):
47
+ super().__init__()
48
+
49
+ self.frac_lengths_mask = frac_lengths_mask
50
+
51
+ self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs))
52
+ num_channels = default(num_channels, self.mel_spec.n_mel_channels)
53
+ self.num_channels = num_channels
54
+
55
+ self.transformer = transformer
56
+ dim = transformer.dim
57
+ self.dim = dim
58
+
59
+ self.sigma = sigma
60
+
61
+ self.odeint_kwargs = odeint_kwargs
62
+
63
+ self.vocab_char_map = vocab_char_map
64
+
65
+ self.controlnet = controlnet
66
+
67
+ @property
68
+ def device(self):
69
+ return next(self.parameters()).device
70
+
71
+ @torch.no_grad()
72
+ def sample(
73
+ self,
74
+ cond: float["b n d"] | float["b nw"], # noqa: F722
75
+ text: int["b nt"] | list[str], # noqa: F722
76
+ clip: float["b n d"], # noqa: F722
77
+ duration: int | int["b"], # noqa: F821
78
+ *,
79
+ caption_emb: float["b n d"] | None = None, # noqa: F722
80
+ spk_emb: float["b n d"] | None = None, # noqa: F722
81
+ lens: int["b"] | None = None, # noqa: F821
82
+ steps=32,
83
+ cfg_strength=1.0,
84
+ sway_sampling_coef=None,
85
+ seed: int | None = None,
86
+ max_duration=4096,
87
+ vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722
88
+ no_ref_audio=False,
89
+ duplicate_test=False,
90
+ t_inter=0.1,
91
+ edit_mask=None,
92
+ ):
93
+ self.eval()
94
+
95
+ if cond.ndim == 2:
96
+ cond = self.mel_spec(cond)
97
+ cond = cond.permute(0, 2, 1)
98
+ assert cond.shape[-1] == self.num_channels
99
+
100
+ cond = cond.to(next(self.parameters()).dtype)
101
+
102
+ batch, cond_seq_len, device = *cond.shape[:2], cond.device
103
+ if not exists(lens):
104
+ lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long)
105
+
106
+ if isinstance(text, list):
107
+ if exists(self.vocab_char_map):
108
+ text = list_str_to_idx(text, self.vocab_char_map).to(device)
109
+ else:
110
+ text = list_str_to_tensor(text).to(device)
111
+ assert text.shape[0] == batch
112
+
113
+ if exists(text):
114
+ text_lens = (text != -1).sum(dim=-1)
115
+ lens = torch.maximum(text_lens, lens)
116
+
117
+ cond_mask = lens_to_mask(lens)
118
+ if edit_mask is not None:
119
+ cond_mask = cond_mask & edit_mask
120
+
121
+ if isinstance(duration, int):
122
+ duration = torch.full((batch,), duration, device=device, dtype=torch.long)
123
+
124
+ # duration = torch.maximum(lens + 1, duration)
125
+
126
+ duration = duration.clamp(max=max_duration)
127
+ max_duration = duration.amax()
128
+
129
+ if duplicate_test:
130
+ test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0)
131
+
132
+ cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0)
133
+ cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False)
134
+ cond_mask = cond_mask.unsqueeze(-1)
135
+ step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
136
+
137
+ if batch > 1:
138
+ mask = lens_to_mask(duration)
139
+ else:
140
+ mask = None
141
+
142
+ if no_ref_audio:
143
+ cond = torch.zeros_like(cond)
144
+
145
+ def fn(t, x):
146
+ step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
147
+
148
+ controlnet_embeds = self.controlnet(
149
+ x=x,
150
+ text=text,
151
+ clip=clip,
152
+ spk_emb=spk_emb,
153
+ caption=caption_emb,
154
+ time=t,
155
+ )
156
+
157
+ cond_pred = self.transformer(
158
+ x=x,
159
+ cond=step_cond,
160
+ text=text,
161
+ time=t,
162
+ mask=mask,
163
+ drop_audio_cond=[False],
164
+ drop_text=[False],
165
+ controlnet_embeds=controlnet_embeds,
166
+ )
167
+
168
+ null_pred = self.transformer(
169
+ x=x,
170
+ cond=step_cond,
171
+ text=text,
172
+ time=t,
173
+ mask=mask,
174
+ drop_audio_cond=[True],
175
+ drop_text=[True],
176
+ controlnet_embeds=None,
177
+ )
178
+
179
+ return null_pred + (cond_pred - null_pred) * 2
180
+
181
+ y0 = []
182
+ for dur in duration:
183
+ if exists(seed):
184
+ torch.manual_seed(seed)
185
+ y0.append(torch.randn(dur, self.num_channels, device=self.device, dtype=step_cond.dtype))
186
+ y0 = pad_sequence(y0, padding_value=0, batch_first=True)
187
+
188
+ t_start = 0
189
+
190
+ if duplicate_test:
191
+ t_start = t_inter
192
+ y0 = (1 - t_start) * y0 + t_start * test_cond
193
+ steps = int(steps * (1 - t_start))
194
+
195
+ t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype)
196
+ if sway_sampling_coef is not None:
197
+ t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
198
+
199
+ trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
200
+
201
+ sampled = trajectory[-1]
202
+ out = sampled
203
+ out = torch.where(cond_mask, cond, out)
204
+
205
+ if exists(vocoder):
206
+ out = out.permute(0, 2, 1)
207
+ out = vocoder(out)
208
+
209
+ return out, trajectory
src/moviedubber/model/dit.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/SWivid/F5-TTS/blob/main/src/f5_tts/model/backbones/dit.py
2
+
3
+ """
4
+ ein notation:
5
+ b - batch
6
+ n - sequence
7
+ nt - text sequence
8
+ nw - raw wave length
9
+ d - dimension
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from torch import nn
17
+ from x_transformers.x_transformers import RotaryEmbedding
18
+
19
+ from .modules import (
20
+ AdaLayerNormZero_Final,
21
+ ConvNeXtV2Block,
22
+ ConvPositionEmbedding,
23
+ DiTBlock,
24
+ TimestepEmbedding,
25
+ get_pos_embed_indices,
26
+ precompute_freqs_cis,
27
+ )
28
+
29
+
30
+ # Text embedding
31
+
32
+
33
+ class TextEmbedding(nn.Module):
34
+ def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
35
+ super().__init__()
36
+ self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
37
+
38
+ if conv_layers > 0:
39
+ self.extra_modeling = True
40
+ self.precompute_max_pos = 4096 # ~44s of 24khz audio
41
+ self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
42
+ self.text_blocks = nn.Sequential(
43
+ *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
44
+ )
45
+ else:
46
+ self.extra_modeling = False
47
+
48
+ def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
49
+ text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
50
+ text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
51
+ batch, text_len = text.shape[0], text.shape[1]
52
+ text = F.pad(text, (0, seq_len - text_len), value=0)
53
+
54
+ for idx, _drop in enumerate(drop_text): # cfg for text
55
+ if _drop:
56
+ text[idx] = torch.zeros_like(text[idx])
57
+
58
+ text = self.text_embed(text) # b n -> b n d
59
+
60
+ # possible extra modeling
61
+ if self.extra_modeling:
62
+ # sinus pos emb
63
+ batch_start = torch.zeros((batch,), dtype=torch.long)
64
+ pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
65
+ text_pos_embed = self.freqs_cis[pos_idx]
66
+ text = text + text_pos_embed
67
+
68
+ # convnextv2 blocks
69
+ text = self.text_blocks(text)
70
+
71
+ return text
72
+
73
+
74
+ # noised input audio and context mixing embedding
75
+
76
+
77
+ class InputEmbedding(nn.Module):
78
+ def __init__(self, mel_dim, text_dim, out_dim):
79
+ super().__init__()
80
+ self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
81
+ self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
82
+
83
+ def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
84
+ for idx, _drop in enumerate(drop_audio_cond): # cfg for cond audio
85
+ if _drop:
86
+ cond[idx] = torch.zeros_like(cond[idx])
87
+
88
+ x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
89
+ x = self.conv_pos_embed(x) + x
90
+ return x
91
+
92
+
93
+ class InputEmbeddingO(nn.Module):
94
+ def __init__(self, mel_dim, text_dim, out_dim):
95
+ super().__init__()
96
+ self.proj = nn.Linear(mel_dim + 512 + text_dim + 192 + 32, out_dim)
97
+ self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
98
+
99
+ def forward(
100
+ self,
101
+ x: float["b n d"], # noqa: F722
102
+ text_emb: float["b n d"], # noqa: F722
103
+ video_emb: float["b n d"], # noqa: F722
104
+ spk_emb: float["b n d"], # noqa: F722
105
+ caption_emb: float["b n d"], # noqa: F722
106
+ ):
107
+ x = self.proj(torch.cat((x, text_emb, video_emb, spk_emb, caption_emb), dim=-1))
108
+ x = self.conv_pos_embed(x) + x
109
+ return x
110
+
111
+
112
+ # Transformer backbone using DiT blocks
113
+
114
+
115
+ class DiT(nn.Module):
116
+ def __init__(
117
+ self,
118
+ *,
119
+ dim,
120
+ depth=8,
121
+ heads=8,
122
+ dim_head=64,
123
+ dropout=0.1,
124
+ ff_mult=4,
125
+ mel_dim=100,
126
+ text_num_embeds=256,
127
+ text_dim=None,
128
+ conv_layers=0,
129
+ long_skip_connection=False,
130
+ ):
131
+ super().__init__()
132
+
133
+ self.time_embed = TimestepEmbedding(dim)
134
+ if text_dim is None:
135
+ text_dim = mel_dim
136
+ self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
137
+ self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
138
+
139
+ self.rotary_embed = RotaryEmbedding(dim_head)
140
+
141
+ self.dim = dim
142
+ self.depth = depth
143
+
144
+ self.transformer_blocks = nn.ModuleList(
145
+ [DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)]
146
+ )
147
+ self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
148
+
149
+ self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
150
+ self.proj_out = nn.Linear(dim, mel_dim)
151
+
152
+ def forward(
153
+ self,
154
+ x: float["b n d"], # nosied input audio # noqa: F722
155
+ cond: float["b n d"], # masked cond audio # noqa: F722
156
+ text: int["b nt"], # text # noqa: F722
157
+ time: float["b"] | float[""], # time step # noqa: F821 F722
158
+ drop_audio_cond, # cfg for cond audio
159
+ drop_text, # cfg for text
160
+ mask: bool["b n"] | None = None, # noqa: F722
161
+ controlnet_embeds: float["b n d"] | None = None, # noqa: F722
162
+ ):
163
+ batch, seq_len = x.shape[0], x.shape[1]
164
+ if time.ndim == 0:
165
+ time = time.repeat(batch)
166
+
167
+ t = self.time_embed(time)
168
+ text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
169
+ x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
170
+
171
+ rope = self.rotary_embed.forward_from_seq_len(seq_len)
172
+
173
+ if self.long_skip_connection is not None:
174
+ residual = x
175
+
176
+ for i, block in enumerate(self.transformer_blocks):
177
+ if controlnet_embeds is not None and i < 12:
178
+ x += controlnet_embeds[i]
179
+
180
+ x = block(x, t, mask=mask, rope=rope)
181
+
182
+ if self.long_skip_connection is not None:
183
+ x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
184
+
185
+ x = self.norm_out(x, t)
186
+ output = self.proj_out(x)
187
+
188
+ return output
189
+
190
+
191
+ class ControlNetDiT(nn.Module):
192
+ def __init__(
193
+ self,
194
+ *,
195
+ dim,
196
+ depth=8,
197
+ heads=8,
198
+ dim_head=64,
199
+ dropout=0.1,
200
+ ff_mult=4,
201
+ mel_dim=100,
202
+ text_num_embeds=256,
203
+ text_dim=None,
204
+ conv_layers=0,
205
+ long_skip_connection=False,
206
+ checkpoint_activations=False,
207
+ duration_predictor=None,
208
+ ):
209
+ super().__init__()
210
+
211
+ if text_dim is None:
212
+ text_dim = mel_dim
213
+
214
+ self.time_embed = TimestepEmbedding(dim)
215
+
216
+ self.rotary_embed = RotaryEmbedding(dim_head)
217
+
218
+ self.dim = dim
219
+ self.depth = depth // 2 + 1
220
+
221
+ self.transformer_blocks1 = nn.ModuleList(
222
+ [
223
+ DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout)
224
+ for _ in range(self.depth)
225
+ ]
226
+ )
227
+
228
+ self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
229
+
230
+ self.input_embed = InputEmbeddingO(mel_dim, text_dim, dim)
231
+
232
+ self.spk_embed_affine_layer = torch.nn.Linear(192, 192)
233
+ self.clip_embed_affine_layer = torch.nn.Linear(768, 512)
234
+ self.caption_embed_affine_layer = torch.nn.Linear(512, 32)
235
+
236
+ self.zero_linear = nn.ModuleList([nn.Linear(dim, dim, bias=False) for _ in range(12)])
237
+ for zero_linear in self.zero_linear:
238
+ nn.init.zeros_(zero_linear.weight)
239
+
240
+ self.duration_predictor = duration_predictor
241
+
242
+ def forward(
243
+ self,
244
+ x: float["b n d"], # nosied input audio # noqa: F722
245
+ text: int["b nt"], # text # noqa: F722
246
+ clip: float["b n d"], # video clip # noqa: F722
247
+ spk_emb: float["b d"], # speaker embedding # noqa: F722
248
+ time: float["b"] | float[""], # time step # noqa: F821 F722
249
+ caption: float["b nt"] | None = None, # caption # noqa: F722
250
+ mask: bool["b n"] | None = None, # noqa: F722
251
+ lens: int["b"] | None = None, # noqa: F722, F821
252
+ return_dur: bool = False, # return duration prediction
253
+ ):
254
+ batch, seq_len = x.shape[0], x.shape[1]
255
+
256
+ if time.ndim == 0:
257
+ time = time.repeat(batch)
258
+
259
+ t = self.time_embed(time)
260
+
261
+ clip_emb = F.normalize(clip, dim=-1)
262
+ clip_emb = self.clip_embed_affine_layer(clip)
263
+
264
+ spk_emb = F.normalize(spk_emb, dim=-1)
265
+ spk_emb = self.spk_embed_affine_layer(spk_emb)
266
+ spk_emb = torch.repeat_interleave(spk_emb, seq_len, dim=1)
267
+
268
+ if caption is None:
269
+ caption = torch.zeros(1, seq_len, 512).to(device=x.device)
270
+
271
+ caption_emb = F.normalize(caption, dim=-1)
272
+ caption_emb = self.caption_embed_affine_layer(caption_emb)
273
+
274
+ text_embed = self.text_embed(text, seq_len, drop_text=[False])
275
+
276
+ x = self.input_embed(x, text_embed, clip_emb, spk_emb, caption_emb)
277
+
278
+ rope = self.rotary_embed.forward_from_seq_len(seq_len)
279
+
280
+ info = []
281
+ for i, block in enumerate(self.transformer_blocks1):
282
+ x = block(x, t, mask=mask, rope=rope) # 'b n 1024'
283
+
284
+ info.append(x)
285
+
286
+ out_info = []
287
+ for i, linear in enumerate(self.zero_linear):
288
+ h = linear(info[i])
289
+ out_info.append(h)
290
+
291
+ if return_dur and self.duration_predictor is not None:
292
+ dur_loss = self.duration_predictor(x=x, text=clip_emb, lens=lens)
293
+
294
+ return out_info, dur_loss
295
+
296
+ else:
297
+ return out_info
src/moviedubber/model/modules.py ADDED
@@ -0,0 +1,467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import math
13
+ from typing import Optional
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from librosa.filters import mel as librosa_mel_fn
18
+ from torch import nn
19
+ from x_transformers.x_transformers import apply_rotary_pos_emb
20
+
21
+
22
+ # raw wav to mel spec
23
+
24
+
25
+ mel_basis_cache = {}
26
+ hann_window_cache = {}
27
+
28
+
29
+ def get_bigvgan_mel_spectrogram(
30
+ waveform,
31
+ n_fft=1024,
32
+ n_mel_channels=100,
33
+ target_sample_rate=24000,
34
+ hop_length=256,
35
+ win_length=1024,
36
+ fmin=0,
37
+ fmax=None,
38
+ center=False,
39
+ ): # Copy from https://github.com/NVIDIA/BigVGAN/tree/main
40
+ device = waveform.device
41
+ key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{fmin}_{fmax}_{device}"
42
+
43
+ if key not in mel_basis_cache:
44
+ mel = librosa_mel_fn(sr=target_sample_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=fmin, fmax=fmax)
45
+ mel_basis_cache[key] = torch.from_numpy(mel).float().to(device) # TODO: why they need .float()?
46
+ hann_window_cache[key] = torch.hann_window(win_length).to(device)
47
+
48
+ mel_basis = mel_basis_cache[key]
49
+ hann_window = hann_window_cache[key]
50
+
51
+ padding = (n_fft - hop_length) // 2
52
+ waveform = torch.nn.functional.pad(waveform.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1)
53
+
54
+ spec = torch.stft(
55
+ waveform,
56
+ n_fft,
57
+ hop_length=hop_length,
58
+ win_length=win_length,
59
+ window=hann_window,
60
+ center=center,
61
+ pad_mode="reflect",
62
+ normalized=False,
63
+ onesided=True,
64
+ return_complex=True,
65
+ )
66
+ spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
67
+
68
+ mel_spec = torch.matmul(mel_basis, spec)
69
+ mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5))
70
+
71
+ return mel_spec
72
+
73
+
74
+ class MelSpec(nn.Module):
75
+ def __init__(
76
+ self,
77
+ n_fft=1024,
78
+ hop_length=256,
79
+ win_length=1024,
80
+ n_mel_channels=100,
81
+ target_sample_rate=24_000,
82
+ mel_spec_type="bigvgan",
83
+ ):
84
+ super().__init__()
85
+
86
+ self.n_fft = n_fft
87
+ self.hop_length = hop_length
88
+ self.win_length = win_length
89
+ self.n_mel_channels = n_mel_channels
90
+ self.target_sample_rate = target_sample_rate
91
+
92
+ self.extractor = get_bigvgan_mel_spectrogram
93
+
94
+ self.register_buffer("dummy", torch.tensor(0), persistent=False)
95
+
96
+ def forward(self, wav):
97
+ if self.dummy.device != wav.device:
98
+ self.to(wav.device)
99
+
100
+ mel = self.extractor(
101
+ waveform=wav,
102
+ n_fft=self.n_fft,
103
+ n_mel_channels=self.n_mel_channels,
104
+ target_sample_rate=self.target_sample_rate,
105
+ hop_length=self.hop_length,
106
+ win_length=self.win_length,
107
+ )
108
+
109
+ return mel
110
+
111
+
112
+ # sinusoidal position embedding
113
+
114
+
115
+ class SinusPositionEmbedding(nn.Module):
116
+ def __init__(self, dim):
117
+ super().__init__()
118
+ self.dim = dim
119
+
120
+ def forward(self, x, scale=1000):
121
+ device = x.device
122
+ half_dim = self.dim // 2
123
+ emb = math.log(10000) / (half_dim - 1)
124
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
125
+ emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
126
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
127
+ return emb
128
+
129
+
130
+ # convolutional position embedding
131
+
132
+
133
+ class ConvPositionEmbedding(nn.Module):
134
+ def __init__(self, dim, kernel_size=31, groups=16):
135
+ super().__init__()
136
+ assert kernel_size % 2 != 0
137
+ self.conv1d = nn.Sequential(
138
+ nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
139
+ nn.Mish(),
140
+ nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
141
+ nn.Mish(),
142
+ )
143
+
144
+ def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
145
+ if mask is not None:
146
+ mask = mask[..., None]
147
+ x = x.masked_fill(~mask, 0.0)
148
+
149
+ x = x.permute(0, 2, 1)
150
+ x = self.conv1d(x)
151
+ out = x.permute(0, 2, 1)
152
+
153
+ if mask is not None:
154
+ out = out.masked_fill(~mask, 0.0)
155
+
156
+ return out
157
+
158
+
159
+ # rotary positional embedding related
160
+
161
+
162
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
163
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
164
+ # has some connection to NTK literature
165
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
166
+ # https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
167
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
168
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
169
+ t = torch.arange(end, device=freqs.device) # type: ignore
170
+ freqs = torch.outer(t, freqs).float() # type: ignore
171
+ freqs_cos = torch.cos(freqs) # real part
172
+ freqs_sin = torch.sin(freqs) # imaginary part
173
+ return torch.cat([freqs_cos, freqs_sin], dim=-1)
174
+
175
+
176
+ def get_pos_embed_indices(start, length, max_pos, scale=1.0):
177
+ # length = length if isinstance(length, int) else length.max()
178
+ scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
179
+ pos = (
180
+ start.unsqueeze(1)
181
+ + (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long()
182
+ )
183
+ # avoid extra long error.
184
+ pos = torch.where(pos < max_pos, pos, max_pos - 1)
185
+ return pos
186
+
187
+
188
+ # Global Response Normalization layer (Instance Normalization ?)
189
+
190
+
191
+ class GRN(nn.Module):
192
+ def __init__(self, dim):
193
+ super().__init__()
194
+ self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
195
+ self.beta = nn.Parameter(torch.zeros(1, 1, dim))
196
+
197
+ def forward(self, x):
198
+ Gx = torch.norm(x, p=2, dim=1, keepdim=True)
199
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
200
+ return self.gamma * (x * Nx) + self.beta + x
201
+
202
+
203
+ # ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
204
+ # ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
205
+
206
+
207
+ class ConvNeXtV2Block(nn.Module):
208
+ def __init__(
209
+ self,
210
+ dim: int,
211
+ intermediate_dim: int,
212
+ dilation: int = 1,
213
+ ):
214
+ super().__init__()
215
+ padding = (dilation * (7 - 1)) // 2
216
+ self.dwconv = nn.Conv1d(
217
+ dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
218
+ ) # depthwise conv
219
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
220
+ self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
221
+ self.act = nn.GELU()
222
+ self.grn = GRN(intermediate_dim)
223
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
224
+
225
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
226
+ residual = x
227
+ x = x.transpose(1, 2) # b n d -> b d n
228
+ x = self.dwconv(x)
229
+ x = x.transpose(1, 2) # b d n -> b n d
230
+ x = self.norm(x)
231
+ x = self.pwconv1(x)
232
+ x = self.act(x)
233
+ x = self.grn(x)
234
+ x = self.pwconv2(x)
235
+ return residual + x
236
+
237
+
238
+ # AdaLayerNormZero
239
+ # return with modulated x for attn input, and params for later mlp modulation
240
+
241
+
242
+ class AdaLayerNormZero(nn.Module):
243
+ def __init__(self, dim):
244
+ super().__init__()
245
+
246
+ self.silu = nn.SiLU()
247
+ self.linear = nn.Linear(dim, dim * 6)
248
+
249
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
250
+
251
+ def forward(self, x, emb=None):
252
+ emb = self.linear(self.silu(emb))
253
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
254
+
255
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
256
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
257
+
258
+
259
+ # AdaLayerNormZero for final layer
260
+ # return only with modulated x for attn input, cuz no more mlp modulation
261
+
262
+
263
+ class AdaLayerNormZero_Final(nn.Module):
264
+ def __init__(self, dim):
265
+ super().__init__()
266
+
267
+ self.silu = nn.SiLU()
268
+ self.linear = nn.Linear(dim, dim * 2)
269
+
270
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
271
+
272
+ def forward(self, x, emb):
273
+ emb = self.linear(self.silu(emb))
274
+ scale, shift = torch.chunk(emb, 2, dim=1)
275
+
276
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
277
+ return x
278
+
279
+
280
+ # FeedForward
281
+
282
+
283
+ class FeedForward(nn.Module):
284
+ def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"):
285
+ super().__init__()
286
+ inner_dim = int(dim * mult)
287
+ dim_out = dim_out if dim_out is not None else dim
288
+
289
+ activation = nn.GELU(approximate=approximate)
290
+ project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
291
+ self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
292
+
293
+ def forward(self, x):
294
+ return self.ff(x)
295
+
296
+
297
+ # Attention with possible joint part
298
+ # modified from diffusers/src/diffusers/models/attention_processor.py
299
+
300
+
301
+ class Attention(nn.Module):
302
+ def __init__(
303
+ self,
304
+ processor: AttnProcessor,
305
+ dim: int,
306
+ heads: int = 8,
307
+ dim_head: int = 64,
308
+ dropout: float = 0.0,
309
+ context_dim: Optional[int] = None, # if not None -> joint attention
310
+ context_pre_only=None,
311
+ ):
312
+ super().__init__()
313
+
314
+ if not hasattr(F, "scaled_dot_product_attention"):
315
+ raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
316
+
317
+ self.processor = processor
318
+
319
+ self.dim = dim
320
+ self.heads = heads
321
+ self.inner_dim = dim_head * heads
322
+ self.dropout = dropout
323
+
324
+ self.context_dim = context_dim
325
+ self.context_pre_only = context_pre_only
326
+
327
+ self.to_q = nn.Linear(dim, self.inner_dim)
328
+ self.to_k = nn.Linear(dim, self.inner_dim)
329
+ self.to_v = nn.Linear(dim, self.inner_dim)
330
+
331
+ if self.context_dim is not None:
332
+ self.to_k_c = nn.Linear(context_dim, self.inner_dim)
333
+ self.to_v_c = nn.Linear(context_dim, self.inner_dim)
334
+ if self.context_pre_only is not None:
335
+ self.to_q_c = nn.Linear(context_dim, self.inner_dim)
336
+
337
+ self.to_out = nn.ModuleList([])
338
+ self.to_out.append(nn.Linear(self.inner_dim, dim))
339
+ self.to_out.append(nn.Dropout(dropout))
340
+
341
+ if self.context_pre_only is not None and not self.context_pre_only:
342
+ self.to_out_c = nn.Linear(self.inner_dim, dim)
343
+
344
+ def forward(
345
+ self,
346
+ x: float["b n d"], # noised input x # noqa: F722
347
+ c: float["b n d"] = None, # context c # noqa: F722
348
+ mask: bool["b n"] | None = None, # noqa: F722
349
+ rope=None, # rotary position embedding for x
350
+ c_rope=None, # rotary position embedding for c
351
+ ) -> torch.Tensor:
352
+ if c is not None:
353
+ return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope)
354
+ else:
355
+ return self.processor(self, x, mask=mask, rope=rope)
356
+
357
+
358
+ # Attention processor
359
+
360
+
361
+ class AttnProcessor:
362
+ def __init__(self):
363
+ pass
364
+
365
+ def __call__(
366
+ self,
367
+ attn: Attention,
368
+ x: float["b n d"], # noised input x # noqa: F722
369
+ mask: bool["b n"] | None = None, # noqa: F722
370
+ rope=None, # rotary position embedding
371
+ ) -> torch.FloatTensor:
372
+ batch_size = x.shape[0]
373
+
374
+ # `sample` projections.
375
+ query = attn.to_q(x)
376
+ key = attn.to_k(x)
377
+ value = attn.to_v(x)
378
+
379
+ # apply rotary position embedding
380
+ if rope is not None:
381
+ freqs, xpos_scale = rope
382
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
383
+
384
+ query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
385
+ key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
386
+
387
+ # attention
388
+ inner_dim = key.shape[-1]
389
+ head_dim = inner_dim // attn.heads
390
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
391
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
392
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
393
+
394
+ # mask. e.g. inference got a batch with different target durations, mask out the padding
395
+ if mask is not None:
396
+ attn_mask = mask
397
+ attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
398
+ attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
399
+ else:
400
+ attn_mask = None
401
+
402
+ x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
403
+ x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
404
+ x = x.to(query.dtype)
405
+
406
+ # linear proj
407
+ x = attn.to_out[0](x)
408
+ # dropout
409
+ x = attn.to_out[1](x)
410
+
411
+ if mask is not None:
412
+ mask = mask.unsqueeze(-1)
413
+ x = x.masked_fill(~mask, 0.0)
414
+
415
+ return x
416
+
417
+
418
+ # DiT Block
419
+
420
+
421
+ class DiTBlock(nn.Module):
422
+ def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
423
+ super().__init__()
424
+
425
+ self.attn_norm = AdaLayerNormZero(dim)
426
+ self.attn = Attention(
427
+ processor=AttnProcessor(),
428
+ dim=dim,
429
+ heads=heads,
430
+ dim_head=dim_head,
431
+ dropout=dropout,
432
+ )
433
+
434
+ self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
435
+ self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
436
+
437
+ def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
438
+ # pre-norm & modulation for attention input
439
+ norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
440
+
441
+ # attention
442
+ attn_output = self.attn(x=norm, mask=mask, rope=rope)
443
+
444
+ # process attention output for input x
445
+ x = x + gate_msa.unsqueeze(1) * attn_output
446
+
447
+ norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
448
+ ff_output = self.ff(norm)
449
+ x = x + gate_mlp.unsqueeze(1) * ff_output
450
+
451
+ return x
452
+
453
+
454
+ # time step conditioning embedding
455
+
456
+
457
+ class TimestepEmbedding(nn.Module):
458
+ def __init__(self, dim, freq_embed_dim=256):
459
+ super().__init__()
460
+ self.time_embed = SinusPositionEmbedding(freq_embed_dim)
461
+ self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
462
+
463
+ def forward(self, timestep: float["b"]): # noqa: F821
464
+ time_hidden = self.time_embed(timestep)
465
+ time_hidden = time_hidden.to(timestep.dtype)
466
+ time = self.time_mlp(time_hidden) # b d
467
+ return time
src/moviedubber/model/utils.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+
4
+ import jieba
5
+ import torch
6
+ from pypinyin import Style, lazy_pinyin
7
+ from torch.nn.utils.rnn import pad_sequence
8
+
9
+
10
+ def exists(v):
11
+ return v is not None
12
+
13
+
14
+ def default(v, d):
15
+ return v if exists(v) else d
16
+
17
+
18
+ # tensor helpers
19
+
20
+
21
+ def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa: F722 F821
22
+ if not exists(length):
23
+ length = t.amax()
24
+
25
+ seq = torch.arange(length, device=t.device)
26
+ return seq[None, :] < t[:, None]
27
+
28
+
29
+ # simple utf-8 tokenizer, since paper went character based
30
+ def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]: # noqa: F722
31
+ list_tensors = [torch.tensor([*bytes(t, "UTF-8")]) for t in text] # ByT5 style
32
+ text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True)
33
+ return text
34
+
35
+
36
+ # char tokenizer, based on custom dataset's extracted .txt file
37
+ def list_str_to_idx(
38
+ text: list[str] | list[list[str]],
39
+ vocab_char_map: dict[str, int], # {char: idx}
40
+ padding_value=-1,
41
+ ) -> int["b nt"]: # noqa: F722
42
+ list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
43
+ text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
44
+ return text
45
+
46
+
47
+ # Get tokenizer
48
+
49
+
50
+ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
51
+ """
52
+ tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
53
+ - "char" for char-wise tokenizer, need .txt vocab_file
54
+ - "byte" for utf-8 tokenizer
55
+ - "custom" if you're directly passing in a path to the vocab.txt you want to use
56
+ vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
57
+ - if use "char", derived from unfiltered character & symbol counts of custom dataset
58
+ - if use "byte", set to 256 (unicode byte range)
59
+ """
60
+ if tokenizer in ["pinyin", "char"]:
61
+ # tokenizer_path = os.path.join(files("f5_tts").joinpath("../../data"), f"{dataset_name}_{tokenizer}/vocab.txt")
62
+ tokenizer_path = "/ailab-train/speech/zhengjunjie/huggingface/models/F5-TTS/F5TTS_Base/vocab.txt"
63
+ print(f"Loading {tokenizer} tokenizer from {tokenizer_path}")
64
+ with open(tokenizer_path, "r", encoding="utf-8") as f:
65
+ vocab_char_map = {}
66
+ for i, char in enumerate(f):
67
+ vocab_char_map[char[:-1]] = i
68
+ vocab_size = len(vocab_char_map)
69
+ assert vocab_char_map[" "] == 0, "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char"
70
+
71
+ elif tokenizer == "byte":
72
+ vocab_char_map = None
73
+ vocab_size = 256
74
+
75
+ elif tokenizer == "custom":
76
+ with open(dataset_name, "r", encoding="utf-8") as f:
77
+ vocab_char_map = {}
78
+ for i, char in enumerate(f):
79
+ vocab_char_map[char[:-1]] = i
80
+ vocab_size = len(vocab_char_map)
81
+
82
+ return vocab_char_map, vocab_size
83
+
84
+
85
+ # convert char to pinyin
86
+
87
+ jieba.initialize()
88
+ print("Word segmentation module jieba initialized.\n")
89
+
90
+
91
+ def convert_char_to_pinyin(text_list, polyphone=True):
92
+ final_text_list = []
93
+ custom_trans = str.maketrans(
94
+ {";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"}
95
+ ) # add custom trans here, to address oov
96
+
97
+ def is_chinese(c):
98
+ return (
99
+ "\u3100" <= c <= "\u9fff" # common chinese characters
100
+ )
101
+
102
+ for text in text_list:
103
+ char_list = []
104
+ text = text.translate(custom_trans)
105
+ for seg in jieba.cut(text):
106
+ seg_byte_len = len(bytes(seg, "UTF-8"))
107
+ if seg_byte_len == len(seg): # if pure alphabets and symbols
108
+ if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
109
+ char_list.append(" ")
110
+ char_list.extend(seg)
111
+ elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters
112
+ seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
113
+ for i, c in enumerate(seg):
114
+ if is_chinese(c):
115
+ char_list.append(" ")
116
+ char_list.append(seg_[i])
117
+ else: # if mixed characters, alphabets and symbols
118
+ for c in seg:
119
+ if ord(c) < 256:
120
+ char_list.extend(c)
121
+ elif is_chinese(c):
122
+ char_list.append(" ")
123
+ char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
124
+ else:
125
+ char_list.append(c)
126
+ final_text_list.append(char_list)
127
+
128
+ return final_text_list
src/third_party/BigVGAN/.gitignore ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # BigVGAN
2
+ alias_free_activation/cuda/build/
3
+ exp/
4
+ tmp/
5
+
6
+ # Symlinks for bundled LibriTTS filelists
7
+ filelists/LibriTTS/train-clean-100
8
+ filelists/LibriTTS/train-clean-360
9
+ filelists/LibriTTS/train-other-500
10
+ filelists/LibriTTS/dev-clean
11
+ filelists/LibriTTS/dev-other
12
+ filelists/LibriTTS/test-clean
13
+ filelists/LibriTTS/test-other
14
+
15
+ # VSCode configs
16
+ .vscode/
17
+
18
+ # Byte-compiled / optimized / DLL files
19
+ __pycache__/
20
+ *.py[cod]
21
+ *$py.class
22
+
23
+ # C extensions
24
+ *.so
25
+
26
+ # Distribution / packaging
27
+ .Python
28
+ build/
29
+ develop-eggs/
30
+ dist/
31
+ downloads/
32
+ eggs/
33
+ .eggs/
34
+ lib/
35
+ lib64/
36
+ parts/
37
+ sdist/
38
+ var/
39
+ wheels/
40
+ share/python-wheels/
41
+ *.egg-info/
42
+ .installed.cfg
43
+ *.egg
44
+ MANIFEST
45
+
46
+ # PyInstaller
47
+ # Usually these files are written by a python script from a template
48
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
49
+ *.manifest
50
+ *.spec
51
+
52
+ # Installer logs
53
+ pip-log.txt
54
+ pip-delete-this-directory.txt
55
+
56
+ # Unit test / coverage reports
57
+ htmlcov/
58
+ .tox/
59
+ .nox/
60
+ .coverage
61
+ .coverage.*
62
+ .cache
63
+ nosetests.xml
64
+ coverage.xml
65
+ *.cover
66
+ *.py,cover
67
+ .hypothesis/
68
+ .pytest_cache/
69
+ cover/
70
+
71
+ # Translations
72
+ *.mo
73
+ *.pot
74
+
75
+ # Django stuff:
76
+ *.log
77
+ local_settings.py
78
+ db.sqlite3
79
+ db.sqlite3-journal
80
+
81
+ # Flask stuff:
82
+ instance/
83
+ .webassets-cache
84
+
85
+ # Scrapy stuff:
86
+ .scrapy
87
+
88
+ # Sphinx documentation
89
+ docs/_build/
90
+
91
+ # PyBuilder
92
+ .pybuilder/
93
+ target/
94
+
95
+ # Jupyter Notebook
96
+ .ipynb_checkpoints
97
+
98
+ # IPython
99
+ profile_default/
100
+ ipython_config.py
101
+
102
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
103
+ __pypackages__/
104
+
105
+ # Celery stuff
106
+ celerybeat-schedule
107
+ celerybeat.pid
108
+
109
+ # SageMath parsed files
110
+ *.sage.py
111
+
112
+ # Environments
113
+ .env
114
+ .venv
115
+ env/
116
+ venv/
117
+ ENV/
118
+ env.bak/
119
+ venv.bak/
120
+
121
+ # Spyder project settings
122
+ .spyderproject
123
+ .spyproject
124
+
125
+ # Rope project settings
126
+ .ropeproject
127
+
128
+ # mkdocs documentation
129
+ /site
130
+
131
+ # mypy
132
+ .mypy_cache/
133
+ .dmypy.json
134
+ dmypy.json
135
+
136
+ # Pyre type checker
137
+ .pyre/
138
+
139
+ # pytype static type analyzer
140
+ .pytype/
141
+
142
+ # Cython debug symbols
143
+ cython_debug/
144
+
145
+ # PyCharm
146
+ .idea/
src/third_party/BigVGAN/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 NVIDIA CORPORATION.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
src/third_party/BigVGAN/README.md ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## BigVGAN: A Universal Neural Vocoder with Large-Scale Training
2
+
3
+ #### Sang-gil Lee, Wei Ping, Boris Ginsburg, Bryan Catanzaro, Sungroh Yoon
4
+
5
+ [[Paper]](https://arxiv.org/abs/2206.04658) - [[Code]](https://github.com/NVIDIA/BigVGAN) - [[Showcase]](https://bigvgan-demo.github.io/) - [[Project Page]](https://research.nvidia.com/labs/adlr/projects/bigvgan/) - [[Weights]](https://huggingface.co/collections/nvidia/bigvgan-66959df3d97fd7d98d97dc9a) - [[Demo]](https://huggingface.co/spaces/nvidia/BigVGAN)
6
+
7
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/bigvgan-a-universal-neural-vocoder-with-large/speech-synthesis-on-libritts)](https://paperswithcode.com/sota/speech-synthesis-on-libritts?p=bigvgan-a-universal-neural-vocoder-with-large)
8
+
9
+ <center><img src="https://user-images.githubusercontent.com/15963413/218609148-881e39df-33af-4af9-ab95-1427c4ebf062.png" width="800"></center>
10
+
11
+ ## News
12
+ - **Sep 2024 (v2.4):**
13
+ - We have updated the pretrained checkpoints trained for 5M steps. This is final release of the BigVGAN-v2 checkpoints.
14
+
15
+ - **Jul 2024 (v2.3):**
16
+ - General refactor and code improvements for improved readability.
17
+ - Fully fused CUDA kernel of anti-alised activation (upsampling + activation + downsampling) with inference speed benchmark.
18
+
19
+ - **Jul 2024 (v2.2):** The repository now includes an interactive local demo using gradio.
20
+
21
+ - **Jul 2024 (v2.1):** BigVGAN is now integrated with 🤗 Hugging Face Hub with easy access to inference using pretrained checkpoints. We also provide an interactive demo on Hugging Face Spaces.
22
+
23
+ - **Jul 2024 (v2):** We release BigVGAN-v2 along with pretrained checkpoints. Below are the highlights:
24
+ - Custom CUDA kernel for inference: we provide a fused upsampling + activation kernel written in CUDA for accelerated inference speed. Our test shows 1.5 - 3x faster speed on a single A100 GPU.
25
+ - Improved discriminator and loss: BigVGAN-v2 is trained using a [multi-scale sub-band CQT discriminator](https://arxiv.org/abs/2311.14957) and a [multi-scale mel spectrogram loss](https://arxiv.org/abs/2306.06546).
26
+ - Larger training data: BigVGAN-v2 is trained using datasets containing diverse audio types, including speech in multiple languages, environmental sounds, and instruments.
27
+ - We provide pretrained checkpoints of BigVGAN-v2 using diverse audio configurations, supporting up to 44 kHz sampling rate and 512x upsampling ratio.
28
+
29
+ ## Installation
30
+
31
+ The codebase has been tested on Python `3.10` and PyTorch `2.3.1` conda packages with either `pytorch-cuda=12.1` or `pytorch-cuda=11.8`. Below is an example command to create the conda environment:
32
+
33
+ ```shell
34
+ conda create -n bigvgan python=3.10 pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
35
+ conda activate bigvgan
36
+ ```
37
+
38
+ Clone the repository and install dependencies:
39
+
40
+ ```shell
41
+ git clone https://github.com/NVIDIA/BigVGAN
42
+ cd BigVGAN
43
+ pip install -r requirements.txt
44
+ ```
45
+
46
+ ## Inference Quickstart using 🤗 Hugging Face Hub
47
+
48
+ Below example describes how you can use BigVGAN: load the pretrained BigVGAN generator from Hugging Face Hub, compute mel spectrogram from input waveform, and generate synthesized waveform using the mel spectrogram as the model's input.
49
+
50
+ ```python
51
+ device = 'cuda'
52
+
53
+ import torch
54
+ import bigvgan
55
+ import librosa
56
+ from meldataset import get_mel_spectrogram
57
+
58
+ # instantiate the model. You can optionally set use_cuda_kernel=True for faster inference.
59
+ model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_24khz_100band_256x', use_cuda_kernel=False)
60
+
61
+ # remove weight norm in the model and set to eval mode
62
+ model.remove_weight_norm()
63
+ model = model.eval().to(device)
64
+
65
+ # load wav file and compute mel spectrogram
66
+ wav_path = '/path/to/your/audio.wav'
67
+ wav, sr = librosa.load(wav_path, sr=model.h.sampling_rate, mono=True) # wav is np.ndarray with shape [T_time] and values in [-1, 1]
68
+ wav = torch.FloatTensor(wav).unsqueeze(0) # wav is FloatTensor with shape [B(1), T_time]
69
+
70
+ # compute mel spectrogram from the ground truth audio
71
+ mel = get_mel_spectrogram(wav, model.h).to(device) # mel is FloatTensor with shape [B(1), C_mel, T_frame]
72
+
73
+ # generate waveform from mel
74
+ with torch.inference_mode():
75
+ wav_gen = model(mel) # wav_gen is FloatTensor with shape [B(1), 1, T_time] and values in [-1, 1]
76
+ wav_gen_float = wav_gen.squeeze(0).cpu() # wav_gen is FloatTensor with shape [1, T_time]
77
+
78
+ # you can convert the generated waveform to 16 bit linear PCM
79
+ wav_gen_int16 = (wav_gen_float * 32767.0).numpy().astype('int16') # wav_gen is now np.ndarray with shape [1, T_time] and int16 dtype
80
+ ```
81
+
82
+ ## Local gradio demo <a href='https://github.com/gradio-app/gradio'><img src='https://img.shields.io/github/stars/gradio-app/gradio'></a>
83
+
84
+ You can run a local gradio demo using below command:
85
+
86
+ ```python
87
+ pip install -r demo/requirements.txt
88
+ python demo/app.py
89
+ ```
90
+
91
+ ## Training
92
+
93
+ Create symbolic link to the root of the dataset. The codebase uses filelist with the relative path from the dataset. Below are the example commands for LibriTTS dataset:
94
+
95
+ ```shell
96
+ cd filelists/LibriTTS && \
97
+ ln -s /path/to/your/LibriTTS/train-clean-100 train-clean-100 && \
98
+ ln -s /path/to/your/LibriTTS/train-clean-360 train-clean-360 && \
99
+ ln -s /path/to/your/LibriTTS/train-other-500 train-other-500 && \
100
+ ln -s /path/to/your/LibriTTS/dev-clean dev-clean && \
101
+ ln -s /path/to/your/LibriTTS/dev-other dev-other && \
102
+ ln -s /path/to/your/LibriTTS/test-clean test-clean && \
103
+ ln -s /path/to/your/LibriTTS/test-other test-other && \
104
+ cd ../..
105
+ ```
106
+
107
+ Train BigVGAN model. Below is an example command for training BigVGAN-v2 using LibriTTS dataset at 24kHz with a full 100-band mel spectrogram as input:
108
+
109
+ ```shell
110
+ python train.py \
111
+ --config configs/bigvgan_v2_24khz_100band_256x.json \
112
+ --input_wavs_dir filelists/LibriTTS \
113
+ --input_training_file filelists/LibriTTS/train-full.txt \
114
+ --input_validation_file filelists/LibriTTS/val-full.txt \
115
+ --list_input_unseen_wavs_dir filelists/LibriTTS filelists/LibriTTS \
116
+ --list_input_unseen_validation_file filelists/LibriTTS/dev-clean.txt filelists/LibriTTS/dev-other.txt \
117
+ --checkpoint_path exp/bigvgan_v2_24khz_100band_256x
118
+ ```
119
+
120
+ ## Synthesis
121
+
122
+ Synthesize from BigVGAN model. Below is an example command for generating audio from the model.
123
+ It computes mel spectrograms using wav files from `--input_wavs_dir` and saves the generated audio to `--output_dir`.
124
+
125
+ ```shell
126
+ python inference.py \
127
+ --checkpoint_file /path/to/your/bigvgan_v2_24khz_100band_256x/bigvgan_generator.pt \
128
+ --input_wavs_dir /path/to/your/input_wav \
129
+ --output_dir /path/to/your/output_wav
130
+ ```
131
+
132
+ `inference_e2e.py` supports synthesis directly from the mel spectrogram saved in `.npy` format, with shapes `[1, channel, frame]` or `[channel, frame]`.
133
+ It loads mel spectrograms from `--input_mels_dir` and saves the generated audio to `--output_dir`.
134
+
135
+ Make sure that the STFT hyperparameters for mel spectrogram are the same as the model, which are defined in `config.json` of the corresponding model.
136
+
137
+ ```shell
138
+ python inference_e2e.py \
139
+ --checkpoint_file /path/to/your/bigvgan_v2_24khz_100band_256x/bigvgan_generator.pt \
140
+ --input_mels_dir /path/to/your/input_mel \
141
+ --output_dir /path/to/your/output_wav
142
+ ```
143
+
144
+ ## Using Custom CUDA Kernel for Synthesis
145
+
146
+ You can apply the fast CUDA inference kernel by using a parameter `use_cuda_kernel` when instantiating BigVGAN:
147
+
148
+ ```python
149
+ generator = BigVGAN(h, use_cuda_kernel=True)
150
+ ```
151
+
152
+ You can also pass `--use_cuda_kernel` to `inference.py` and `inference_e2e.py` to enable this feature.
153
+
154
+ When applied for the first time, it builds the kernel using `nvcc` and `ninja`. If the build succeeds, the kernel is saved to `alias_free_activation/cuda/build` and the model automatically loads the kernel. The codebase has been tested using CUDA `12.1`.
155
+
156
+ Please make sure that both are installed in your system and `nvcc` installed in your system matches the version your PyTorch build is using.
157
+
158
+ We recommend running `test_cuda_vs_torch_model.py` first to build and check the correctness of the CUDA kernel. See below example command and its output, where it returns `[Success] test CUDA fused vs. plain torch BigVGAN inference`:
159
+
160
+ ```python
161
+ python tests/test_cuda_vs_torch_model.py \
162
+ --checkpoint_file /path/to/your/bigvgan_generator.pt
163
+ ```
164
+
165
+ ```shell
166
+ loading plain Pytorch BigVGAN
167
+ ...
168
+ loading CUDA kernel BigVGAN with auto-build
169
+ Detected CUDA files, patching ldflags
170
+ Emitting ninja build file /path/to/your/BigVGAN/alias_free_activation/cuda/build/build.ninja..
171
+ Building extension module anti_alias_activation_cuda...
172
+ ...
173
+ Loading extension module anti_alias_activation_cuda...
174
+ ...
175
+ Loading '/path/to/your/bigvgan_generator.pt'
176
+ ...
177
+ [Success] test CUDA fused vs. plain torch BigVGAN inference
178
+ > mean_difference=0.0007238413265440613
179
+ ...
180
+ ```
181
+
182
+ If you see `[Fail] test CUDA fused vs. plain torch BigVGAN inference`, it means that the CUDA kernel inference is incorrect. Please check if `nvcc` installed in your system is compatible with your PyTorch version.
183
+
184
+ ## Pretrained Models
185
+
186
+ We provide the [pretrained models on Hugging Face Collections](https://huggingface.co/collections/nvidia/bigvgan-66959df3d97fd7d98d97dc9a).
187
+ One can download the checkpoints of the generator weight (named `bigvgan_generator.pt`) and its discriminator/optimizer states (named `bigvgan_discriminator_optimizer.pt`) within the listed model repositories.
188
+
189
+ | Model Name | Sampling Rate | Mel band | fmax | Upsampling Ratio | Params | Dataset | Steps | Fine-Tuned |
190
+ |:--------------------------------------------------------------------------------------------------------:|:-------------:|:--------:|:-----:|:----------------:|:------:|:--------------------------:|:-----:|:----------:|
191
+ | [bigvgan_v2_44khz_128band_512x](https://huggingface.co/nvidia/bigvgan_v2_44khz_128band_512x) | 44 kHz | 128 | 22050 | 512 | 122M | Large-scale Compilation | 5M | No |
192
+ | [bigvgan_v2_44khz_128band_256x](https://huggingface.co/nvidia/bigvgan_v2_44khz_128band_256x) | 44 kHz | 128 | 22050 | 256 | 112M | Large-scale Compilation | 5M | No |
193
+ | [bigvgan_v2_24khz_100band_256x](https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x) | 24 kHz | 100 | 12000 | 256 | 112M | Large-scale Compilation | 5M | No |
194
+ | [bigvgan_v2_22khz_80band_256x](https://huggingface.co/nvidia/bigvgan_v2_22khz_80band_256x) | 22 kHz | 80 | 11025 | 256 | 112M | Large-scale Compilation | 5M | No |
195
+ | [bigvgan_v2_22khz_80band_fmax8k_256x](https://huggingface.co/nvidia/bigvgan_v2_22khz_80band_fmax8k_256x) | 22 kHz | 80 | 8000 | 256 | 112M | Large-scale Compilation | 5M | No |
196
+ | [bigvgan_24khz_100band](https://huggingface.co/nvidia/bigvgan_24khz_100band) | 24 kHz | 100 | 12000 | 256 | 112M | LibriTTS | 5M | No |
197
+ | [bigvgan_base_24khz_100band](https://huggingface.co/nvidia/bigvgan_base_24khz_100band) | 24 kHz | 100 | 12000 | 256 | 14M | LibriTTS | 5M | No |
198
+ | [bigvgan_22khz_80band](https://huggingface.co/nvidia/bigvgan_22khz_80band) | 22 kHz | 80 | 8000 | 256 | 112M | LibriTTS + VCTK + LJSpeech | 5M | No |
199
+ | [bigvgan_base_22khz_80band](https://huggingface.co/nvidia/bigvgan_base_22khz_80band) | 22 kHz | 80 | 8000 | 256 | 14M | LibriTTS + VCTK + LJSpeech | 5M | No |
200
+
201
+ The paper results are based on the original 24kHz BigVGAN models (`bigvgan_24khz_100band` and `bigvgan_base_24khz_100band`) trained on LibriTTS dataset.
202
+ We also provide 22kHz BigVGAN models with band-limited setup (i.e., fmax=8000) for TTS applications.
203
+ Note that the checkpoints use `snakebeta` activation with log scale parameterization, which have the best overall quality.
204
+
205
+ You can fine-tune the models by:
206
+
207
+ 1. downloading the checkpoints (both the generator weight and its discriminator/optimizer states)
208
+ 2. resuming training using your audio dataset by specifying `--checkpoint_path` that includes the checkpoints when launching `train.py`
209
+
210
+ ## Training Details of BigVGAN-v2
211
+
212
+ Comapred to the original BigVGAN, the pretrained checkpoints of BigVGAN-v2 used `batch_size=32` with a longer `segment_size=65536` and are trained using 8 A100 GPUs.
213
+
214
+ Note that the BigVGAN-v2 `json` config files in `./configs` use `batch_size=4` as default to fit in a single A100 GPU for training. You can fine-tune the models adjusting `batch_size` depending on your GPUs.
215
+
216
+ When training BigVGAN-v2 from scratch with small batch size, it can potentially encounter the early divergence problem mentioned in the paper. In such case, we recommend lowering the `clip_grad_norm` value (e.g. `100`) for the early training iterations (e.g. 20000 steps) and increase the value to the default `500`.
217
+
218
+ ## Evaluation Results of BigVGAN-v2
219
+
220
+ Below are the objective results of the 24kHz model (`bigvgan_v2_24khz_100band_256x`) obtained from the LibriTTS `dev` sets. BigVGAN-v2 shows noticeable improvements of the metrics. The model also exhibits reduced perceptual artifacts, especially for non-speech audio.
221
+
222
+ | Model | Dataset | Steps | PESQ(↑) | M-STFT(↓) | MCD(↓) | Periodicity(↓) | V/UV F1(↑) |
223
+ |:----------:|:-----------------------:|:-----:|:---------:|:----------:|:----------:|:--------------:|:----------:|
224
+ | BigVGAN | LibriTTS | 1M | 4.027 | 0.7997 | 0.3745 | 0.1018 | 0.9598 |
225
+ | BigVGAN | LibriTTS | 5M | 4.256 | 0.7409 | 0.2988 | 0.0809 | 0.9698 |
226
+ | BigVGAN-v2 | Large-scale Compilation | 3M | 4.359 | 0.7134 | 0.3060 | 0.0621 | 0.9777 |
227
+ | BigVGAN-v2 | Large-scale Compilation | 5M | **4.362** | **0.7026** | **0.2903** | **0.0593** | **0.9793** |
228
+
229
+ ## Speed Benchmark
230
+
231
+ Below are the speed and VRAM usage benchmark results of BigVGAN from `tests/test_cuda_vs_torch_model.py`, using `bigvgan_v2_24khz_100band_256x` as a reference model.
232
+
233
+ | GPU | num_mel_frame | use_cuda_kernel | Speed (kHz) | Real-time Factor | VRAM (GB) |
234
+ |:--------------------------:|:-------------:|:---------------:|:-----------:|:----------------:|:---------:|
235
+ | NVIDIA A100 | 256 | False | 1672.1 | 69.7x | 1.3 |
236
+ | | | True | 3916.5 | 163.2x | 1.3 |
237
+ | | 2048 | False | 1899.6 | 79.2x | 1.7 |
238
+ | | | True | 5330.1 | 222.1x | 1.7 |
239
+ | | 16384 | False | 1973.8 | 82.2x | 5.0 |
240
+ | | | True | 5761.7 | 240.1x | 4.4 |
241
+ | NVIDIA GeForce RTX 3080 | 256 | False | 841.1 | 35.0x | 1.3 |
242
+ | | | True | 1598.1 | 66.6x | 1.3 |
243
+ | | 2048 | False | 929.9 | 38.7x | 1.7 |
244
+ | | | True | 1971.3 | 82.1x | 1.6 |
245
+ | | 16384 | False | 943.4 | 39.3x | 5.0 |
246
+ | | | True | 2026.5 | 84.4x | 3.9 |
247
+ | NVIDIA GeForce RTX 2080 Ti | 256 | False | 515.6 | 21.5x | 1.3 |
248
+ | | | True | 811.3 | 33.8x | 1.3 |
249
+ | | 2048 | False | 576.5 | 24.0x | 1.7 |
250
+ | | | True | 1023.0 | 42.6x | 1.5 |
251
+ | | 16384 | False | 589.4 | 24.6x | 5.0 |
252
+ | | | True | 1068.1 | 44.5x | 3.2 |
253
+
254
+ ## Acknowledgements
255
+
256
+ We thank Vijay Anand Korthikanti and Kevin J. Shih for their generous support in implementing the CUDA kernel for inference.
257
+
258
+ ## References
259
+
260
+ - [HiFi-GAN](https://github.com/jik876/hifi-gan) (for generator and multi-period discriminator)
261
+ - [Snake](https://github.com/EdwardDixon/snake) (for periodic activation)
262
+ - [Alias-free-torch](https://github.com/junjun3518/alias-free-torch) (for anti-aliasing)
263
+ - [Julius](https://github.com/adefossez/julius) (for low-pass filter)
264
+ - [UnivNet](https://github.com/mindslab-ai/univnet) (for multi-resolution discriminator)
265
+ - [descript-audio-codec](https://github.com/descriptinc/descript-audio-codec) and [vocos](https://github.com/gemelo-ai/vocos) (for multi-band multi-scale STFT discriminator and multi-scale mel spectrogram loss)
266
+ - [Amphion](https://github.com/open-mmlab/Amphion) (for multi-scale sub-band CQT discriminator)
src/third_party/BigVGAN/activations.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch
5
+ from torch import nn, sin, pow
6
+ from torch.nn import Parameter
7
+
8
+
9
+ class Snake(nn.Module):
10
+ """
11
+ Implementation of a sine-based periodic activation function
12
+ Shape:
13
+ - Input: (B, C, T)
14
+ - Output: (B, C, T), same shape as the input
15
+ Parameters:
16
+ - alpha - trainable parameter
17
+ References:
18
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
19
+ https://arxiv.org/abs/2006.08195
20
+ Examples:
21
+ >>> a1 = snake(256)
22
+ >>> x = torch.randn(256)
23
+ >>> x = a1(x)
24
+ """
25
+
26
+ def __init__(
27
+ self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
28
+ ):
29
+ """
30
+ Initialization.
31
+ INPUT:
32
+ - in_features: shape of the input
33
+ - alpha: trainable parameter
34
+ alpha is initialized to 1 by default, higher values = higher-frequency.
35
+ alpha will be trained along with the rest of your model.
36
+ """
37
+ super(Snake, self).__init__()
38
+ self.in_features = in_features
39
+
40
+ # Initialize alpha
41
+ self.alpha_logscale = alpha_logscale
42
+ if self.alpha_logscale: # Log scale alphas initialized to zeros
43
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
44
+ else: # Linear scale alphas initialized to ones
45
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
46
+
47
+ self.alpha.requires_grad = alpha_trainable
48
+
49
+ self.no_div_by_zero = 0.000000001
50
+
51
+ def forward(self, x):
52
+ """
53
+ Forward pass of the function.
54
+ Applies the function to the input elementwise.
55
+ Snake ∶= x + 1/a * sin^2 (xa)
56
+ """
57
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # Line up with x to [B, C, T]
58
+ if self.alpha_logscale:
59
+ alpha = torch.exp(alpha)
60
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
61
+
62
+ return x
63
+
64
+
65
+ class SnakeBeta(nn.Module):
66
+ """
67
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
68
+ Shape:
69
+ - Input: (B, C, T)
70
+ - Output: (B, C, T), same shape as the input
71
+ Parameters:
72
+ - alpha - trainable parameter that controls frequency
73
+ - beta - trainable parameter that controls magnitude
74
+ References:
75
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
76
+ https://arxiv.org/abs/2006.08195
77
+ Examples:
78
+ >>> a1 = snakebeta(256)
79
+ >>> x = torch.randn(256)
80
+ >>> x = a1(x)
81
+ """
82
+
83
+ def __init__(
84
+ self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
85
+ ):
86
+ """
87
+ Initialization.
88
+ INPUT:
89
+ - in_features: shape of the input
90
+ - alpha - trainable parameter that controls frequency
91
+ - beta - trainable parameter that controls magnitude
92
+ alpha is initialized to 1 by default, higher values = higher-frequency.
93
+ beta is initialized to 1 by default, higher values = higher-magnitude.
94
+ alpha will be trained along with the rest of your model.
95
+ """
96
+ super(SnakeBeta, self).__init__()
97
+ self.in_features = in_features
98
+
99
+ # Initialize alpha
100
+ self.alpha_logscale = alpha_logscale
101
+ if self.alpha_logscale: # Log scale alphas initialized to zeros
102
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
103
+ self.beta = Parameter(torch.zeros(in_features) * alpha)
104
+ else: # Linear scale alphas initialized to ones
105
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
106
+ self.beta = Parameter(torch.ones(in_features) * alpha)
107
+
108
+ self.alpha.requires_grad = alpha_trainable
109
+ self.beta.requires_grad = alpha_trainable
110
+
111
+ self.no_div_by_zero = 0.000000001
112
+
113
+ def forward(self, x):
114
+ """
115
+ Forward pass of the function.
116
+ Applies the function to the input elementwise.
117
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
118
+ """
119
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # Line up with x to [B, C, T]
120
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
121
+ if self.alpha_logscale:
122
+ alpha = torch.exp(alpha)
123
+ beta = torch.exp(beta)
124
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
125
+
126
+ return x
src/third_party/BigVGAN/alias_free_activation/cuda/__init__.py ADDED
File without changes
src/third_party/BigVGAN/alias_free_activation/cuda/activation1d.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from alias_free_activation.torch.resample import UpSample1d, DownSample1d
7
+
8
+ # load fused CUDA kernel: this enables importing anti_alias_activation_cuda
9
+ from alias_free_activation.cuda import load
10
+
11
+ anti_alias_activation_cuda = load.load()
12
+
13
+
14
+ class FusedAntiAliasActivation(torch.autograd.Function):
15
+ """
16
+ Assumes filter size 12, replication padding on upsampling/downsampling, and logscale alpha/beta parameters as inputs.
17
+ The hyperparameters are hard-coded in the kernel to maximize speed.
18
+ NOTE: The fused kenrel is incorrect for Activation1d with different hyperparameters.
19
+ """
20
+
21
+ @staticmethod
22
+ def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta):
23
+ activation_results = anti_alias_activation_cuda.forward(
24
+ inputs, up_ftr, down_ftr, alpha, beta
25
+ )
26
+
27
+ return activation_results
28
+
29
+ @staticmethod
30
+ def backward(ctx, output_grads):
31
+ raise NotImplementedError
32
+ return output_grads, None, None
33
+
34
+
35
+ class Activation1d(nn.Module):
36
+ def __init__(
37
+ self,
38
+ activation,
39
+ up_ratio: int = 2,
40
+ down_ratio: int = 2,
41
+ up_kernel_size: int = 12,
42
+ down_kernel_size: int = 12,
43
+ fused: bool = True,
44
+ ):
45
+ super().__init__()
46
+ self.up_ratio = up_ratio
47
+ self.down_ratio = down_ratio
48
+ self.act = activation
49
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
50
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
51
+
52
+ self.fused = fused # Whether to use fused CUDA kernel or not
53
+
54
+ def forward(self, x):
55
+ if not self.fused:
56
+ x = self.upsample(x)
57
+ x = self.act(x)
58
+ x = self.downsample(x)
59
+ return x
60
+ else:
61
+ if self.act.__class__.__name__ == "Snake":
62
+ beta = self.act.alpha.data # Snake uses same params for alpha and beta
63
+ else:
64
+ beta = (
65
+ self.act.beta.data
66
+ ) # Snakebeta uses different params for alpha and beta
67
+ alpha = self.act.alpha.data
68
+ if (
69
+ not self.act.alpha_logscale
70
+ ): # Exp baked into cuda kernel, cancel it out with a log
71
+ alpha = torch.log(alpha)
72
+ beta = torch.log(beta)
73
+
74
+ x = FusedAntiAliasActivation.apply(
75
+ x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta
76
+ )
77
+ return x
src/third_party/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #include <torch/extension.h>
18
+
19
+ extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta);
20
+
21
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
22
+ m.def("forward", &fwd_cuda, "Anti-Alias Activation forward (CUDA)");
23
+ }
src/third_party/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #include <ATen/ATen.h>
18
+ #include <cuda.h>
19
+ #include <cuda_runtime.h>
20
+ #include <cuda_fp16.h>
21
+ #include <cuda_profiler_api.h>
22
+ #include <ATen/cuda/CUDAContext.h>
23
+ #include <torch/extension.h>
24
+ #include "type_shim.h"
25
+ #include <assert.h>
26
+ #include <cfloat>
27
+ #include <limits>
28
+ #include <stdint.h>
29
+ #include <c10/macros/Macros.h>
30
+
31
+ namespace
32
+ {
33
+ // Hard-coded hyperparameters
34
+ // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
35
+ constexpr int ELEMENTS_PER_LDG_STG = 1; //(WARP_ITERATIONS < 4) ? 1 : 4;
36
+ constexpr int BUFFER_SIZE = 32;
37
+ constexpr int FILTER_SIZE = 12;
38
+ constexpr int HALF_FILTER_SIZE = 6;
39
+ constexpr int UPSAMPLE_REPLICATION_PAD = 5; // 5 on each side, matching torch impl
40
+ constexpr int DOWNSAMPLE_REPLICATION_PAD_LEFT = 5; // matching torch impl
41
+ constexpr int DOWNSAMPLE_REPLICATION_PAD_RIGHT = 6; // matching torch impl
42
+
43
+ template <typename input_t, typename output_t, typename acc_t>
44
+ __global__ void anti_alias_activation_forward(
45
+ output_t *dst,
46
+ const input_t *src,
47
+ const input_t *up_ftr,
48
+ const input_t *down_ftr,
49
+ const input_t *alpha,
50
+ const input_t *beta,
51
+ int batch_size,
52
+ int channels,
53
+ int seq_len)
54
+ {
55
+ // Up and downsample filters
56
+ input_t up_filter[FILTER_SIZE];
57
+ input_t down_filter[FILTER_SIZE];
58
+
59
+ // Load data from global memory including extra indices reserved for replication paddings
60
+ input_t elements[2 * FILTER_SIZE + 2 * BUFFER_SIZE + 2 * UPSAMPLE_REPLICATION_PAD] = {0};
61
+ input_t intermediates[2 * FILTER_SIZE + 2 * BUFFER_SIZE + DOWNSAMPLE_REPLICATION_PAD_LEFT + DOWNSAMPLE_REPLICATION_PAD_RIGHT] = {0};
62
+
63
+ // Output stores downsampled output before writing to dst
64
+ output_t output[BUFFER_SIZE];
65
+
66
+ // blockDim/threadIdx = (128, 1, 1)
67
+ // gridDim/blockIdx = (seq_blocks, channels, batches)
68
+ int block_offset = (blockIdx.x * 128 * BUFFER_SIZE + seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
69
+ int local_offset = threadIdx.x * BUFFER_SIZE;
70
+ int seq_offset = blockIdx.x * 128 * BUFFER_SIZE + local_offset;
71
+
72
+ // intermediate have double the seq_len
73
+ int intermediate_local_offset = threadIdx.x * BUFFER_SIZE * 2;
74
+ int intermediate_seq_offset = blockIdx.x * 128 * BUFFER_SIZE * 2 + intermediate_local_offset;
75
+
76
+ // Get values needed for replication padding before moving pointer
77
+ const input_t *right_most_pntr = src + (seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
78
+ input_t seq_left_most_value = right_most_pntr[0];
79
+ input_t seq_right_most_value = right_most_pntr[seq_len - 1];
80
+
81
+ // Move src and dst pointers
82
+ src += block_offset + local_offset;
83
+ dst += block_offset + local_offset;
84
+
85
+ // Alpha and beta values for snake activatons. Applies exp by default
86
+ alpha = alpha + blockIdx.y;
87
+ input_t alpha_val = expf(alpha[0]);
88
+ beta = beta + blockIdx.y;
89
+ input_t beta_val = expf(beta[0]);
90
+
91
+ #pragma unroll
92
+ for (int it = 0; it < FILTER_SIZE; it += 1)
93
+ {
94
+ up_filter[it] = up_ftr[it];
95
+ down_filter[it] = down_ftr[it];
96
+ }
97
+
98
+ // Apply replication padding for upsampling, matching torch impl
99
+ #pragma unroll
100
+ for (int it = -HALF_FILTER_SIZE; it < BUFFER_SIZE + HALF_FILTER_SIZE; it += 1)
101
+ {
102
+ int element_index = seq_offset + it; // index for element
103
+ if ((element_index < 0) && (element_index >= -UPSAMPLE_REPLICATION_PAD))
104
+ {
105
+ elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_left_most_value;
106
+ }
107
+ if ((element_index >= seq_len) && (element_index < seq_len + UPSAMPLE_REPLICATION_PAD))
108
+ {
109
+ elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_right_most_value;
110
+ }
111
+ if ((element_index >= 0) && (element_index < seq_len))
112
+ {
113
+ elements[2 * (HALF_FILTER_SIZE + it)] = 2 * src[it];
114
+ }
115
+ }
116
+
117
+ // Apply upsampling strided convolution and write to intermediates. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT for replication padding of the downsampilng conv later
118
+ #pragma unroll
119
+ for (int it = 0; it < (2 * BUFFER_SIZE + 2 * FILTER_SIZE); it += 1)
120
+ {
121
+ input_t acc = 0.0;
122
+ int element_index = intermediate_seq_offset + it; // index for intermediate
123
+ #pragma unroll
124
+ for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
125
+ {
126
+ if ((element_index + f_idx) >= 0)
127
+ {
128
+ acc += up_filter[f_idx] * elements[it + f_idx];
129
+ }
130
+ }
131
+ intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] = acc;
132
+ }
133
+
134
+ // Apply activation function. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT and DOWNSAMPLE_REPLICATION_PAD_RIGHT for replication padding of the downsampilng conv later
135
+ double no_div_by_zero = 0.000000001;
136
+ #pragma unroll
137
+ for (int it = 0; it < 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it += 1)
138
+ {
139
+ intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] += (1.0 / (beta_val + no_div_by_zero)) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val);
140
+ }
141
+
142
+ // Apply replication padding before downsampling conv from intermediates
143
+ #pragma unroll
144
+ for (int it = 0; it < DOWNSAMPLE_REPLICATION_PAD_LEFT; it += 1)
145
+ {
146
+ intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT];
147
+ }
148
+ #pragma unroll
149
+ for (int it = DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it < DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE + DOWNSAMPLE_REPLICATION_PAD_RIGHT; it += 1)
150
+ {
151
+ intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE - 1];
152
+ }
153
+
154
+ // Apply downsample strided convolution (assuming stride=2) from intermediates
155
+ #pragma unroll
156
+ for (int it = 0; it < BUFFER_SIZE; it += 1)
157
+ {
158
+ input_t acc = 0.0;
159
+ #pragma unroll
160
+ for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
161
+ {
162
+ // Add constant DOWNSAMPLE_REPLICATION_PAD_RIGHT to match torch implementation
163
+ acc += down_filter[f_idx] * intermediates[it * 2 + f_idx + DOWNSAMPLE_REPLICATION_PAD_RIGHT];
164
+ }
165
+ output[it] = acc;
166
+ }
167
+
168
+ // Write output to dst
169
+ #pragma unroll
170
+ for (int it = 0; it < BUFFER_SIZE; it += ELEMENTS_PER_LDG_STG)
171
+ {
172
+ int element_index = seq_offset + it;
173
+ if (element_index < seq_len)
174
+ {
175
+ dst[it] = output[it];
176
+ }
177
+ }
178
+
179
+ }
180
+
181
+ template <typename input_t, typename output_t, typename acc_t>
182
+ void dispatch_anti_alias_activation_forward(
183
+ output_t *dst,
184
+ const input_t *src,
185
+ const input_t *up_ftr,
186
+ const input_t *down_ftr,
187
+ const input_t *alpha,
188
+ const input_t *beta,
189
+ int batch_size,
190
+ int channels,
191
+ int seq_len)
192
+ {
193
+ if (seq_len == 0)
194
+ {
195
+ return;
196
+ }
197
+ else
198
+ {
199
+ // Use 128 threads per block to maximimize gpu utilization
200
+ constexpr int threads_per_block = 128;
201
+ constexpr int seq_len_per_block = 4096;
202
+ int blocks_per_seq_len = (seq_len + seq_len_per_block - 1) / seq_len_per_block;
203
+ dim3 blocks(blocks_per_seq_len, channels, batch_size);
204
+ dim3 threads(threads_per_block, 1, 1);
205
+
206
+ anti_alias_activation_forward<input_t, output_t, acc_t>
207
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, up_ftr, down_ftr, alpha, beta, batch_size, channels, seq_len);
208
+ }
209
+ }
210
+ }
211
+
212
+ extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta)
213
+ {
214
+ // Input is a 3d tensor with dimensions [batches, channels, seq_len]
215
+ const int batches = input.size(0);
216
+ const int channels = input.size(1);
217
+ const int seq_len = input.size(2);
218
+
219
+ // Output
220
+ auto act_options = input.options().requires_grad(false);
221
+
222
+ torch::Tensor anti_alias_activation_results =
223
+ torch::empty({batches, channels, seq_len}, act_options);
224
+
225
+ void *input_ptr = static_cast<void *>(input.data_ptr());
226
+ void *up_filter_ptr = static_cast<void *>(up_filter.data_ptr());
227
+ void *down_filter_ptr = static_cast<void *>(down_filter.data_ptr());
228
+ void *alpha_ptr = static_cast<void *>(alpha.data_ptr());
229
+ void *beta_ptr = static_cast<void *>(beta.data_ptr());
230
+ void *anti_alias_activation_results_ptr = static_cast<void *>(anti_alias_activation_results.data_ptr());
231
+
232
+ DISPATCH_FLOAT_HALF_AND_BFLOAT(
233
+ input.scalar_type(),
234
+ "dispatch anti alias activation_forward",
235
+ dispatch_anti_alias_activation_forward<scalar_t, scalar_t, float>(
236
+ reinterpret_cast<scalar_t *>(anti_alias_activation_results_ptr),
237
+ reinterpret_cast<const scalar_t *>(input_ptr),
238
+ reinterpret_cast<const scalar_t *>(up_filter_ptr),
239
+ reinterpret_cast<const scalar_t *>(down_filter_ptr),
240
+ reinterpret_cast<const scalar_t *>(alpha_ptr),
241
+ reinterpret_cast<const scalar_t *>(beta_ptr),
242
+ batches,
243
+ channels,
244
+ seq_len););
245
+ return anti_alias_activation_results;
246
+ }
src/third_party/BigVGAN/alias_free_activation/cuda/compat.h ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ /*This code is copied fron NVIDIA apex:
18
+ * https://github.com/NVIDIA/apex
19
+ * with minor changes. */
20
+
21
+ #ifndef TORCH_CHECK
22
+ #define TORCH_CHECK AT_CHECK
23
+ #endif
24
+
25
+ #ifdef VERSION_GE_1_3
26
+ #define DATA_PTR data_ptr
27
+ #else
28
+ #define DATA_PTR data
29
+ #endif
src/third_party/BigVGAN/alias_free_activation/cuda/load.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ import os
5
+ import pathlib
6
+ import subprocess
7
+
8
+ from torch.utils import cpp_extension
9
+
10
+ """
11
+ Setting this param to a list has a problem of generating different compilation commands (with diferent order of architectures) and leading to recompilation of fused kernels.
12
+ Set it to empty stringo avoid recompilation and assign arch flags explicity in extra_cuda_cflags below
13
+ """
14
+ os.environ["TORCH_CUDA_ARCH_LIST"] = ""
15
+
16
+
17
+ def load():
18
+ # Check if cuda 11 is installed for compute capability 8.0
19
+ cc_flag = []
20
+ _, bare_metal_major, _ = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
21
+ if int(bare_metal_major) >= 11:
22
+ cc_flag.append("-gencode")
23
+ cc_flag.append("arch=compute_80,code=sm_80")
24
+
25
+ # Build path
26
+ srcpath = pathlib.Path(__file__).parent.absolute()
27
+ buildpath = srcpath / "build"
28
+ _create_build_dir(buildpath)
29
+
30
+ # Helper function to build the kernels.
31
+ def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
32
+ return cpp_extension.load(
33
+ name=name,
34
+ sources=sources,
35
+ build_directory=buildpath,
36
+ extra_cflags=[
37
+ "-O3",
38
+ ],
39
+ extra_cuda_cflags=[
40
+ "-O3",
41
+ "-gencode",
42
+ "arch=compute_70,code=sm_70",
43
+ "--use_fast_math",
44
+ ]
45
+ + extra_cuda_flags
46
+ + cc_flag,
47
+ verbose=True,
48
+ )
49
+
50
+ extra_cuda_flags = [
51
+ "-U__CUDA_NO_HALF_OPERATORS__",
52
+ "-U__CUDA_NO_HALF_CONVERSIONS__",
53
+ "--expt-relaxed-constexpr",
54
+ "--expt-extended-lambda",
55
+ ]
56
+
57
+ sources = [
58
+ srcpath / "anti_alias_activation.cpp",
59
+ srcpath / "anti_alias_activation_cuda.cu",
60
+ ]
61
+ anti_alias_activation_cuda = _cpp_extention_load_helper(
62
+ "anti_alias_activation_cuda", sources, extra_cuda_flags
63
+ )
64
+
65
+ return anti_alias_activation_cuda
66
+
67
+
68
+ def _get_cuda_bare_metal_version(cuda_dir):
69
+ raw_output = subprocess.check_output(
70
+ [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
71
+ )
72
+ output = raw_output.split()
73
+ release_idx = output.index("release") + 1
74
+ release = output[release_idx].split(".")
75
+ bare_metal_major = release[0]
76
+ bare_metal_minor = release[1][0]
77
+
78
+ return raw_output, bare_metal_major, bare_metal_minor
79
+
80
+
81
+ def _create_build_dir(buildpath):
82
+ try:
83
+ os.mkdir(buildpath)
84
+ except OSError:
85
+ if not os.path.isdir(buildpath):
86
+ print(f"Creation of the build directory {buildpath} failed")
src/third_party/BigVGAN/alias_free_activation/cuda/type_shim.h ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #include <ATen/ATen.h>
18
+ #include "compat.h"
19
+
20
+ #define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...) \
21
+ switch (TYPE) \
22
+ { \
23
+ case at::ScalarType::Float: \
24
+ { \
25
+ using scalar_t = float; \
26
+ __VA_ARGS__; \
27
+ break; \
28
+ } \
29
+ case at::ScalarType::Half: \
30
+ { \
31
+ using scalar_t = at::Half; \
32
+ __VA_ARGS__; \
33
+ break; \
34
+ } \
35
+ case at::ScalarType::BFloat16: \
36
+ { \
37
+ using scalar_t = at::BFloat16; \
38
+ __VA_ARGS__; \
39
+ break; \
40
+ } \
41
+ default: \
42
+ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
43
+ }
44
+
45
+ #define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
46
+ switch (TYPEIN) \
47
+ { \
48
+ case at::ScalarType::Float: \
49
+ { \
50
+ using scalar_t_in = float; \
51
+ switch (TYPEOUT) \
52
+ { \
53
+ case at::ScalarType::Float: \
54
+ { \
55
+ using scalar_t_out = float; \
56
+ __VA_ARGS__; \
57
+ break; \
58
+ } \
59
+ case at::ScalarType::Half: \
60
+ { \
61
+ using scalar_t_out = at::Half; \
62
+ __VA_ARGS__; \
63
+ break; \
64
+ } \
65
+ case at::ScalarType::BFloat16: \
66
+ { \
67
+ using scalar_t_out = at::BFloat16; \
68
+ __VA_ARGS__; \
69
+ break; \
70
+ } \
71
+ default: \
72
+ AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
73
+ } \
74
+ break; \
75
+ } \
76
+ case at::ScalarType::Half: \
77
+ { \
78
+ using scalar_t_in = at::Half; \
79
+ using scalar_t_out = at::Half; \
80
+ __VA_ARGS__; \
81
+ break; \
82
+ } \
83
+ case at::ScalarType::BFloat16: \
84
+ { \
85
+ using scalar_t_in = at::BFloat16; \
86
+ using scalar_t_out = at::BFloat16; \
87
+ __VA_ARGS__; \
88
+ break; \
89
+ } \
90
+ default: \
91
+ AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
92
+ }
src/third_party/BigVGAN/alias_free_activation/torch/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ from .filter import *
5
+ from .resample import *
6
+ from .act import *
src/third_party/BigVGAN/alias_free_activation/torch/act.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch.nn as nn
5
+ from alias_free_activation.torch.resample import UpSample1d, DownSample1d
6
+
7
+
8
+ class Activation1d(nn.Module):
9
+ def __init__(
10
+ self,
11
+ activation,
12
+ up_ratio: int = 2,
13
+ down_ratio: int = 2,
14
+ up_kernel_size: int = 12,
15
+ down_kernel_size: int = 12,
16
+ ):
17
+ super().__init__()
18
+ self.up_ratio = up_ratio
19
+ self.down_ratio = down_ratio
20
+ self.act = activation
21
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
22
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
23
+
24
+ # x: [B,C,T]
25
+ def forward(self, x):
26
+ x = self.upsample(x)
27
+ x = self.act(x)
28
+ x = self.downsample(x)
29
+
30
+ return x
src/third_party/BigVGAN/alias_free_activation/torch/filter.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import math
8
+
9
+ if "sinc" in dir(torch):
10
+ sinc = torch.sinc
11
+ else:
12
+ # This code is adopted from adefossez's julius.core.sinc under the MIT License
13
+ # https://adefossez.github.io/julius/julius/core.html
14
+ # LICENSE is in incl_licenses directory.
15
+ def sinc(x: torch.Tensor):
16
+ """
17
+ Implementation of sinc, i.e. sin(pi * x) / (pi * x)
18
+ __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
19
+ """
20
+ return torch.where(
21
+ x == 0,
22
+ torch.tensor(1.0, device=x.device, dtype=x.dtype),
23
+ torch.sin(math.pi * x) / math.pi / x,
24
+ )
25
+
26
+
27
+ # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
28
+ # https://adefossez.github.io/julius/julius/lowpass.html
29
+ # LICENSE is in incl_licenses directory.
30
+ def kaiser_sinc_filter1d(
31
+ cutoff, half_width, kernel_size
32
+ ): # return filter [1,1,kernel_size]
33
+ even = kernel_size % 2 == 0
34
+ half_size = kernel_size // 2
35
+
36
+ # For kaiser window
37
+ delta_f = 4 * half_width
38
+ A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
39
+ if A > 50.0:
40
+ beta = 0.1102 * (A - 8.7)
41
+ elif A >= 21.0:
42
+ beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
43
+ else:
44
+ beta = 0.0
45
+ window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
46
+
47
+ # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
48
+ if even:
49
+ time = torch.arange(-half_size, half_size) + 0.5
50
+ else:
51
+ time = torch.arange(kernel_size) - half_size
52
+ if cutoff == 0:
53
+ filter_ = torch.zeros_like(time)
54
+ else:
55
+ filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
56
+ """
57
+ Normalize filter to have sum = 1, otherwise we will have a small leakage of the constant component in the input signal.
58
+ """
59
+ filter_ /= filter_.sum()
60
+ filter = filter_.view(1, 1, kernel_size)
61
+
62
+ return filter
63
+
64
+
65
+ class LowPassFilter1d(nn.Module):
66
+ def __init__(
67
+ self,
68
+ cutoff=0.5,
69
+ half_width=0.6,
70
+ stride: int = 1,
71
+ padding: bool = True,
72
+ padding_mode: str = "replicate",
73
+ kernel_size: int = 12,
74
+ ):
75
+ """
76
+ kernel_size should be even number for stylegan3 setup, in this implementation, odd number is also possible.
77
+ """
78
+ super().__init__()
79
+ if cutoff < -0.0:
80
+ raise ValueError("Minimum cutoff must be larger than zero.")
81
+ if cutoff > 0.5:
82
+ raise ValueError("A cutoff above 0.5 does not make sense.")
83
+ self.kernel_size = kernel_size
84
+ self.even = kernel_size % 2 == 0
85
+ self.pad_left = kernel_size // 2 - int(self.even)
86
+ self.pad_right = kernel_size // 2
87
+ self.stride = stride
88
+ self.padding = padding
89
+ self.padding_mode = padding_mode
90
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
91
+ self.register_buffer("filter", filter)
92
+
93
+ # Input [B, C, T]
94
+ def forward(self, x):
95
+ _, C, _ = x.shape
96
+
97
+ if self.padding:
98
+ x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
99
+ out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
100
+
101
+ return out
src/third_party/BigVGAN/alias_free_activation/torch/resample.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+ from alias_free_activation.torch.filter import LowPassFilter1d
7
+ from alias_free_activation.torch.filter import kaiser_sinc_filter1d
8
+
9
+
10
+ class UpSample1d(nn.Module):
11
+ def __init__(self, ratio=2, kernel_size=None):
12
+ super().__init__()
13
+ self.ratio = ratio
14
+ self.kernel_size = (
15
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
16
+ )
17
+ self.stride = ratio
18
+ self.pad = self.kernel_size // ratio - 1
19
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
20
+ self.pad_right = (
21
+ self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
22
+ )
23
+ filter = kaiser_sinc_filter1d(
24
+ cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
25
+ )
26
+ self.register_buffer("filter", filter)
27
+
28
+ # x: [B, C, T]
29
+ def forward(self, x):
30
+ _, C, _ = x.shape
31
+
32
+ x = F.pad(x, (self.pad, self.pad), mode="replicate")
33
+ x = self.ratio * F.conv_transpose1d(
34
+ x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
35
+ )
36
+ x = x[..., self.pad_left : -self.pad_right]
37
+
38
+ return x
39
+
40
+
41
+ class DownSample1d(nn.Module):
42
+ def __init__(self, ratio=2, kernel_size=None):
43
+ super().__init__()
44
+ self.ratio = ratio
45
+ self.kernel_size = (
46
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
47
+ )
48
+ self.lowpass = LowPassFilter1d(
49
+ cutoff=0.5 / ratio,
50
+ half_width=0.6 / ratio,
51
+ stride=ratio,
52
+ kernel_size=self.kernel_size,
53
+ )
54
+
55
+ def forward(self, x):
56
+ xx = self.lowpass(x)
57
+
58
+ return xx
src/third_party/BigVGAN/bigvgan.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ import os
8
+ import json
9
+ from pathlib import Path
10
+ from typing import Optional, Union, Dict
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.nn import Conv1d, ConvTranspose1d
15
+ from torch.nn.utils import weight_norm, remove_weight_norm
16
+
17
+ import activations
18
+ from utils import init_weights, get_padding
19
+ from alias_free_activation.torch.act import Activation1d as TorchActivation1d
20
+ from env import AttrDict
21
+
22
+ from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
23
+
24
+
25
+ def load_hparams_from_json(path) -> AttrDict:
26
+ with open(path) as f:
27
+ data = f.read()
28
+ return AttrDict(json.loads(data))
29
+
30
+
31
+ class AMPBlock1(torch.nn.Module):
32
+ """
33
+ AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
34
+ AMPBlock1 has additional self.convs2 that contains additional Conv1d layers with a fixed dilation=1 followed by each layer in self.convs1
35
+
36
+ Args:
37
+ h (AttrDict): Hyperparameters.
38
+ channels (int): Number of convolution channels.
39
+ kernel_size (int): Size of the convolution kernel. Default is 3.
40
+ dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).
41
+ activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ h: AttrDict,
47
+ channels: int,
48
+ kernel_size: int = 3,
49
+ dilation: tuple = (1, 3, 5),
50
+ activation: str = None,
51
+ ):
52
+ super().__init__()
53
+
54
+ self.h = h
55
+
56
+ self.convs1 = nn.ModuleList(
57
+ [
58
+ weight_norm(
59
+ Conv1d(
60
+ channels,
61
+ channels,
62
+ kernel_size,
63
+ stride=1,
64
+ dilation=d,
65
+ padding=get_padding(kernel_size, d),
66
+ )
67
+ )
68
+ for d in dilation
69
+ ]
70
+ )
71
+ self.convs1.apply(init_weights)
72
+
73
+ self.convs2 = nn.ModuleList(
74
+ [
75
+ weight_norm(
76
+ Conv1d(
77
+ channels,
78
+ channels,
79
+ kernel_size,
80
+ stride=1,
81
+ dilation=1,
82
+ padding=get_padding(kernel_size, 1),
83
+ )
84
+ )
85
+ for _ in range(len(dilation))
86
+ ]
87
+ )
88
+ self.convs2.apply(init_weights)
89
+
90
+ self.num_layers = len(self.convs1) + len(
91
+ self.convs2
92
+ ) # Total number of conv layers
93
+
94
+ # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
95
+ if self.h.get("use_cuda_kernel", False):
96
+ from alias_free_activation.cuda.activation1d import (
97
+ Activation1d as CudaActivation1d,
98
+ )
99
+
100
+ Activation1d = CudaActivation1d
101
+ else:
102
+ Activation1d = TorchActivation1d
103
+
104
+ # Activation functions
105
+ if activation == "snake":
106
+ self.activations = nn.ModuleList(
107
+ [
108
+ Activation1d(
109
+ activation=activations.Snake(
110
+ channels, alpha_logscale=h.snake_logscale
111
+ )
112
+ )
113
+ for _ in range(self.num_layers)
114
+ ]
115
+ )
116
+ elif activation == "snakebeta":
117
+ self.activations = nn.ModuleList(
118
+ [
119
+ Activation1d(
120
+ activation=activations.SnakeBeta(
121
+ channels, alpha_logscale=h.snake_logscale
122
+ )
123
+ )
124
+ for _ in range(self.num_layers)
125
+ ]
126
+ )
127
+ else:
128
+ raise NotImplementedError(
129
+ "activation incorrectly specified. check the config file and look for 'activation'."
130
+ )
131
+
132
+ def forward(self, x):
133
+ acts1, acts2 = self.activations[::2], self.activations[1::2]
134
+ for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
135
+ xt = a1(x)
136
+ xt = c1(xt)
137
+ xt = a2(xt)
138
+ xt = c2(xt)
139
+ x = xt + x
140
+
141
+ return x
142
+
143
+ def remove_weight_norm(self):
144
+ for l in self.convs1:
145
+ remove_weight_norm(l)
146
+ for l in self.convs2:
147
+ remove_weight_norm(l)
148
+
149
+
150
+ class AMPBlock2(torch.nn.Module):
151
+ """
152
+ AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
153
+ Unlike AMPBlock1, AMPBlock2 does not contain extra Conv1d layers with fixed dilation=1
154
+
155
+ Args:
156
+ h (AttrDict): Hyperparameters.
157
+ channels (int): Number of convolution channels.
158
+ kernel_size (int): Size of the convolution kernel. Default is 3.
159
+ dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).
160
+ activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.
161
+ """
162
+
163
+ def __init__(
164
+ self,
165
+ h: AttrDict,
166
+ channels: int,
167
+ kernel_size: int = 3,
168
+ dilation: tuple = (1, 3, 5),
169
+ activation: str = None,
170
+ ):
171
+ super().__init__()
172
+
173
+ self.h = h
174
+
175
+ self.convs = nn.ModuleList(
176
+ [
177
+ weight_norm(
178
+ Conv1d(
179
+ channels,
180
+ channels,
181
+ kernel_size,
182
+ stride=1,
183
+ dilation=d,
184
+ padding=get_padding(kernel_size, d),
185
+ )
186
+ )
187
+ for d in dilation
188
+ ]
189
+ )
190
+ self.convs.apply(init_weights)
191
+
192
+ self.num_layers = len(self.convs) # Total number of conv layers
193
+
194
+ # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
195
+ if self.h.get("use_cuda_kernel", False):
196
+ from alias_free_activation.cuda.activation1d import (
197
+ Activation1d as CudaActivation1d,
198
+ )
199
+
200
+ Activation1d = CudaActivation1d
201
+ else:
202
+ Activation1d = TorchActivation1d
203
+
204
+ # Activation functions
205
+ if activation == "snake":
206
+ self.activations = nn.ModuleList(
207
+ [
208
+ Activation1d(
209
+ activation=activations.Snake(
210
+ channels, alpha_logscale=h.snake_logscale
211
+ )
212
+ )
213
+ for _ in range(self.num_layers)
214
+ ]
215
+ )
216
+ elif activation == "snakebeta":
217
+ self.activations = nn.ModuleList(
218
+ [
219
+ Activation1d(
220
+ activation=activations.SnakeBeta(
221
+ channels, alpha_logscale=h.snake_logscale
222
+ )
223
+ )
224
+ for _ in range(self.num_layers)
225
+ ]
226
+ )
227
+ else:
228
+ raise NotImplementedError(
229
+ "activation incorrectly specified. check the config file and look for 'activation'."
230
+ )
231
+
232
+ def forward(self, x):
233
+ for c, a in zip(self.convs, self.activations):
234
+ xt = a(x)
235
+ xt = c(xt)
236
+ x = xt + x
237
+ return x
238
+
239
+ def remove_weight_norm(self):
240
+ for l in self.convs:
241
+ remove_weight_norm(l)
242
+
243
+
244
+ class BigVGAN(
245
+ torch.nn.Module,
246
+ PyTorchModelHubMixin,
247
+ library_name="bigvgan",
248
+ repo_url="https://github.com/NVIDIA/BigVGAN",
249
+ docs_url="https://github.com/NVIDIA/BigVGAN/blob/main/README.md",
250
+ pipeline_tag="audio-to-audio",
251
+ license="mit",
252
+ tags=["neural-vocoder", "audio-generation", "arxiv:2206.04658"],
253
+ ):
254
+ """
255
+ BigVGAN is a neural vocoder model that applies anti-aliased periodic activation for residual blocks (resblocks).
256
+ New in BigVGAN-v2: it can optionally use optimized CUDA kernels for AMP (anti-aliased multi-periodicity) blocks.
257
+
258
+ Args:
259
+ h (AttrDict): Hyperparameters.
260
+ use_cuda_kernel (bool): If set to True, loads optimized CUDA kernels for AMP. This should be used for inference only, as training is not supported with CUDA kernels.
261
+
262
+ Note:
263
+ - The `use_cuda_kernel` parameter should be used for inference only, as training with CUDA kernels is not supported.
264
+ - Ensure that the activation function is correctly specified in the hyperparameters (h.activation).
265
+ """
266
+
267
+ def __init__(self, h: AttrDict, use_cuda_kernel: bool = False):
268
+ super().__init__()
269
+ self.h = h
270
+ self.h["use_cuda_kernel"] = use_cuda_kernel
271
+
272
+ # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
273
+ if self.h.get("use_cuda_kernel", False):
274
+ from alias_free_activation.cuda.activation1d import (
275
+ Activation1d as CudaActivation1d,
276
+ )
277
+
278
+ Activation1d = CudaActivation1d
279
+ else:
280
+ Activation1d = TorchActivation1d
281
+
282
+ self.num_kernels = len(h.resblock_kernel_sizes)
283
+ self.num_upsamples = len(h.upsample_rates)
284
+
285
+ # Pre-conv
286
+ self.conv_pre = weight_norm(
287
+ Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
288
+ )
289
+
290
+ # Define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
291
+ if h.resblock == "1":
292
+ resblock_class = AMPBlock1
293
+ elif h.resblock == "2":
294
+ resblock_class = AMPBlock2
295
+ else:
296
+ raise ValueError(
297
+ f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}"
298
+ )
299
+
300
+ # Transposed conv-based upsamplers. does not apply anti-aliasing
301
+ self.ups = nn.ModuleList()
302
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
303
+ self.ups.append(
304
+ nn.ModuleList(
305
+ [
306
+ weight_norm(
307
+ ConvTranspose1d(
308
+ h.upsample_initial_channel // (2**i),
309
+ h.upsample_initial_channel // (2 ** (i + 1)),
310
+ k,
311
+ u,
312
+ padding=(k - u) // 2,
313
+ )
314
+ )
315
+ ]
316
+ )
317
+ )
318
+
319
+ # Residual blocks using anti-aliased multi-periodicity composition modules (AMP)
320
+ self.resblocks = nn.ModuleList()
321
+ for i in range(len(self.ups)):
322
+ ch = h.upsample_initial_channel // (2 ** (i + 1))
323
+ for j, (k, d) in enumerate(
324
+ zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
325
+ ):
326
+ self.resblocks.append(
327
+ resblock_class(h, ch, k, d, activation=h.activation)
328
+ )
329
+
330
+ # Post-conv
331
+ activation_post = (
332
+ activations.Snake(ch, alpha_logscale=h.snake_logscale)
333
+ if h.activation == "snake"
334
+ else (
335
+ activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
336
+ if h.activation == "snakebeta"
337
+ else None
338
+ )
339
+ )
340
+ if activation_post is None:
341
+ raise NotImplementedError(
342
+ "activation incorrectly specified. check the config file and look for 'activation'."
343
+ )
344
+
345
+ self.activation_post = Activation1d(activation=activation_post)
346
+
347
+ # Whether to use bias for the final conv_post. Default to True for backward compatibility
348
+ self.use_bias_at_final = h.get("use_bias_at_final", True)
349
+ self.conv_post = weight_norm(
350
+ Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final)
351
+ )
352
+
353
+ # Weight initialization
354
+ for i in range(len(self.ups)):
355
+ self.ups[i].apply(init_weights)
356
+ self.conv_post.apply(init_weights)
357
+
358
+ # Final tanh activation. Defaults to True for backward compatibility
359
+ self.use_tanh_at_final = h.get("use_tanh_at_final", True)
360
+
361
+ def forward(self, x):
362
+ # Pre-conv
363
+ x = self.conv_pre(x)
364
+
365
+ for i in range(self.num_upsamples):
366
+ # Upsampling
367
+ for i_up in range(len(self.ups[i])):
368
+ x = self.ups[i][i_up](x)
369
+ # AMP blocks
370
+ xs = None
371
+ for j in range(self.num_kernels):
372
+ if xs is None:
373
+ xs = self.resblocks[i * self.num_kernels + j](x)
374
+ else:
375
+ xs += self.resblocks[i * self.num_kernels + j](x)
376
+ x = xs / self.num_kernels
377
+
378
+ # Post-conv
379
+ x = self.activation_post(x)
380
+ x = self.conv_post(x)
381
+ # Final tanh activation
382
+ if self.use_tanh_at_final:
383
+ x = torch.tanh(x)
384
+ else:
385
+ x = torch.clamp(x, min=-1.0, max=1.0) # Bound the output to [-1, 1]
386
+
387
+ return x
388
+
389
+ def remove_weight_norm(self):
390
+ try:
391
+ print("Removing weight norm...")
392
+ for l in self.ups:
393
+ for l_i in l:
394
+ remove_weight_norm(l_i)
395
+ for l in self.resblocks:
396
+ l.remove_weight_norm()
397
+ remove_weight_norm(self.conv_pre)
398
+ remove_weight_norm(self.conv_post)
399
+ except ValueError:
400
+ print("[INFO] Model already removed weight norm. Skipping!")
401
+ pass
402
+
403
+ # Additional methods for huggingface_hub support
404
+ def _save_pretrained(self, save_directory: Path) -> None:
405
+ """Save weights and config.json from a Pytorch model to a local directory."""
406
+
407
+ model_path = save_directory / "bigvgan_generator.pt"
408
+ torch.save({"generator": self.state_dict()}, model_path)
409
+
410
+ config_path = save_directory / "config.json"
411
+ with open(config_path, "w") as config_file:
412
+ json.dump(self.h, config_file, indent=4)
413
+
414
+ @classmethod
415
+ def _from_pretrained(
416
+ cls,
417
+ *,
418
+ model_id: str,
419
+ revision: str,
420
+ cache_dir: str,
421
+ force_download: bool,
422
+ proxies: Optional[Dict],
423
+ resume_download: bool,
424
+ local_files_only: bool,
425
+ token: Union[str, bool, None],
426
+ map_location: str = "cpu", # Additional argument
427
+ strict: bool = False, # Additional argument
428
+ use_cuda_kernel: bool = False,
429
+ **model_kwargs,
430
+ ):
431
+ """Load Pytorch pretrained weights and return the loaded model."""
432
+
433
+ # Download and load hyperparameters (h) used by BigVGAN
434
+ if os.path.isdir(model_id):
435
+ print("Loading config.json from local directory")
436
+ config_file = os.path.join(model_id, "config.json")
437
+ else:
438
+ config_file = hf_hub_download(
439
+ repo_id=model_id,
440
+ filename="config.json",
441
+ revision=revision,
442
+ cache_dir=cache_dir,
443
+ force_download=force_download,
444
+ proxies=proxies,
445
+ resume_download=resume_download,
446
+ token=token,
447
+ local_files_only=local_files_only,
448
+ )
449
+ h = load_hparams_from_json(config_file)
450
+
451
+ # instantiate BigVGAN using h
452
+ if use_cuda_kernel:
453
+ print(
454
+ f"[WARNING] You have specified use_cuda_kernel=True during BigVGAN.from_pretrained(). Only inference is supported (training is not implemented)!"
455
+ )
456
+ print(
457
+ f"[WARNING] You need nvcc and ninja installed in your system that matches your PyTorch build is using to build the kernel. If not, the model will fail to initialize or generate incorrect waveform!"
458
+ )
459
+ print(
460
+ f"[WARNING] For detail, see the official GitHub repository: https://github.com/NVIDIA/BigVGAN?tab=readme-ov-file#using-custom-cuda-kernel-for-synthesis"
461
+ )
462
+ model = cls(h, use_cuda_kernel=use_cuda_kernel)
463
+
464
+ # Download and load pretrained generator weight
465
+ if os.path.isdir(model_id):
466
+ print("Loading weights from local directory")
467
+ model_file = os.path.join(model_id, "bigvgan_generator.pt")
468
+ else:
469
+ print(f"Loading weights from {model_id}")
470
+ model_file = hf_hub_download(
471
+ repo_id=model_id,
472
+ filename="bigvgan_generator.pt",
473
+ revision=revision,
474
+ cache_dir=cache_dir,
475
+ force_download=force_download,
476
+ proxies=proxies,
477
+ resume_download=resume_download,
478
+ token=token,
479
+ local_files_only=local_files_only,
480
+ )
481
+
482
+ checkpoint_dict = torch.load(model_file, map_location=map_location)
483
+
484
+ try:
485
+ model.load_state_dict(checkpoint_dict["generator"])
486
+ except RuntimeError:
487
+ print(
488
+ f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
489
+ )
490
+ model.remove_weight_norm()
491
+ model.load_state_dict(checkpoint_dict["generator"])
492
+
493
+ return model
src/third_party/BigVGAN/configs/bigvgan_22khz_80band.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "resblock": "1",
3
+ "num_gpus": 0,
4
+ "batch_size": 32,
5
+ "learning_rate": 0.0001,
6
+ "adam_b1": 0.8,
7
+ "adam_b2": 0.99,
8
+ "lr_decay": 0.9999996,
9
+ "seed": 1234,
10
+
11
+ "upsample_rates": [4,4,2,2,2,2],
12
+ "upsample_kernel_sizes": [8,8,4,4,4,4],
13
+ "upsample_initial_channel": 1536,
14
+ "resblock_kernel_sizes": [3,7,11],
15
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16
+
17
+ "activation": "snakebeta",
18
+ "snake_logscale": true,
19
+
20
+ "resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]],
21
+ "mpd_reshapes": [2, 3, 5, 7, 11],
22
+ "use_spectral_norm": false,
23
+ "discriminator_channel_mult": 1,
24
+
25
+ "segment_size": 8192,
26
+ "num_mels": 80,
27
+ "num_freq": 1025,
28
+ "n_fft": 1024,
29
+ "hop_size": 256,
30
+ "win_size": 1024,
31
+
32
+ "sampling_rate": 22050,
33
+
34
+ "fmin": 0,
35
+ "fmax": 8000,
36
+ "fmax_for_loss": null,
37
+
38
+ "num_workers": 4,
39
+
40
+ "dist_config": {
41
+ "dist_backend": "nccl",
42
+ "dist_url": "tcp://localhost:54321",
43
+ "world_size": 1
44
+ }
45
+ }
src/third_party/BigVGAN/configs/bigvgan_24khz_100band.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "resblock": "1",
3
+ "num_gpus": 0,
4
+ "batch_size": 32,
5
+ "learning_rate": 0.0001,
6
+ "adam_b1": 0.8,
7
+ "adam_b2": 0.99,
8
+ "lr_decay": 0.9999996,
9
+ "seed": 1234,
10
+
11
+ "upsample_rates": [4,4,2,2,2,2],
12
+ "upsample_kernel_sizes": [8,8,4,4,4,4],
13
+ "upsample_initial_channel": 1536,
14
+ "resblock_kernel_sizes": [3,7,11],
15
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16
+
17
+ "activation": "snakebeta",
18
+ "snake_logscale": true,
19
+
20
+ "resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]],
21
+ "mpd_reshapes": [2, 3, 5, 7, 11],
22
+ "use_spectral_norm": false,
23
+ "discriminator_channel_mult": 1,
24
+
25
+ "segment_size": 8192,
26
+ "num_mels": 100,
27
+ "num_freq": 1025,
28
+ "n_fft": 1024,
29
+ "hop_size": 256,
30
+ "win_size": 1024,
31
+
32
+ "sampling_rate": 24000,
33
+
34
+ "fmin": 0,
35
+ "fmax": 12000,
36
+ "fmax_for_loss": null,
37
+
38
+ "num_workers": 4,
39
+
40
+ "dist_config": {
41
+ "dist_backend": "nccl",
42
+ "dist_url": "tcp://localhost:54321",
43
+ "world_size": 1
44
+ }
45
+ }
src/third_party/BigVGAN/configs/bigvgan_base_22khz_80band.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "resblock": "1",
3
+ "num_gpus": 0,
4
+ "batch_size": 32,
5
+ "learning_rate": 0.0001,
6
+ "adam_b1": 0.8,
7
+ "adam_b2": 0.99,
8
+ "lr_decay": 0.9999996,
9
+ "seed": 1234,
10
+
11
+ "upsample_rates": [8,8,2,2],
12
+ "upsample_kernel_sizes": [16,16,4,4],
13
+ "upsample_initial_channel": 512,
14
+ "resblock_kernel_sizes": [3,7,11],
15
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16
+
17
+ "activation": "snakebeta",
18
+ "snake_logscale": true,
19
+
20
+ "resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]],
21
+ "mpd_reshapes": [2, 3, 5, 7, 11],
22
+ "use_spectral_norm": false,
23
+ "discriminator_channel_mult": 1,
24
+
25
+ "segment_size": 8192,
26
+ "num_mels": 80,
27
+ "num_freq": 1025,
28
+ "n_fft": 1024,
29
+ "hop_size": 256,
30
+ "win_size": 1024,
31
+
32
+ "sampling_rate": 22050,
33
+
34
+ "fmin": 0,
35
+ "fmax": 8000,
36
+ "fmax_for_loss": null,
37
+
38
+ "num_workers": 4,
39
+
40
+ "dist_config": {
41
+ "dist_backend": "nccl",
42
+ "dist_url": "tcp://localhost:54321",
43
+ "world_size": 1
44
+ }
45
+ }
src/third_party/BigVGAN/configs/bigvgan_base_24khz_100band.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "resblock": "1",
3
+ "num_gpus": 0,
4
+ "batch_size": 32,
5
+ "learning_rate": 0.0001,
6
+ "adam_b1": 0.8,
7
+ "adam_b2": 0.99,
8
+ "lr_decay": 0.9999996,
9
+ "seed": 1234,
10
+
11
+ "upsample_rates": [8,8,2,2],
12
+ "upsample_kernel_sizes": [16,16,4,4],
13
+ "upsample_initial_channel": 512,
14
+ "resblock_kernel_sizes": [3,7,11],
15
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16
+
17
+ "activation": "snakebeta",
18
+ "snake_logscale": true,
19
+
20
+ "resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]],
21
+ "mpd_reshapes": [2, 3, 5, 7, 11],
22
+ "use_spectral_norm": false,
23
+ "discriminator_channel_mult": 1,
24
+
25
+ "segment_size": 8192,
26
+ "num_mels": 100,
27
+ "num_freq": 1025,
28
+ "n_fft": 1024,
29
+ "hop_size": 256,
30
+ "win_size": 1024,
31
+
32
+ "sampling_rate": 24000,
33
+
34
+ "fmin": 0,
35
+ "fmax": 12000,
36
+ "fmax_for_loss": null,
37
+
38
+ "num_workers": 4,
39
+
40
+ "dist_config": {
41
+ "dist_backend": "nccl",
42
+ "dist_url": "tcp://localhost:54321",
43
+ "world_size": 1
44
+ }
45
+ }
src/third_party/BigVGAN/configs/bigvgan_v2_22khz_80band_256x.json ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "resblock": "1",
3
+ "num_gpus": 0,
4
+ "batch_size": 4,
5
+ "learning_rate": 0.0001,
6
+ "adam_b1": 0.8,
7
+ "adam_b2": 0.99,
8
+ "lr_decay": 0.9999996,
9
+ "seed": 1234,
10
+
11
+ "upsample_rates": [4,4,2,2,2,2],
12
+ "upsample_kernel_sizes": [8,8,4,4,4,4],
13
+ "upsample_initial_channel": 1536,
14
+ "resblock_kernel_sizes": [3,7,11],
15
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16
+
17
+ "use_tanh_at_final": false,
18
+ "use_bias_at_final": false,
19
+
20
+ "activation": "snakebeta",
21
+ "snake_logscale": true,
22
+
23
+ "use_cqtd_instead_of_mrd": true,
24
+ "cqtd_filters": 128,
25
+ "cqtd_max_filters": 1024,
26
+ "cqtd_filters_scale": 1,
27
+ "cqtd_dilations": [1, 2, 4],
28
+ "cqtd_hop_lengths": [512, 256, 256],
29
+ "cqtd_n_octaves": [9, 9, 9],
30
+ "cqtd_bins_per_octaves": [24, 36, 48],
31
+
32
+ "mpd_reshapes": [2, 3, 5, 7, 11],
33
+ "use_spectral_norm": false,
34
+ "discriminator_channel_mult": 1,
35
+
36
+ "use_multiscale_melloss": true,
37
+ "lambda_melloss": 15,
38
+
39
+ "clip_grad_norm": 500,
40
+
41
+ "segment_size": 65536,
42
+ "num_mels": 80,
43
+ "num_freq": 1025,
44
+ "n_fft": 1024,
45
+ "hop_size": 256,
46
+ "win_size": 1024,
47
+
48
+ "sampling_rate": 22050,
49
+
50
+ "fmin": 0,
51
+ "fmax": null,
52
+ "fmax_for_loss": null,
53
+
54
+ "num_workers": 4,
55
+
56
+ "dist_config": {
57
+ "dist_backend": "nccl",
58
+ "dist_url": "tcp://localhost:54321",
59
+ "world_size": 1
60
+ }
61
+ }
src/third_party/BigVGAN/configs/bigvgan_v2_22khz_80band_fmax8k_256x.json ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "resblock": "1",
3
+ "num_gpus": 0,
4
+ "batch_size": 4,
5
+ "learning_rate": 0.0001,
6
+ "adam_b1": 0.8,
7
+ "adam_b2": 0.99,
8
+ "lr_decay": 0.9999996,
9
+ "seed": 1234,
10
+
11
+ "upsample_rates": [4,4,2,2,2,2],
12
+ "upsample_kernel_sizes": [8,8,4,4,4,4],
13
+ "upsample_initial_channel": 1536,
14
+ "resblock_kernel_sizes": [3,7,11],
15
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16
+
17
+ "use_tanh_at_final": false,
18
+ "use_bias_at_final": false,
19
+
20
+ "activation": "snakebeta",
21
+ "snake_logscale": true,
22
+
23
+ "use_cqtd_instead_of_mrd": true,
24
+ "cqtd_filters": 128,
25
+ "cqtd_max_filters": 1024,
26
+ "cqtd_filters_scale": 1,
27
+ "cqtd_dilations": [1, 2, 4],
28
+ "cqtd_hop_lengths": [512, 256, 256],
29
+ "cqtd_n_octaves": [9, 9, 9],
30
+ "cqtd_bins_per_octaves": [24, 36, 48],
31
+
32
+ "mpd_reshapes": [2, 3, 5, 7, 11],
33
+ "use_spectral_norm": false,
34
+ "discriminator_channel_mult": 1,
35
+
36
+ "use_multiscale_melloss": true,
37
+ "lambda_melloss": 15,
38
+
39
+ "clip_grad_norm": 500,
40
+
41
+ "segment_size": 65536,
42
+ "num_mels": 80,
43
+ "num_freq": 1025,
44
+ "n_fft": 1024,
45
+ "hop_size": 256,
46
+ "win_size": 1024,
47
+
48
+ "sampling_rate": 22050,
49
+
50
+ "fmin": 0,
51
+ "fmax": 8000,
52
+ "fmax_for_loss": null,
53
+
54
+ "num_workers": 4,
55
+
56
+ "dist_config": {
57
+ "dist_backend": "nccl",
58
+ "dist_url": "tcp://localhost:54321",
59
+ "world_size": 1
60
+ }
61
+ }
src/third_party/BigVGAN/configs/bigvgan_v2_24khz_100band_256x.json ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "resblock": "1",
3
+ "num_gpus": 0,
4
+ "batch_size": 4,
5
+ "learning_rate": 0.0001,
6
+ "adam_b1": 0.8,
7
+ "adam_b2": 0.99,
8
+ "lr_decay": 0.9999996,
9
+ "seed": 1234,
10
+
11
+ "upsample_rates": [4,4,2,2,2,2],
12
+ "upsample_kernel_sizes": [8,8,4,4,4,4],
13
+ "upsample_initial_channel": 1536,
14
+ "resblock_kernel_sizes": [3,7,11],
15
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16
+
17
+ "use_tanh_at_final": false,
18
+ "use_bias_at_final": false,
19
+
20
+ "activation": "snakebeta",
21
+ "snake_logscale": true,
22
+
23
+ "use_cqtd_instead_of_mrd": true,
24
+ "cqtd_filters": 128,
25
+ "cqtd_max_filters": 1024,
26
+ "cqtd_filters_scale": 1,
27
+ "cqtd_dilations": [1, 2, 4],
28
+ "cqtd_hop_lengths": [512, 256, 256],
29
+ "cqtd_n_octaves": [9, 9, 9],
30
+ "cqtd_bins_per_octaves": [24, 36, 48],
31
+
32
+ "mpd_reshapes": [2, 3, 5, 7, 11],
33
+ "use_spectral_norm": false,
34
+ "discriminator_channel_mult": 1,
35
+
36
+ "use_multiscale_melloss": true,
37
+ "lambda_melloss": 15,
38
+
39
+ "clip_grad_norm": 500,
40
+
41
+ "segment_size": 65536,
42
+ "num_mels": 100,
43
+ "num_freq": 1025,
44
+ "n_fft": 1024,
45
+ "hop_size": 256,
46
+ "win_size": 1024,
47
+
48
+ "sampling_rate": 24000,
49
+
50
+ "fmin": 0,
51
+ "fmax": null,
52
+ "fmax_for_loss": null,
53
+
54
+ "num_workers": 4,
55
+
56
+ "dist_config": {
57
+ "dist_backend": "nccl",
58
+ "dist_url": "tcp://localhost:54321",
59
+ "world_size": 1
60
+ }
61
+ }
src/third_party/BigVGAN/configs/bigvgan_v2_44khz_128band_256x.json ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "resblock": "1",
3
+ "num_gpus": 0,
4
+ "batch_size": 4,
5
+ "learning_rate": 0.0001,
6
+ "adam_b1": 0.8,
7
+ "adam_b2": 0.99,
8
+ "lr_decay": 0.9999996,
9
+ "seed": 1234,
10
+
11
+ "upsample_rates": [4,4,2,2,2,2],
12
+ "upsample_kernel_sizes": [8,8,4,4,4,4],
13
+ "upsample_initial_channel": 1536,
14
+ "resblock_kernel_sizes": [3,7,11],
15
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16
+
17
+ "use_tanh_at_final": false,
18
+ "use_bias_at_final": false,
19
+
20
+ "activation": "snakebeta",
21
+ "snake_logscale": true,
22
+
23
+ "use_cqtd_instead_of_mrd": true,
24
+ "cqtd_filters": 128,
25
+ "cqtd_max_filters": 1024,
26
+ "cqtd_filters_scale": 1,
27
+ "cqtd_dilations": [1, 2, 4],
28
+ "cqtd_hop_lengths": [512, 256, 256],
29
+ "cqtd_n_octaves": [9, 9, 9],
30
+ "cqtd_bins_per_octaves": [24, 36, 48],
31
+
32
+ "mpd_reshapes": [2, 3, 5, 7, 11],
33
+ "use_spectral_norm": false,
34
+ "discriminator_channel_mult": 1,
35
+
36
+ "use_multiscale_melloss": true,
37
+ "lambda_melloss": 15,
38
+
39
+ "clip_grad_norm": 500,
40
+
41
+ "segment_size": 65536,
42
+ "num_mels": 128,
43
+ "num_freq": 1025,
44
+ "n_fft": 1024,
45
+ "hop_size": 256,
46
+ "win_size": 1024,
47
+
48
+ "sampling_rate": 44100,
49
+
50
+ "fmin": 0,
51
+ "fmax": null,
52
+ "fmax_for_loss": null,
53
+
54
+ "num_workers": 4,
55
+
56
+ "dist_config": {
57
+ "dist_backend": "nccl",
58
+ "dist_url": "tcp://localhost:54321",
59
+ "world_size": 1
60
+ }
61
+ }
src/third_party/BigVGAN/configs/bigvgan_v2_44khz_128band_512x.json ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "resblock": "1",
3
+ "num_gpus": 0,
4
+ "batch_size": 4,
5
+ "learning_rate": 0.0001,
6
+ "adam_b1": 0.8,
7
+ "adam_b2": 0.99,
8
+ "lr_decay": 0.9999996,
9
+ "seed": 1234,
10
+
11
+ "upsample_rates": [8,4,2,2,2,2],
12
+ "upsample_kernel_sizes": [16,8,4,4,4,4],
13
+ "upsample_initial_channel": 1536,
14
+ "resblock_kernel_sizes": [3,7,11],
15
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16
+
17
+ "use_tanh_at_final": false,
18
+ "use_bias_at_final": false,
19
+
20
+ "activation": "snakebeta",
21
+ "snake_logscale": true,
22
+
23
+ "use_cqtd_instead_of_mrd": true,
24
+ "cqtd_filters": 128,
25
+ "cqtd_max_filters": 1024,
26
+ "cqtd_filters_scale": 1,
27
+ "cqtd_dilations": [1, 2, 4],
28
+ "cqtd_hop_lengths": [512, 256, 256],
29
+ "cqtd_n_octaves": [9, 9, 9],
30
+ "cqtd_bins_per_octaves": [24, 36, 48],
31
+
32
+ "mpd_reshapes": [2, 3, 5, 7, 11],
33
+ "use_spectral_norm": false,
34
+ "discriminator_channel_mult": 1,
35
+
36
+ "use_multiscale_melloss": true,
37
+ "lambda_melloss": 15,
38
+
39
+ "clip_grad_norm": 500,
40
+
41
+ "segment_size": 65536,
42
+ "num_mels": 128,
43
+ "num_freq": 2049,
44
+ "n_fft": 2048,
45
+ "hop_size": 512,
46
+ "win_size": 2048,
47
+
48
+ "sampling_rate": 44100,
49
+
50
+ "fmin": 0,
51
+ "fmax": null,
52
+ "fmax_for_loss": null,
53
+
54
+ "num_workers": 4,
55
+
56
+ "dist_config": {
57
+ "dist_backend": "nccl",
58
+ "dist_url": "tcp://localhost:54321",
59
+ "world_size": 1
60
+ }
61
+ }