“fred-dev” commited on
Commit
2ed72d6
·
1 Parent(s): 319a18d

Lets build again

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +21 -0
  2. defaults.ini +56 -0
  3. model_config_float_conditioning_dit_all.json +208 -0
  4. pyproject.toml +3 -0
  5. requirements.txt +34 -0
  6. run_gradio.py +31 -0
  7. run_tests.py +44 -0
  8. scripts/ds_zero_to_pl_ckpt.py +14 -0
  9. setup.py +46 -0
  10. stable_audio_tools/__init__.py +2 -0
  11. stable_audio_tools/configs/dataset_configs/custom_metadata/custom_md_example.py +4 -0
  12. stable_audio_tools/configs/dataset_configs/local_training_example.json +11 -0
  13. stable_audio_tools/configs/dataset_configs/s3_wds_example.json +10 -0
  14. stable_audio_tools/configs/model_configs/autoencoders/dac_2048_32_vae.json +71 -0
  15. stable_audio_tools/configs/model_configs/autoencoders/encodec_musicgen_rvq.json +88 -0
  16. stable_audio_tools/configs/model_configs/autoencoders/stable_audio_1_0_vae.json +111 -0
  17. stable_audio_tools/configs/model_configs/autoencoders/stable_audio_2_0_vae.json +122 -0
  18. stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_base.json +18 -0
  19. stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_base_16k.json +18 -0
  20. stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_base_44k.json +18 -0
  21. stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_large.json +18 -0
  22. stable_audio_tools/configs/model_configs/txt2audio/musicgen_small_finetune.json +22 -0
  23. stable_audio_tools/configs/model_configs/txt2audio/stable_audio_1_0.json +107 -0
  24. stable_audio_tools/configs/model_configs/txt2audio/stable_audio_2_0.json +127 -0
  25. stable_audio_tools/data/__init__.py +0 -0
  26. stable_audio_tools/data/dataset.py +597 -0
  27. stable_audio_tools/data/utils.py +96 -0
  28. stable_audio_tools/inference/__init__.py +0 -0
  29. stable_audio_tools/inference/generation.py +243 -0
  30. stable_audio_tools/inference/sampling.py +170 -0
  31. stable_audio_tools/inference/utils.py +35 -0
  32. stable_audio_tools/interface/__init__.py +0 -0
  33. stable_audio_tools/interface/gradio.py +788 -0
  34. stable_audio_tools/interface/testing.py +409 -0
  35. stable_audio_tools/models/__init__.py +1 -0
  36. stable_audio_tools/models/adp.py +1588 -0
  37. stable_audio_tools/models/autoencoders.py +800 -0
  38. stable_audio_tools/models/blocks.py +339 -0
  39. stable_audio_tools/models/bottleneck.py +326 -0
  40. stable_audio_tools/models/conditioners.py +558 -0
  41. stable_audio_tools/models/diffusion.py +678 -0
  42. stable_audio_tools/models/diffusion_prior.py +151 -0
  43. stable_audio_tools/models/discriminators.py +546 -0
  44. stable_audio_tools/models/dit.py +358 -0
  45. stable_audio_tools/models/factory.py +149 -0
  46. stable_audio_tools/models/lm.py +531 -0
  47. stable_audio_tools/models/lm_backbone.py +157 -0
  48. stable_audio_tools/models/local_attention.py +278 -0
  49. stable_audio_tools/models/musicgen.py +161 -0
  50. stable_audio_tools/models/pqmf.py +393 -0
app.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from stable_audio_tools import get_pretrained_model
3
+ from stable_audio_tools.interface.gradio import create_ui
4
+ import json
5
+
6
+ import torch
7
+
8
+ def run():
9
+ torch.manual_seed(42)
10
+
11
+ interface = create_ui(
12
+ model_config_path = "model_config_float_conditioning_dit_all.json",
13
+ ckpt_path="epoch=1292-step=602500.ckpt",
14
+ model_half=False
15
+ )
16
+ interface.queue()
17
+ interface.launch(share=True)
18
+
19
+ if __name__ == "__main__":
20
+ run()
21
+
defaults.ini ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ [DEFAULTS]
3
+
4
+ #name of the run
5
+ name = stable_audio_tools
6
+
7
+ # the batch size
8
+ batch_size = 8
9
+
10
+ # number of GPUs to use for training
11
+ num_gpus = 1
12
+
13
+ # number of nodes to use for training
14
+ num_nodes = 1
15
+
16
+ # Multi-GPU strategy for PyTorch Lightning
17
+ strategy = ""
18
+
19
+ # Precision to use for training
20
+ precision = "16-mixed"
21
+
22
+ # number of CPU workers for the DataLoader
23
+ num_workers = 8
24
+
25
+ # the random seed
26
+ seed = 42
27
+
28
+ # Batches for gradient accumulation
29
+ accum_batches = 1
30
+
31
+ # Number of steps between checkpoints
32
+ checkpoint_every = 10000
33
+
34
+ # trainer checkpoint file to restart training from
35
+ ckpt_path = ''
36
+
37
+ # model checkpoint file to start a new training run from
38
+ pretrained_ckpt_path = ''
39
+
40
+ # Checkpoint path for the pretransform model if needed
41
+ pretransform_ckpt_path = ''
42
+
43
+ # configuration model specifying model hyperparameters
44
+ model_config = ''
45
+
46
+ # configuration for datasets
47
+ dataset_config = ''
48
+
49
+ # directory to save the checkpoints in
50
+ save_dir = ''
51
+
52
+ # gradient_clip_val passed into PyTorch Lightning Trainer
53
+ gradient_clip_val = 0.0
54
+
55
+ # remove the weight norm from the pretransform model
56
+ remove_pretransform_weight_norm = ''
model_config_float_conditioning_dit_all.json ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "diffusion_cond",
3
+ "sample_size": 1048576,
4
+ "sample_rate": 44100,
5
+ "audio_channels": 1,
6
+ "model": {
7
+ "pretransform": {
8
+ "type": "autoencoder",
9
+ "iterate_batch": true,
10
+ "config": {
11
+ "encoder": {
12
+ "type": "dac",
13
+ "config": {
14
+ "in_channels": 1,
15
+ "latent_dim": 128,
16
+ "d_model": 128,
17
+ "strides": [4, 4, 8, 8]
18
+ }
19
+ },
20
+ "decoder": {
21
+ "type": "dac",
22
+ "config": {
23
+ "out_channels": 1,
24
+ "latent_dim": 64,
25
+ "channels": 1536,
26
+ "rates": [8, 8, 4, 4]
27
+ }
28
+ },
29
+ "bottleneck": {
30
+ "type": "vae"
31
+ },
32
+ "latent_dim": 64,
33
+ "downsampling_ratio": 1024,
34
+ "io_channels": 1
35
+ }
36
+ },
37
+ "conditioning": {
38
+ "configs": [
39
+ {
40
+ "id": "latitude",
41
+ "type": "number",
42
+ "config": {
43
+ "min_val": -54.617412,
44
+ "max_val": -10.13994
45
+ }
46
+ },
47
+ {
48
+ "id": "longitude",
49
+ "type": "number",
50
+ "config": {
51
+ "min_val": 96.8233,
52
+ "max_val": 167.9619
53
+ }
54
+ },
55
+ {
56
+ "id": "temperature",
57
+ "type": "number",
58
+ "config": {
59
+ "min_val": -10.0,
60
+ "max_val": 55.0
61
+ }
62
+ },
63
+ {
64
+ "id": "humidity",
65
+ "type": "number",
66
+ "config": {
67
+ "min_val": 1,
68
+ "max_val": 100.0
69
+ }
70
+ },
71
+ {
72
+ "id": "wind_speed",
73
+ "type": "number",
74
+ "config": {
75
+ "min_val": 0,
76
+ "max_val": 50.0
77
+ }
78
+ },
79
+ {
80
+ "id": "pressure",
81
+ "type": "number",
82
+ "config": {
83
+ "min_val": 800.0,
84
+ "max_val": 1200.0
85
+ }
86
+ },
87
+ {
88
+ "id": "minutes_of_day",
89
+ "type": "number",
90
+ "config": {
91
+ "min_val": 0,
92
+ "max_val": 1439
93
+ }
94
+ },
95
+ {
96
+ "id": "day_of_year",
97
+ "type": "number",
98
+ "config": {
99
+ "min_val": 1,
100
+ "max_val": 366
101
+ }
102
+ },
103
+ {
104
+ "id": "seconds_start",
105
+ "type": "number",
106
+ "config": {
107
+ "min_val": 0,
108
+ "max_val": 512
109
+ }
110
+ },
111
+ {
112
+ "id": "seconds_total",
113
+ "type": "number",
114
+ "config": {
115
+ "min_val": 0,
116
+ "max_val": 512
117
+ }
118
+ }
119
+ ],
120
+ "cond_dim": 768
121
+ },
122
+ "diffusion": {
123
+ "cross_attention_cond_ids": ["latitude", "longitude", "temperature", "humidity", "wind_speed", "pressure", "minutes_of_day", "day_of_year","seconds_start", "seconds_total"],
124
+ "global_cond_ids": ["seconds_start", "seconds_total"],
125
+ "type": "dit",
126
+ "config": {
127
+ "io_channels": 64,
128
+ "embed_dim": 768,
129
+ "depth": 24,
130
+ "num_heads": 24,
131
+ "cond_token_dim": 768,
132
+ "global_cond_dim": 1536,
133
+ "project_cond_tokens": false,
134
+ "transformer_type": "continuous_transformer"
135
+ }
136
+ },
137
+ "io_channels": 64
138
+ },
139
+
140
+ "training": {
141
+ "use_ema": true,
142
+ "log_loss_info": false,
143
+ "optimizer_configs": {
144
+ "diffusion": {
145
+ "optimizer": {
146
+ "type": "AdamW",
147
+ "config": {
148
+ "lr": 5e-5,
149
+ "betas": [0.9, 0.999],
150
+ "weight_decay": 1e-3
151
+ }
152
+ },
153
+ "scheduler": {
154
+ "type": "InverseLR",
155
+ "config": {
156
+ "inv_gamma": 1000000,
157
+ "power": 0.5,
158
+ "warmup": 0.99
159
+ }
160
+ }
161
+ }
162
+ },
163
+ "demo": {
164
+ "demo_every": 2500,
165
+ "demo_steps": 100,
166
+ "num_demos": 3,
167
+ "demo_cfg_scales": [3, 5, 7],
168
+ "demo_cond": [
169
+ {
170
+ "latitude": -24.005512,
171
+ "longitude": 133.368348,
172
+ "temperature": 25.5,
173
+ "humidity": 60,
174
+ "wind_speed": 8,
175
+ "pressure": 1000,
176
+ "minutes_of_day": 400,
177
+ "day_of_year": 110,
178
+ "seconds_start": 0,
179
+ "seconds_total": 22
180
+ },
181
+ {
182
+ "latitude": -26.987815,
183
+ "longitude": 153.129068,
184
+ "temperature": 31.5,
185
+ "humidity": 70,
186
+ "wind_speed": 12,
187
+ "pressure": 1010,
188
+ "minutes_of_day": 600,
189
+ "day_of_year": 57,
190
+ "seconds_start": 0,
191
+ "seconds_total": 22
192
+ },
193
+ {
194
+ "latitude": -12.546364,
195
+ "longitude": 130.919605,
196
+ "temperature": 28.5,
197
+ "humidity": 60,
198
+ "wind_speed": 18,
199
+ "pressure": 1015,
200
+ "minutes_of_day": 1140,
201
+ "day_of_year": 280,
202
+ "seconds_start": 0,
203
+ "seconds_total": 22
204
+ }
205
+ ]
206
+ }
207
+ }
208
+ }
pyproject.toml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools"]
3
+ build-backend = "setuptools.build_meta"
requirements.txt ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu113
2
+ torch
3
+ aeiou
4
+ alias-free-torch
5
+ auraloss
6
+ descript-audio-codec
7
+ einops
8
+ einops-exts
9
+ ema-pytorch
10
+ encodec
11
+ gradio
12
+ huggingface_hub
13
+ importlib-resources
14
+ k-diffusion
15
+ laion-clap
16
+ local-attention
17
+ pandas
18
+ pedalboard
19
+ prefigure
20
+ pytorch_lightning
21
+ PyWavelets
22
+ safetensors
23
+ sentencepiece
24
+ s3fs
25
+ torch
26
+ torchaudio
27
+ torchmetrics
28
+ tqdm
29
+ transformers
30
+ v-diffusion-pytorch
31
+ vector-quantize-pytorch
32
+ wandb
33
+ webdataset
34
+ x_transformers
run_gradio.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from stable_audio_tools import get_pretrained_model
2
+ from stable_audio_tools.interface.gradio import create_ui
3
+ import json
4
+
5
+ import torch
6
+
7
+ def main(args):
8
+ torch.manual_seed(42)
9
+
10
+ interface = create_ui(
11
+ model_config_path = args.model_config,
12
+ ckpt_path=args.ckpt_path,
13
+ pretrained_name=args.pretrained_name,
14
+ pretransform_ckpt_path=args.pretransform_ckpt_path,
15
+ model_half=args.model_half
16
+ )
17
+ interface.queue()
18
+ interface.launch(share=True, auth=(args.username, args.password) if args.username is not None else None)
19
+
20
+ if __name__ == "__main__":
21
+ import argparse
22
+ parser = argparse.ArgumentParser(description='Run gradio interface')
23
+ parser.add_argument('--pretrained-name', type=str, help='Name of pretrained model', required=False)
24
+ parser.add_argument('--model-config', type=str, help='Path to model config', required=False)
25
+ parser.add_argument('--ckpt-path', type=str, help='Path to model checkpoint', required=False)
26
+ parser.add_argument('--pretransform-ckpt-path', type=str, help='Optional to model pretransform checkpoint', required=False)
27
+ parser.add_argument('--username', type=str, help='Gradio username', required=False)
28
+ parser.add_argument('--password', type=str, help='Gradio password', required=False)
29
+ parser.add_argument('--model-half', action='store_true', help='Whether to use half precision', required=False)
30
+ args = parser.parse_args()
31
+ main(args)
run_tests.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from stable_audio_tools import get_pretrained_model
2
+ from stable_audio_tools.interface.testing import runTests
3
+ print(runTests) # Check if it prints a function reference
4
+
5
+
6
+ import torch
7
+
8
+ def main(args):
9
+ torch.manual_seed(42)
10
+ runTests(model_config_path = args.model_config,
11
+ ckpt_path=args.ckpt_path,
12
+ pretrained_name=args.pretrained_name,
13
+ pretransform_ckpt_path=args.pretransform_ckpt_path,
14
+ model_half=args.model_half,
15
+ output_dir=args.output_dir,
16
+ json_dir=args.json_dir
17
+ )
18
+
19
+
20
+
21
+
22
+
23
+ if __name__ == "__main__":
24
+ import argparse
25
+ import sys
26
+ parser = argparse.ArgumentParser(description='Run generation tests')
27
+ parser.add_argument('--pretrained-name', type=str, help='Name of pretrained model', required=False)
28
+ parser.add_argument('--model-config', type=str, help='Path to model config', required=False)
29
+ parser.add_argument('--ckpt-path', type=str, help='Path to model checkpoint', required=False)
30
+ parser.add_argument('--pretransform-ckpt-path', type=str, help='Optional to model pretransform checkpoint', required=False)
31
+ parser.add_argument('--model-half', action='store_true', help='Whether to use half precision', required=False)
32
+ parser.add_argument('--output-dir', type=str, help='Path to output directory', required=True)
33
+ parser.add_argument('--json-dir', type=str, help='Path to directory containing JSON files', required=True)
34
+ print("Running tests")
35
+
36
+ print("Arguments provided:", sys.argv[1:])
37
+
38
+ args = parser.parse_args()
39
+ print("Parsed arguments:", args)
40
+ main(args)
41
+
42
+
43
+
44
+
scripts/ds_zero_to_pl_ckpt.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from lightning.pytorch.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict
3
+
4
+ if __name__ == "__main__":
5
+
6
+ parser = argparse.ArgumentParser()
7
+ parser.add_argument("--save_path", type=str, help="Path to the zero checkpoint")
8
+ parser.add_argument("--output_path", type=str, help="Path to the output checkpoint", default="lightning_model.pt")
9
+ args = parser.parse_args()
10
+
11
+ # lightning deepspeed has saved a directory instead of a file
12
+ save_path = args.save_path
13
+ output_path = args.output_path
14
+ convert_zero_checkpoint_to_fp32_state_dict(save_path, output_path)
setup.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name='stable-audio-tools',
5
+ version='0.0.12',
6
+ url='https://github.com/Stability-AI/stable-audio-tools.git',
7
+ author='Stability AI',
8
+ description='Training and inference tools for generative audio models from Stability AI',
9
+ packages=find_packages(),
10
+ install_requires=[
11
+ 'audiocraft==1.0.0',
12
+ 'aeiou==0.0.20',
13
+ 'alias-free-torch==0.0.6',
14
+ 'auraloss==0.4.0',
15
+ 'descript-audio-codec==1.0.0',
16
+ 'einops==0.7.0',
17
+ 'einops-exts==0.0.4',
18
+ 'ema-pytorch==0.2.3',
19
+ 'encodec==0.1.1',
20
+ 'flash-attn>=2.5.0',
21
+ 'gradio>=3.42.0',
22
+ 'huggingface_hub',
23
+ 'importlib-resources==5.12.0',
24
+ 'k-diffusion==0.1.1',
25
+ 'laion-clap==1.1.4',
26
+ 'local-attention==1.8.6',
27
+ 'pandas==2.0.2',
28
+ 'pedalboard==0.7.4',
29
+ 'prefigure==0.0.9',
30
+ 'pytorch_lightning==2.1.0',
31
+ 'PyWavelets==1.4.1',
32
+ 'safetensors',
33
+ 'sentencepiece==0.1.99',
34
+ 's3fs',
35
+ 'torch>=2.0.1',
36
+ 'torchaudio>=2.0.2',
37
+ 'torchmetrics==0.11.4',
38
+ 'tqdm',
39
+ 'transformers==4.33.3',
40
+ 'v-diffusion-pytorch==0.0.2',
41
+ 'vector-quantize-pytorch==1.9.14',
42
+ 'wandb==0.15.4',
43
+ 'webdataset==0.2.48',
44
+ 'x-transformers<1.27.0'
45
+ ],
46
+ )
stable_audio_tools/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .models.factory import create_model_from_config, create_model_from_config_path
2
+ from .models.pretrained import get_pretrained_model
stable_audio_tools/configs/dataset_configs/custom_metadata/custom_md_example.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ def get_custom_metadata(info, audio):
2
+
3
+ # Use relative path as the prompt
4
+ return {"prompt": info["relpath"]}
stable_audio_tools/configs/dataset_configs/local_training_example.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_type": "audio_dir",
3
+ "datasets": [
4
+ {
5
+ "id": "my_audio",
6
+ "path": "/path/to/audio/dataset/"
7
+ }
8
+ ],
9
+ "custom_metadata_module": "/path/to/custom_metadata/custom_md_example.py",
10
+ "random_crop": true
11
+ }
stable_audio_tools/configs/dataset_configs/s3_wds_example.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_type": "s3",
3
+ "datasets": [
4
+ {
5
+ "id": "s3-test",
6
+ "s3_path": "s3://my-bucket/datasets/webdataset/audio/"
7
+ }
8
+ ],
9
+ "random_crop": true
10
+ }
stable_audio_tools/configs/model_configs/autoencoders/dac_2048_32_vae.json ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "autoencoder",
3
+ "sample_size": 65536,
4
+ "sample_rate": 44100,
5
+ "audio_channels": 1,
6
+ "model": {
7
+ "encoder": {
8
+ "type": "dac",
9
+ "config": {
10
+ "latent_dim": 64,
11
+ "d_model": 128,
12
+ "strides": [4, 8, 8, 8]
13
+ }
14
+ },
15
+ "decoder": {
16
+ "type": "dac",
17
+ "config": {
18
+ "latent_dim": 32,
19
+ "channels": 1536,
20
+ "rates": [8, 8, 8, 4]
21
+ }
22
+ },
23
+ "bottleneck": {
24
+ "type": "vae"
25
+ },
26
+ "latent_dim": 32,
27
+ "downsampling_ratio": 2048,
28
+ "io_channels": 1
29
+ },
30
+ "training": {
31
+ "learning_rate": 1e-4,
32
+ "warmup_steps": 0,
33
+ "use_ema": false,
34
+ "loss_configs": {
35
+ "discriminator": {
36
+ "type": "encodec",
37
+ "config": {
38
+ "filters": 32,
39
+ "n_ffts": [2048, 1024, 512, 256, 128, 64, 32],
40
+ "hop_lengths": [512, 256, 128, 64, 32, 16, 8],
41
+ "win_lengths": [2048, 1024, 512, 256, 128, 64, 32]
42
+ },
43
+ "weights": {
44
+ "adversarial": 0.1,
45
+ "feature_matching": 5.0
46
+ }
47
+ },
48
+ "spectral": {
49
+ "type": "mrstft",
50
+ "config": {
51
+ "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32],
52
+ "hop_sizes": [512, 256, 128, 64, 32, 16, 8],
53
+ "win_lengths": [2048, 1024, 512, 256, 128, 64, 32],
54
+ "perceptual_weighting": true
55
+ },
56
+ "weights": {
57
+ "mrstft": 1.0
58
+ }
59
+ },
60
+ "time": {
61
+ "type": "l1",
62
+ "weights": {
63
+ "l1": 0.0
64
+ }
65
+ }
66
+ },
67
+ "demo": {
68
+ "demo_every": 2000
69
+ }
70
+ }
71
+ }
stable_audio_tools/configs/model_configs/autoencoders/encodec_musicgen_rvq.json ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "autoencoder",
3
+ "sample_size": 32000,
4
+ "sample_rate": 32000,
5
+ "audio_channels": 1,
6
+ "model": {
7
+ "encoder": {
8
+ "type": "seanet",
9
+ "config": {
10
+ "channels": 1,
11
+ "dimension": 128,
12
+ "n_filters": 64,
13
+ "ratios": [4, 4, 5, 8],
14
+ "n_residual_layers": 1,
15
+ "dilation_base": 2,
16
+ "lstm": 2,
17
+ "norm": "weight_norm"
18
+ }
19
+ },
20
+ "decoder": {
21
+ "type": "seanet",
22
+ "config": {
23
+ "channels": 1,
24
+ "dimension": 128,
25
+ "n_filters": 64,
26
+ "ratios": [4, 4, 5, 8],
27
+ "n_residual_layers": 1,
28
+ "dilation_base": 2,
29
+ "lstm": 2,
30
+ "norm": "weight_norm"
31
+ }
32
+ },
33
+ "bottleneck": {
34
+ "type": "rvq",
35
+ "config": {
36
+ "num_quantizers": 4,
37
+ "codebook_size": 2048,
38
+ "dim": 128,
39
+ "decay": 0.99,
40
+ "threshold_ema_dead_code": 2
41
+ }
42
+ },
43
+ "latent_dim": 128,
44
+ "downsampling_ratio": 640,
45
+ "io_channels": 1
46
+ },
47
+ "training": {
48
+ "learning_rate": 1e-4,
49
+ "warmup_steps": 0,
50
+ "use_ema": true,
51
+ "loss_configs": {
52
+ "discriminator": {
53
+ "type": "encodec",
54
+ "config": {
55
+ "filters": 32,
56
+ "n_ffts": [2048, 1024, 512, 256, 128],
57
+ "hop_lengths": [512, 256, 128, 64, 32],
58
+ "win_lengths": [2048, 1024, 512, 256, 128]
59
+ },
60
+ "weights": {
61
+ "adversarial": 0.1,
62
+ "feature_matching": 5.0
63
+ }
64
+ },
65
+ "spectral": {
66
+ "type": "mrstft",
67
+ "config": {
68
+ "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32],
69
+ "hop_sizes": [512, 256, 128, 64, 32, 16, 8],
70
+ "win_lengths": [2048, 1024, 512, 256, 128, 64, 32],
71
+ "perceptual_weighting": true
72
+ },
73
+ "weights": {
74
+ "mrstft": 1.0
75
+ }
76
+ },
77
+ "time": {
78
+ "type": "l1",
79
+ "weights": {
80
+ "l1": 0.0
81
+ }
82
+ }
83
+ },
84
+ "demo": {
85
+ "demo_every": 2000
86
+ }
87
+ }
88
+ }
stable_audio_tools/configs/model_configs/autoencoders/stable_audio_1_0_vae.json ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "autoencoder",
3
+ "sample_size": 65536,
4
+ "sample_rate": 44100,
5
+ "audio_channels": 2,
6
+ "model": {
7
+ "encoder": {
8
+ "type": "dac",
9
+ "config": {
10
+ "in_channels": 2,
11
+ "latent_dim": 128,
12
+ "d_model": 128,
13
+ "strides": [4, 4, 8, 8]
14
+ }
15
+ },
16
+ "decoder": {
17
+ "type": "dac",
18
+ "config": {
19
+ "out_channels": 2,
20
+ "latent_dim": 64,
21
+ "channels": 1536,
22
+ "rates": [8, 8, 4, 4]
23
+ }
24
+ },
25
+ "bottleneck": {
26
+ "type": "vae"
27
+ },
28
+ "latent_dim": 64,
29
+ "downsampling_ratio": 1024,
30
+ "io_channels": 2
31
+ },
32
+ "training": {
33
+ "learning_rate": 1e-4,
34
+ "warmup_steps": 0,
35
+ "use_ema": true,
36
+ "optimizer_configs": {
37
+ "autoencoder": {
38
+ "optimizer": {
39
+ "type": "AdamW",
40
+ "config": {
41
+ "betas": [0.8, 0.99],
42
+ "lr": 1e-4
43
+ }
44
+ },
45
+ "scheduler": {
46
+ "type": "ExponentialLR",
47
+ "config": {
48
+ "gamma": 0.999996
49
+ }
50
+ }
51
+ },
52
+ "discriminator": {
53
+ "optimizer": {
54
+ "type": "AdamW",
55
+ "config": {
56
+ "betas": [0.8, 0.99],
57
+ "lr": 1e-4
58
+ }
59
+ },
60
+ "scheduler": {
61
+ "type": "ExponentialLR",
62
+ "config": {
63
+ "gamma": 0.999996
64
+ }
65
+ }
66
+ }
67
+ },
68
+ "loss_configs": {
69
+ "discriminator": {
70
+ "type": "encodec",
71
+ "config": {
72
+ "filters": 32,
73
+ "n_ffts": [2048, 1024, 512, 256, 128],
74
+ "hop_lengths": [512, 256, 128, 64, 32],
75
+ "win_lengths": [2048, 1024, 512, 256, 128]
76
+ },
77
+ "weights": {
78
+ "adversarial": 0.1,
79
+ "feature_matching": 5.0
80
+ }
81
+ },
82
+ "spectral": {
83
+ "type": "mrstft",
84
+ "config": {
85
+ "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32],
86
+ "hop_sizes": [512, 256, 128, 64, 32, 16, 8],
87
+ "win_lengths": [2048, 1024, 512, 256, 128, 64, 32],
88
+ "perceptual_weighting": true
89
+ },
90
+ "weights": {
91
+ "mrstft": 1.0
92
+ }
93
+ },
94
+ "time": {
95
+ "type": "l1",
96
+ "weights": {
97
+ "l1": 0.0
98
+ }
99
+ },
100
+ "bottleneck": {
101
+ "type": "kl",
102
+ "weights": {
103
+ "kl": 1e-6
104
+ }
105
+ }
106
+ },
107
+ "demo": {
108
+ "demo_every": 2000
109
+ }
110
+ }
111
+ }
stable_audio_tools/configs/model_configs/autoencoders/stable_audio_2_0_vae.json ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "autoencoder",
3
+ "sample_size": 65536,
4
+ "sample_rate": 44100,
5
+ "audio_channels": 2,
6
+ "model": {
7
+ "encoder": {
8
+ "type": "oobleck",
9
+ "config": {
10
+ "in_channels": 2,
11
+ "channels": 128,
12
+ "c_mults": [1, 2, 4, 8, 16],
13
+ "strides": [2, 4, 4, 8, 8],
14
+ "latent_dim": 128,
15
+ "use_snake": true
16
+ }
17
+ },
18
+ "decoder": {
19
+ "type": "oobleck",
20
+ "config": {
21
+ "out_channels": 2,
22
+ "channels": 128,
23
+ "c_mults": [1, 2, 4, 8, 16],
24
+ "strides": [2, 4, 4, 8, 8],
25
+ "latent_dim": 64,
26
+ "use_snake": true,
27
+ "final_tanh": false
28
+ }
29
+ },
30
+ "bottleneck": {
31
+ "type": "vae"
32
+ },
33
+ "latent_dim": 64,
34
+ "downsampling_ratio": 2048,
35
+ "io_channels": 2
36
+ },
37
+ "training": {
38
+ "learning_rate": 1.5e-4,
39
+ "warmup_steps": 0,
40
+ "use_ema": true,
41
+ "optimizer_configs": {
42
+ "autoencoder": {
43
+ "optimizer": {
44
+ "type": "AdamW",
45
+ "config": {
46
+ "betas": [0.8, 0.99],
47
+ "lr": 1.5e-4,
48
+ "weight_decay": 1e-3
49
+ }
50
+ },
51
+ "scheduler": {
52
+ "type": "InverseLR",
53
+ "config": {
54
+ "inv_gamma": 200000,
55
+ "power": 0.5,
56
+ "warmup": 0.999
57
+ }
58
+ }
59
+ },
60
+ "discriminator": {
61
+ "optimizer": {
62
+ "type": "AdamW",
63
+ "config": {
64
+ "betas": [0.8, 0.99],
65
+ "lr": 3e-4,
66
+ "weight_decay": 1e-3
67
+ }
68
+ },
69
+ "scheduler": {
70
+ "type": "InverseLR",
71
+ "config": {
72
+ "inv_gamma": 200000,
73
+ "power": 0.5,
74
+ "warmup": 0.999
75
+ }
76
+ }
77
+ }
78
+ },
79
+ "loss_configs": {
80
+ "discriminator": {
81
+ "type": "encodec",
82
+ "config": {
83
+ "filters": 64,
84
+ "n_ffts": [2048, 1024, 512, 256, 128],
85
+ "hop_lengths": [512, 256, 128, 64, 32],
86
+ "win_lengths": [2048, 1024, 512, 256, 128]
87
+ },
88
+ "weights": {
89
+ "adversarial": 0.1,
90
+ "feature_matching": 5.0
91
+ }
92
+ },
93
+ "spectral": {
94
+ "type": "mrstft",
95
+ "config": {
96
+ "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32],
97
+ "hop_sizes": [512, 256, 128, 64, 32, 16, 8],
98
+ "win_lengths": [2048, 1024, 512, 256, 128, 64, 32],
99
+ "perceptual_weighting": true
100
+ },
101
+ "weights": {
102
+ "mrstft": 1.0
103
+ }
104
+ },
105
+ "time": {
106
+ "type": "l1",
107
+ "weights": {
108
+ "l1": 0.0
109
+ }
110
+ },
111
+ "bottleneck": {
112
+ "type": "kl",
113
+ "weights": {
114
+ "kl": 1e-4
115
+ }
116
+ }
117
+ },
118
+ "demo": {
119
+ "demo_every": 2000
120
+ }
121
+ }
122
+ }
stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_base.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "diffusion_uncond",
3
+ "sample_size": 65536,
4
+ "sample_rate": 48000,
5
+ "model": {
6
+ "type": "DAU1d",
7
+ "config": {
8
+ "n_attn_layers": 5
9
+ }
10
+ },
11
+ "training": {
12
+ "learning_rate": 1e-4,
13
+ "demo": {
14
+ "demo_every": 2000,
15
+ "demo_steps": 250
16
+ }
17
+ }
18
+ }
stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_base_16k.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "diffusion_uncond",
3
+ "sample_size": 65536,
4
+ "sample_rate": 16000,
5
+ "model": {
6
+ "type": "DAU1d",
7
+ "config": {
8
+ "n_attn_layers": 5
9
+ }
10
+ },
11
+ "training": {
12
+ "learning_rate": 1e-4,
13
+ "demo": {
14
+ "demo_every": 2000,
15
+ "demo_steps": 250
16
+ }
17
+ }
18
+ }
stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_base_44k.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "diffusion_uncond",
3
+ "sample_size": 65536,
4
+ "sample_rate": 44100,
5
+ "model": {
6
+ "type": "DAU1d",
7
+ "config": {
8
+ "n_attn_layers": 5
9
+ }
10
+ },
11
+ "training": {
12
+ "learning_rate": 4e-5,
13
+ "demo": {
14
+ "demo_every": 2000,
15
+ "demo_steps": 250
16
+ }
17
+ }
18
+ }
stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_large.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "diffusion_uncond",
3
+ "sample_size": 131072,
4
+ "sample_rate": 48000,
5
+ "model": {
6
+ "type": "DAU1d",
7
+ "config": {
8
+ "n_attn_layers": 5
9
+ }
10
+ },
11
+ "training": {
12
+ "learning_rate": 1e-4,
13
+ "demo": {
14
+ "demo_every": 2000,
15
+ "demo_steps": 250
16
+ }
17
+ }
18
+ }
stable_audio_tools/configs/model_configs/txt2audio/musicgen_small_finetune.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "musicgen",
3
+ "sample_size": 320000,
4
+ "sample_rate": 32000,
5
+ "audio_channels": 1,
6
+ "model": {
7
+ "pretrained": "small"
8
+ },
9
+ "training": {
10
+ "learning_rate": 1e-4,
11
+ "demo": {
12
+ "demo_every": 2000,
13
+ "demo_cond": [
14
+ {"prompt": "Keywords: Atmospheres, Orchestral Drone, Bass, Sci-Fi Ambient Soundscape, Synthesiser, Middle Eastern Vocal, dramatic piano"},
15
+ {"prompt": "Genre: Corporate|Instruments: Ukulele, Drums, Clapping, Glockenspiel"},
16
+ {"prompt": "Description: 116 BPM rock drums, drum track for a rock song"},
17
+ {"prompt": "A grand orchestral arrangement with thunderous percussion, epic brass fanfares, and soaring strings, creating a cinematic atmosphere fit for a heroic battle."}
18
+ ],
19
+ "demo_cfg_scales": [3, 6, 9]
20
+ }
21
+ }
22
+ }
stable_audio_tools/configs/model_configs/txt2audio/stable_audio_1_0.json ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "diffusion_cond",
3
+ "sample_size": 4194304,
4
+ "sample_rate": 44100,
5
+ "audio_channels": 2,
6
+ "model": {
7
+ "pretransform": {
8
+ "type": "autoencoder",
9
+ "iterate_batch": true,
10
+ "config": {
11
+ "encoder": {
12
+ "type": "dac",
13
+ "config": {
14
+ "in_channels": 2,
15
+ "latent_dim": 128,
16
+ "d_model": 128,
17
+ "strides": [4, 4, 8, 8]
18
+ }
19
+ },
20
+ "decoder": {
21
+ "type": "dac",
22
+ "config": {
23
+ "out_channels": 2,
24
+ "latent_dim": 64,
25
+ "channels": 1536,
26
+ "rates": [8, 8, 4, 4]
27
+ }
28
+ },
29
+ "bottleneck": {
30
+ "type": "vae"
31
+ },
32
+ "latent_dim": 64,
33
+ "downsampling_ratio": 1024,
34
+ "io_channels": 2
35
+ }
36
+ },
37
+ "conditioning": {
38
+ "configs": [
39
+ {
40
+ "id": "prompt",
41
+ "type": "clap_text",
42
+ "config": {
43
+ "audio_model_type": "HTSAT-base",
44
+ "enable_fusion": true,
45
+ "clap_ckpt_path": "/path/to/clap.ckpt",
46
+ "use_text_features": true,
47
+ "feature_layer_ix": -2
48
+ }
49
+ },
50
+ {
51
+ "id": "seconds_start",
52
+ "type": "int",
53
+ "config": {
54
+ "min_val": 0,
55
+ "max_val": 512
56
+ }
57
+ },
58
+ {
59
+ "id": "seconds_total",
60
+ "type": "int",
61
+ "config": {
62
+ "min_val": 0,
63
+ "max_val": 512
64
+ }
65
+ }
66
+ ],
67
+ "cond_dim": 768
68
+ },
69
+ "diffusion": {
70
+ "type": "adp_cfg_1d",
71
+ "cross_attention_cond_ids": ["prompt", "seconds_start", "seconds_total"],
72
+ "config": {
73
+ "in_channels": 64,
74
+ "context_embedding_features": 768,
75
+ "context_embedding_max_length": 79,
76
+ "channels": 256,
77
+ "resnet_groups": 16,
78
+ "kernel_multiplier_downsample": 2,
79
+ "multipliers": [4, 4, 4, 5, 5],
80
+ "factors": [1, 2, 2, 4],
81
+ "num_blocks": [2, 2, 2, 2],
82
+ "attentions": [1, 3, 3, 3, 3],
83
+ "attention_heads": 16,
84
+ "attention_multiplier": 4,
85
+ "use_nearest_upsample": false,
86
+ "use_skip_scale": true,
87
+ "use_context_time": true
88
+ }
89
+ },
90
+ "io_channels": 64
91
+ },
92
+ "training": {
93
+ "learning_rate": 4e-5,
94
+ "demo": {
95
+ "demo_every": 2000,
96
+ "demo_steps": 250,
97
+ "num_demos": 4,
98
+ "demo_cond": [
99
+ {"prompt": "A beautiful piano arpeggio", "seconds_start": 0, "seconds_total": 95},
100
+ {"prompt": "A tropical house track with upbeat melodies, a driving bassline, and cheery vibes", "seconds_start": 0, "seconds_total": 90},
101
+ {"prompt": "A cool 80s glam rock song with driving drums and distorted guitars", "seconds_start": 0, "seconds_total": 180},
102
+ {"prompt": "A grand orchestral arrangement with thunderous percussion, epic brass fanfares, and soaring strings, creating a cinematic atmosphere fit for a heroic battle.", "seconds_start": 0, "seconds_total": 60}
103
+ ],
104
+ "demo_cfg_scales": [3, 6, 9]
105
+ }
106
+ }
107
+ }
stable_audio_tools/configs/model_configs/txt2audio/stable_audio_2_0.json ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "diffusion_cond",
3
+ "sample_size": 12582912,
4
+ "sample_rate": 44100,
5
+ "audio_channels": 2,
6
+ "model": {
7
+ "pretransform": {
8
+ "type": "autoencoder",
9
+ "iterate_batch": true,
10
+ "config": {
11
+ "encoder": {
12
+ "type": "oobleck",
13
+ "config": {
14
+ "in_channels": 2,
15
+ "channels": 128,
16
+ "c_mults": [1, 2, 4, 8, 16],
17
+ "strides": [2, 4, 4, 8, 8],
18
+ "latent_dim": 128,
19
+ "use_snake": true
20
+ }
21
+ },
22
+ "decoder": {
23
+ "type": "oobleck",
24
+ "config": {
25
+ "out_channels": 2,
26
+ "channels": 128,
27
+ "c_mults": [1, 2, 4, 8, 16],
28
+ "strides": [2, 4, 4, 8, 8],
29
+ "latent_dim": 64,
30
+ "use_snake": true,
31
+ "final_tanh": false
32
+ }
33
+ },
34
+ "bottleneck": {
35
+ "type": "vae"
36
+ },
37
+ "latent_dim": 64,
38
+ "downsampling_ratio": 2048,
39
+ "io_channels": 2
40
+ }
41
+ },
42
+ "conditioning": {
43
+ "configs": [
44
+ {
45
+ "id": "prompt",
46
+ "type": "clap_text",
47
+ "config": {
48
+ "audio_model_type": "HTSAT-base",
49
+ "enable_fusion": true,
50
+ "clap_ckpt_path": "/path/to/clap.ckpt",
51
+ "use_text_features": true,
52
+ "feature_layer_ix": -2
53
+ }
54
+ },
55
+ {
56
+ "id": "seconds_start",
57
+ "type": "number",
58
+ "config": {
59
+ "min_val": 0,
60
+ "max_val": 512
61
+ }
62
+ },
63
+ {
64
+ "id": "seconds_total",
65
+ "type": "number",
66
+ "config": {
67
+ "min_val": 0,
68
+ "max_val": 512
69
+ }
70
+ }
71
+ ],
72
+ "cond_dim": 768
73
+ },
74
+ "diffusion": {
75
+ "cross_attention_cond_ids": ["prompt", "seconds_start", "seconds_total"],
76
+ "global_cond_ids": ["seconds_start", "seconds_total"],
77
+ "type": "dit",
78
+ "config": {
79
+ "io_channels": 64,
80
+ "embed_dim": 1536,
81
+ "depth": 24,
82
+ "num_heads": 24,
83
+ "cond_token_dim": 768,
84
+ "global_cond_dim": 1536,
85
+ "project_cond_tokens": false,
86
+ "transformer_type": "continuous_transformer"
87
+ }
88
+ },
89
+ "io_channels": 64
90
+ },
91
+ "training": {
92
+ "use_ema": true,
93
+ "log_loss_info": false,
94
+ "optimizer_configs": {
95
+ "diffusion": {
96
+ "optimizer": {
97
+ "type": "AdamW",
98
+ "config": {
99
+ "lr": 5e-5,
100
+ "betas": [0.9, 0.999],
101
+ "weight_decay": 1e-3
102
+ }
103
+ },
104
+ "scheduler": {
105
+ "type": "InverseLR",
106
+ "config": {
107
+ "inv_gamma": 1000000,
108
+ "power": 0.5,
109
+ "warmup": 0.99
110
+ }
111
+ }
112
+ }
113
+ },
114
+ "demo": {
115
+ "demo_every": 2000,
116
+ "demo_steps": 250,
117
+ "num_demos": 4,
118
+ "demo_cond": [
119
+ {"prompt": "A beautiful piano arpeggio", "seconds_start": 0, "seconds_total": 80},
120
+ {"prompt": "A tropical house track with upbeat melodies, a driving bassline, and cheery vibes", "seconds_start": 0, "seconds_total": 250},
121
+ {"prompt": "A cool 80s glam rock song with driving drums and distorted guitars", "seconds_start": 0, "seconds_total": 180},
122
+ {"prompt": "A grand orchestral arrangement with thunderous percussion, epic brass fanfares, and soaring strings, creating a cinematic atmosphere fit for a heroic battle.", "seconds_start": 0, "seconds_total": 190}
123
+ ],
124
+ "demo_cfg_scales": [3, 6, 9]
125
+ }
126
+ }
127
+ }
stable_audio_tools/data/__init__.py ADDED
File without changes
stable_audio_tools/data/dataset.py ADDED
@@ -0,0 +1,597 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import numpy as np
3
+ import io
4
+ import os
5
+ import posixpath
6
+ import random
7
+ import re
8
+ import subprocess
9
+ import time
10
+ import torch
11
+ import torchaudio
12
+ import webdataset as wds
13
+
14
+ from aeiou.core import is_silence
15
+ from os import path
16
+ from pedalboard.io import AudioFile
17
+ from torchaudio import transforms as T
18
+ from typing import Optional, Callable, List
19
+
20
+ from .utils import Stereo, Mono, PhaseFlipper, PadCrop_Normalized_T
21
+
22
+ AUDIO_KEYS = ("flac", "wav", "mp3", "m4a", "ogg", "opus")
23
+
24
+ # fast_scandir implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py
25
+
26
+ def fast_scandir(
27
+ dir:str, # top-level directory at which to begin scanning
28
+ ext:list, # list of allowed file extensions,
29
+ #max_size = 1 * 1000 * 1000 * 1000 # Only files < 1 GB
30
+ ):
31
+ "very fast `glob` alternative. from https://stackoverflow.com/a/59803793/4259243"
32
+ subfolders, files = [], []
33
+ ext = ['.'+x if x[0]!='.' else x for x in ext] # add starting period to extensions if needed
34
+ try: # hope to avoid 'permission denied' by this try
35
+ for f in os.scandir(dir):
36
+ try: # 'hope to avoid too many levels of symbolic links' error
37
+ if f.is_dir():
38
+ subfolders.append(f.path)
39
+ elif f.is_file():
40
+ file_ext = os.path.splitext(f.name)[1].lower()
41
+ is_hidden = os.path.basename(f.path).startswith(".")
42
+
43
+ if file_ext in ext and not is_hidden:
44
+ files.append(f.path)
45
+ except:
46
+ pass
47
+ except:
48
+ pass
49
+
50
+ for dir in list(subfolders):
51
+ sf, f = fast_scandir(dir, ext)
52
+ subfolders.extend(sf)
53
+ files.extend(f)
54
+ return subfolders, files
55
+
56
+ def keyword_scandir(
57
+ dir: str, # top-level directory at which to begin scanning
58
+ ext: list, # list of allowed file extensions
59
+ keywords: list, # list of keywords to search for in the file name
60
+ ):
61
+ "very fast `glob` alternative. from https://stackoverflow.com/a/59803793/4259243"
62
+ subfolders, files = [], []
63
+ # make keywords case insensitive
64
+ keywords = [keyword.lower() for keyword in keywords]
65
+ # add starting period to extensions if needed
66
+ ext = ['.'+x if x[0] != '.' else x for x in ext]
67
+ banned_words = ["paxheader", "__macosx"]
68
+ try: # hope to avoid 'permission denied' by this try
69
+ for f in os.scandir(dir):
70
+ try: # 'hope to avoid too many levels of symbolic links' error
71
+ if f.is_dir():
72
+ subfolders.append(f.path)
73
+ elif f.is_file():
74
+ is_hidden = f.name.split("/")[-1][0] == '.'
75
+ has_ext = os.path.splitext(f.name)[1].lower() in ext
76
+ name_lower = f.name.lower()
77
+ has_keyword = any(
78
+ [keyword in name_lower for keyword in keywords])
79
+ has_banned = any(
80
+ [banned_word in name_lower for banned_word in banned_words])
81
+ if has_ext and has_keyword and not has_banned and not is_hidden and not os.path.basename(f.path).startswith("._"):
82
+ files.append(f.path)
83
+ except:
84
+ pass
85
+ except:
86
+ pass
87
+
88
+ for dir in list(subfolders):
89
+ sf, f = keyword_scandir(dir, ext, keywords)
90
+ subfolders.extend(sf)
91
+ files.extend(f)
92
+ return subfolders, files
93
+
94
+ def get_audio_filenames(
95
+ paths: list, # directories in which to search
96
+ keywords=None,
97
+ exts=['.wav', '.mp3', '.flac', '.ogg', '.aif', '.opus']
98
+ ):
99
+ "recursively get a list of audio filenames"
100
+ filenames = []
101
+ if type(paths) is str:
102
+ paths = [paths]
103
+ for path in paths: # get a list of relevant filenames
104
+ if keywords is not None:
105
+ subfolders, files = keyword_scandir(path, exts, keywords)
106
+ else:
107
+ subfolders, files = fast_scandir(path, exts)
108
+ filenames.extend(files)
109
+ return filenames
110
+
111
+ class SampleDataset(torch.utils.data.Dataset):
112
+ def __init__(
113
+ self,
114
+ paths,
115
+ sample_size=65536,
116
+ sample_rate=48000,
117
+ keywords=None,
118
+ relpath=None,
119
+ random_crop=True,
120
+ force_channels="stereo",
121
+ custom_metadata_fn: Optional[Callable[[str], str]] = None
122
+ ):
123
+ super().__init__()
124
+ self.filenames = []
125
+ self.relpath = relpath
126
+
127
+ self.augs = torch.nn.Sequential(
128
+ PhaseFlipper(),
129
+ )
130
+
131
+ self.pad_crop = PadCrop_Normalized_T(sample_size, sample_rate, randomize=random_crop)
132
+
133
+ self.force_channels = force_channels
134
+
135
+ self.encoding = torch.nn.Sequential(
136
+ Stereo() if self.force_channels == "stereo" else torch.nn.Identity(),
137
+ Mono() if self.force_channels == "mono" else torch.nn.Identity(),
138
+ )
139
+
140
+ self.filenames = get_audio_filenames(paths, keywords)
141
+
142
+ print(f'Found {len(self.filenames)} files')
143
+
144
+ self.sr = sample_rate
145
+
146
+ self.custom_metadata_fn = custom_metadata_fn
147
+
148
+ def load_file(self, filename):
149
+ ext = filename.split(".")[-1]
150
+
151
+ if ext == "mp3":
152
+ with AudioFile(filename) as f:
153
+ audio = f.read(f.frames)
154
+ audio = torch.from_numpy(audio)
155
+ in_sr = f.samplerate
156
+ else:
157
+ audio, in_sr = torchaudio.load(filename, format=ext)
158
+
159
+ if in_sr != self.sr:
160
+ resample_tf = T.Resample(in_sr, self.sr)
161
+ audio = resample_tf(audio)
162
+
163
+ return audio
164
+
165
+ def __len__(self):
166
+ return len(self.filenames)
167
+
168
+ def __getitem__(self, idx):
169
+ audio_filename = self.filenames[idx]
170
+ try:
171
+ start_time = time.time()
172
+ audio = self.load_file(audio_filename)
173
+
174
+ audio, t_start, t_end, seconds_start, seconds_total, padding_mask = self.pad_crop(audio)
175
+
176
+ # Run augmentations on this sample (including random crop)
177
+ if self.augs is not None:
178
+ audio = self.augs(audio)
179
+
180
+ audio = audio.clamp(-1, 1)
181
+
182
+ # Encode the file to assist in prediction
183
+ if self.encoding is not None:
184
+ audio = self.encoding(audio)
185
+
186
+ info = {}
187
+
188
+ info["path"] = audio_filename
189
+
190
+ if self.relpath is not None:
191
+ info["relpath"] = path.relpath(audio_filename, self.relpath)
192
+
193
+ info["timestamps"] = (t_start, t_end)
194
+ info["seconds_start"] = seconds_start
195
+ info["seconds_total"] = seconds_total
196
+ info["padding_mask"] = padding_mask
197
+
198
+ end_time = time.time()
199
+
200
+ info["load_time"] = end_time - start_time
201
+
202
+ if self.custom_metadata_fn is not None:
203
+ custom_metadata = self.custom_metadata_fn(info, audio)
204
+ info.update(custom_metadata)
205
+
206
+ if "__reject__" in info and info["__reject__"]:
207
+ return self[random.randrange(len(self))]
208
+
209
+ return (audio, info)
210
+ except Exception as e:
211
+ print(f'Couldn\'t load file {audio_filename}: {e}')
212
+ return self[random.randrange(len(self))]
213
+
214
+ def group_by_keys(data, keys=wds.tariterators.base_plus_ext, lcase=True, suffixes=None, handler=None):
215
+ """Return function over iterator that groups key, value pairs into samples.
216
+ :param keys: function that splits the key into key and extension (base_plus_ext)
217
+ :param lcase: convert suffixes to lower case (Default value = True)
218
+ """
219
+ current_sample = None
220
+ for filesample in data:
221
+ assert isinstance(filesample, dict)
222
+ fname, value = filesample["fname"], filesample["data"]
223
+ prefix, suffix = keys(fname)
224
+ if wds.tariterators.trace:
225
+ print(
226
+ prefix,
227
+ suffix,
228
+ current_sample.keys() if isinstance(current_sample, dict) else None,
229
+ )
230
+ if prefix is None:
231
+ continue
232
+ if lcase:
233
+ suffix = suffix.lower()
234
+ if current_sample is None or prefix != current_sample["__key__"]:
235
+ if wds.tariterators.valid_sample(current_sample):
236
+ yield current_sample
237
+ current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
238
+ if suffix in current_sample:
239
+ print(f"{fname}: duplicate file name in tar file {suffix} {current_sample.keys()}")
240
+ if suffixes is None or suffix in suffixes:
241
+ current_sample[suffix] = value
242
+ if wds.tariterators.valid_sample(current_sample):
243
+ yield current_sample
244
+
245
+ wds.tariterators.group_by_keys = group_by_keys
246
+
247
+ # S3 code and WDS preprocessing code based on implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py
248
+
249
+ def get_s3_contents(dataset_path, s3_url_prefix=None, filter='', recursive=True, debug=False, profile=None):
250
+ """
251
+ Returns a list of full S3 paths to files in a given S3 bucket and directory path.
252
+ """
253
+ # Ensure dataset_path ends with a trailing slash
254
+ if dataset_path != '' and not dataset_path.endswith('/'):
255
+ dataset_path += '/'
256
+ # Use posixpath to construct the S3 URL path
257
+ bucket_path = posixpath.join(s3_url_prefix or '', dataset_path)
258
+ # Construct the `aws s3 ls` command
259
+ cmd = ['aws', 's3', 'ls', bucket_path]
260
+
261
+ if profile is not None:
262
+ cmd.extend(['--profile', profile])
263
+
264
+ if recursive:
265
+ # Add the --recursive flag if requested
266
+ cmd.append('--recursive')
267
+
268
+ # Run the `aws s3 ls` command and capture the output
269
+ run_ls = subprocess.run(cmd, capture_output=True, check=True)
270
+ # Split the output into lines and strip whitespace from each line
271
+ contents = run_ls.stdout.decode('utf-8').split('\n')
272
+ contents = [x.strip() for x in contents if x]
273
+ # Remove the timestamp from lines that begin with a timestamp
274
+ contents = [re.sub(r'^\S+\s+\S+\s+\d+\s+', '', x)
275
+ if re.match(r'^\S+\s+\S+\s+\d+\s+', x) else x for x in contents]
276
+ # Construct a full S3 path for each file in the contents list
277
+ contents = [posixpath.join(s3_url_prefix or '', x)
278
+ for x in contents if not x.endswith('/')]
279
+ # Apply the filter, if specified
280
+ if filter:
281
+ contents = [x for x in contents if filter in x]
282
+ # Remove redundant directory names in the S3 URL
283
+ if recursive:
284
+ # Get the main directory name from the S3 URL
285
+ main_dir = "/".join(bucket_path.split('/')[3:])
286
+ # Remove the redundant directory names from each file path
287
+ contents = [x.replace(f'{main_dir}', '').replace(
288
+ '//', '/') for x in contents]
289
+ # Print debugging information, if requested
290
+ if debug:
291
+ print("contents = \n", contents)
292
+ # Return the list of S3 paths to files
293
+ return contents
294
+
295
+
296
+ def get_all_s3_urls(
297
+ names=[], # list of all valid [LAION AudioDataset] dataset names
298
+ # list of subsets you want from those datasets, e.g. ['train','valid']
299
+ subsets=[''],
300
+ s3_url_prefix=None, # prefix for those dataset names
301
+ recursive=True, # recursively list all tar files in all subdirs
302
+ filter_str='tar', # only grab files with this substring
303
+ # print debugging info -- note: info displayed likely to change at dev's whims
304
+ debug=False,
305
+ profiles={}, # dictionary of profiles for each item in names, e.g. {'dataset1': 'profile1', 'dataset2': 'profile2'}
306
+ ):
307
+ "get urls of shards (tar files) for multiple datasets in one s3 bucket"
308
+ urls = []
309
+ for name in names:
310
+ # If s3_url_prefix is not specified, assume the full S3 path is included in each element of the names list
311
+ if s3_url_prefix is None:
312
+ contents_str = name
313
+ else:
314
+ # Construct the S3 path using the s3_url_prefix and the current name value
315
+ contents_str = posixpath.join(s3_url_prefix, name)
316
+ if debug:
317
+ print(f"get_all_s3_urls: {contents_str}:")
318
+ for subset in subsets:
319
+ subset_str = posixpath.join(contents_str, subset)
320
+ if debug:
321
+ print(f"subset_str = {subset_str}")
322
+ # Get the list of tar files in the current subset directory
323
+ profile = profiles.get(name, None)
324
+ tar_list = get_s3_contents(
325
+ subset_str, s3_url_prefix=None, recursive=recursive, filter=filter_str, debug=debug, profile=profile)
326
+ for tar in tar_list:
327
+ # Escape spaces and parentheses in the tar filename for use in the shell command
328
+ tar = tar.replace(" ", "\ ").replace(
329
+ "(", "\(").replace(")", "\)")
330
+ # Construct the S3 path to the current tar file
331
+ s3_path = posixpath.join(name, subset, tar) + " -"
332
+ # Construct the AWS CLI command to download the current tar file
333
+ if s3_url_prefix is None:
334
+ request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {s3_path}"
335
+ else:
336
+ request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {posixpath.join(s3_url_prefix, s3_path)}"
337
+ if profiles.get(name):
338
+ request_str += f" --profile {profiles.get(name)}"
339
+ if debug:
340
+ print("request_str = ", request_str)
341
+ # Add the constructed URL to the list of URLs
342
+ urls.append(request_str)
343
+ return urls
344
+
345
+
346
+ def log_and_continue(exn):
347
+ """Call in an exception handler to ignore any exception, isssue a warning, and continue."""
348
+ print(f"Handling webdataset error ({repr(exn)}). Ignoring.")
349
+ return True
350
+
351
+
352
+ def is_valid_sample(sample):
353
+ has_json = "json" in sample
354
+ has_audio = "audio" in sample
355
+ is_silent = is_silence(sample["audio"])
356
+ is_rejected = "__reject__" in sample["json"] and sample["json"]["__reject__"]
357
+
358
+ return has_json and has_audio and not is_silent and not is_rejected
359
+
360
+ class S3DatasetConfig:
361
+ def __init__(
362
+ self,
363
+ id: str,
364
+ s3_path: str,
365
+ custom_metadata_fn: Optional[Callable[[str], str]] = None,
366
+ profile: Optional[str] = None,
367
+ ):
368
+ self.id = id
369
+ self.s3_path = s3_path
370
+ self.custom_metadata_fn = custom_metadata_fn
371
+ self.profile = profile
372
+ self.urls = []
373
+
374
+ def load_data_urls(self):
375
+ self.urls = get_all_s3_urls(
376
+ names=[self.s3_path],
377
+ s3_url_prefix=None,
378
+ recursive=True,
379
+ profiles={self.s3_path: self.profile} if self.profile else {},
380
+ )
381
+
382
+ return self.urls
383
+
384
+ def audio_decoder(key, value):
385
+ # Get file extension from key
386
+ ext = key.split(".")[-1]
387
+
388
+ if ext in AUDIO_KEYS:
389
+ return torchaudio.load(io.BytesIO(value))
390
+ else:
391
+ return None
392
+
393
+ def collation_fn(samples):
394
+ batched = list(zip(*samples))
395
+ result = []
396
+ for b in batched:
397
+ if isinstance(b[0], (int, float)):
398
+ b = np.array(b)
399
+ elif isinstance(b[0], torch.Tensor):
400
+ b = torch.stack(b)
401
+ elif isinstance(b[0], np.ndarray):
402
+ b = np.array(b)
403
+ else:
404
+ b = b
405
+ result.append(b)
406
+ return result
407
+
408
+ class S3WebDataLoader():
409
+ def __init__(
410
+ self,
411
+ datasets: List[S3DatasetConfig],
412
+ batch_size,
413
+ sample_size,
414
+ sample_rate=48000,
415
+ num_workers=8,
416
+ epoch_steps=1000,
417
+ random_crop=True,
418
+ force_channels="stereo",
419
+ augment_phase=True,
420
+ **data_loader_kwargs
421
+ ):
422
+
423
+ self.datasets = datasets
424
+
425
+ self.sample_size = sample_size
426
+ self.sample_rate = sample_rate
427
+ self.random_crop = random_crop
428
+ self.force_channels = force_channels
429
+ self.augment_phase = augment_phase
430
+
431
+ urls = [dataset.load_data_urls() for dataset in datasets]
432
+
433
+ # Flatten the list of lists of URLs
434
+ urls = [url for dataset_urls in urls for url in dataset_urls]
435
+
436
+ self.dataset = wds.DataPipeline(
437
+ wds.ResampledShards(urls),
438
+ wds.tarfile_to_samples(handler=log_and_continue),
439
+ wds.decode(audio_decoder, handler=log_and_continue),
440
+ wds.map(self.wds_preprocess, handler=log_and_continue),
441
+ wds.select(is_valid_sample),
442
+ wds.to_tuple("audio", "json", handler=log_and_continue),
443
+ wds.batched(batch_size, partial=False, collation_fn=collation_fn),
444
+ ).with_epoch(epoch_steps//num_workers if num_workers > 0 else epoch_steps)
445
+
446
+ self.data_loader = wds.WebLoader(self.dataset, num_workers=num_workers, **data_loader_kwargs)
447
+
448
+ def wds_preprocess(self, sample):
449
+
450
+ found_key, rewrite_key = '', ''
451
+ for k, v in sample.items(): # print the all entries in dict
452
+ for akey in AUDIO_KEYS:
453
+ if k.endswith(akey):
454
+ # to rename long/weird key with its simpler counterpart
455
+ found_key, rewrite_key = k, akey
456
+ break
457
+ if '' != found_key:
458
+ break
459
+ if '' == found_key: # got no audio!
460
+ return None # try returning None to tell WebDataset to skip this one
461
+
462
+ audio, in_sr = sample[found_key]
463
+ if in_sr != self.sample_rate:
464
+ resample_tf = T.Resample(in_sr, self.sample_rate)
465
+ audio = resample_tf(audio)
466
+
467
+ if self.sample_size is not None:
468
+ # Pad/crop and get the relative timestamp
469
+ pad_crop = PadCrop_Normalized_T(
470
+ self.sample_size, randomize=self.random_crop, sample_rate=self.sample_rate)
471
+ audio, t_start, t_end, seconds_start, seconds_total, padding_mask = pad_crop(
472
+ audio)
473
+ sample["json"]["seconds_start"] = seconds_start
474
+ sample["json"]["seconds_total"] = seconds_total
475
+ sample["json"]["padding_mask"] = padding_mask
476
+ else:
477
+ t_start, t_end = 0, 1
478
+
479
+ # Check if audio is length zero, initialize to a single zero if so
480
+ if audio.shape[-1] == 0:
481
+ audio = torch.zeros(1, 1)
482
+
483
+ # Make the audio stereo and augment by randomly inverting phase
484
+ augs = torch.nn.Sequential(
485
+ Stereo() if self.force_channels == "stereo" else torch.nn.Identity(),
486
+ Mono() if self.force_channels == "mono" else torch.nn.Identity(),
487
+ PhaseFlipper() if self.augment_phase else torch.nn.Identity()
488
+ )
489
+
490
+ audio = augs(audio)
491
+
492
+ sample["json"]["timestamps"] = (t_start, t_end)
493
+
494
+ if "text" in sample["json"]:
495
+ sample["json"]["prompt"] = sample["json"]["text"]
496
+
497
+ # Check for custom metadata functions
498
+ for dataset in self.datasets:
499
+ if dataset.custom_metadata_fn is None:
500
+ continue
501
+
502
+ if dataset.s3_path in sample["__url__"]:
503
+ custom_metadata = dataset.custom_metadata_fn(sample["json"], audio)
504
+ sample["json"].update(custom_metadata)
505
+
506
+ if found_key != rewrite_key: # rename long/weird key with its simpler counterpart
507
+ del sample[found_key]
508
+
509
+ sample["audio"] = audio
510
+
511
+ # Add audio to the metadata as well for conditioning
512
+ sample["json"]["audio"] = audio
513
+
514
+ return sample
515
+
516
+ def create_dataloader_from_config(dataset_config, batch_size, sample_size, sample_rate, audio_channels=2, num_workers=4):
517
+
518
+ dataset_type = dataset_config.get("dataset_type", None)
519
+
520
+ assert dataset_type is not None, "Dataset type must be specified in dataset config"
521
+
522
+ if audio_channels == 1:
523
+ force_channels = "mono"
524
+ else:
525
+ force_channels = "stereo"
526
+
527
+ if dataset_type == "audio_dir":
528
+
529
+ audio_dir_configs = dataset_config.get("datasets", None)
530
+
531
+ assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]"
532
+
533
+ training_dirs = []
534
+
535
+ custom_metadata_fn = None
536
+ custom_metadata_module_path = dataset_config.get("custom_metadata_module", None)
537
+
538
+ if custom_metadata_module_path is not None:
539
+ spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path)
540
+ metadata_module = importlib.util.module_from_spec(spec)
541
+ spec.loader.exec_module(metadata_module)
542
+
543
+ custom_metadata_fn = metadata_module.get_custom_metadata
544
+
545
+ for audio_dir_config in audio_dir_configs:
546
+ audio_dir_path = audio_dir_config.get("path", None)
547
+ assert audio_dir_path is not None, "Path must be set for local audio directory configuration"
548
+ training_dirs.append(audio_dir_path)
549
+
550
+ train_set = SampleDataset(
551
+ training_dirs,
552
+ sample_rate=sample_rate,
553
+ sample_size=sample_size,
554
+ random_crop=dataset_config.get("random_crop", True),
555
+ force_channels=force_channels,
556
+ custom_metadata_fn=custom_metadata_fn,
557
+ relpath=training_dirs[0] #TODO: Make relpath relative to each training dir
558
+ )
559
+
560
+ return torch.utils.data.DataLoader(train_set, batch_size, shuffle=True,
561
+ num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=True, collate_fn=collation_fn)
562
+
563
+ elif dataset_type == "s3":
564
+ dataset_configs = []
565
+
566
+ for s3_config in dataset_config["datasets"]:
567
+
568
+ custom_metadata_fn = None
569
+ custom_metadata_module_path = s3_config.get("custom_metadata_module", None)
570
+
571
+ if custom_metadata_module_path is not None:
572
+ spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path)
573
+ metadata_module = importlib.util.module_from_spec(spec)
574
+ spec.loader.exec_module(metadata_module)
575
+
576
+ custom_metadata_fn = metadata_module.get_custom_metadata
577
+
578
+ dataset_configs.append(
579
+ S3DatasetConfig(
580
+ id=s3_config["id"],
581
+ s3_path=s3_config["s3_path"],
582
+ custom_metadata_fn=custom_metadata_fn,
583
+ profile=s3_config.get("profile", None),
584
+ )
585
+ )
586
+
587
+ return S3WebDataLoader(
588
+ dataset_configs,
589
+ sample_rate=sample_rate,
590
+ sample_size=sample_size,
591
+ batch_size=batch_size,
592
+ random_crop=dataset_config.get("random_crop", True),
593
+ num_workers=num_workers,
594
+ persistent_workers=True,
595
+ force_channels=force_channels,
596
+ epoch_steps=dataset_config.get("epoch_steps", 2000),
597
+ ).data_loader
stable_audio_tools/data/utils.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import torch
4
+
5
+ from torch import nn
6
+ from typing import Tuple
7
+
8
+ class PadCrop(nn.Module):
9
+ def __init__(self, n_samples, randomize=True):
10
+ super().__init__()
11
+ self.n_samples = n_samples
12
+ self.randomize = randomize
13
+
14
+ def __call__(self, signal):
15
+ n, s = signal.shape
16
+ start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item()
17
+ end = start + self.n_samples
18
+ output = signal.new_zeros([n, self.n_samples])
19
+ output[:, :min(s, self.n_samples)] = signal[:, start:end]
20
+ return output
21
+
22
+ class PadCrop_Normalized_T(nn.Module):
23
+
24
+ def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True):
25
+
26
+ super().__init__()
27
+
28
+ self.n_samples = n_samples
29
+ self.sample_rate = sample_rate
30
+ self.randomize = randomize
31
+
32
+ def __call__(self, source: torch.Tensor) -> Tuple[torch.Tensor, float, float, int, int]:
33
+
34
+ n_channels, n_samples = source.shape
35
+
36
+ # If the audio is shorter than the desired length, pad it
37
+ upper_bound = max(0, n_samples - self.n_samples)
38
+
39
+ # If randomize is False, always start at the beginning of the audio
40
+ offset = 0
41
+ if(self.randomize and n_samples > self.n_samples):
42
+ offset = random.randint(0, upper_bound)
43
+
44
+ # Calculate the start and end times of the chunk
45
+ t_start = offset / (upper_bound + self.n_samples)
46
+ t_end = (offset + self.n_samples) / (upper_bound + self.n_samples)
47
+
48
+ # Create the chunk
49
+ chunk = source.new_zeros([n_channels, self.n_samples])
50
+
51
+ # Copy the audio into the chunk
52
+ chunk[:, :min(n_samples, self.n_samples)] = source[:, offset:offset + self.n_samples]
53
+
54
+ # Calculate the start and end times of the chunk in seconds
55
+ seconds_start = math.floor(offset / self.sample_rate)
56
+ seconds_total = math.ceil(n_samples / self.sample_rate)
57
+
58
+ # Create a mask the same length as the chunk with 1s where the audio is and 0s where it isn't
59
+ padding_mask = torch.zeros([self.n_samples])
60
+ padding_mask[:min(n_samples, self.n_samples)] = 1
61
+
62
+
63
+ return (
64
+ chunk,
65
+ t_start,
66
+ t_end,
67
+ seconds_start,
68
+ seconds_total,
69
+ padding_mask
70
+ )
71
+
72
+ class PhaseFlipper(nn.Module):
73
+ "Randomly invert the phase of a signal"
74
+ def __init__(self, p=0.5):
75
+ super().__init__()
76
+ self.p = p
77
+ def __call__(self, signal):
78
+ return -signal if (random.random() < self.p) else signal
79
+
80
+ class Mono(nn.Module):
81
+ def __call__(self, signal):
82
+ return torch.mean(signal, dim=0, keepdims=True) if len(signal.shape) > 1 else signal
83
+
84
+ class Stereo(nn.Module):
85
+ def __call__(self, signal):
86
+ signal_shape = signal.shape
87
+ # Check if it's mono
88
+ if len(signal_shape) == 1: # s -> 2, s
89
+ signal = signal.unsqueeze(0).repeat(2, 1)
90
+ elif len(signal_shape) == 2:
91
+ if signal_shape[0] == 1: #1, s -> 2, s
92
+ signal = signal.repeat(2, 1)
93
+ elif signal_shape[0] > 2: #?, s -> 2,s
94
+ signal = signal[:2, :]
95
+
96
+ return signal
stable_audio_tools/inference/__init__.py ADDED
File without changes
stable_audio_tools/inference/generation.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import typing as tp
4
+ import math
5
+ from torchaudio import transforms as T
6
+
7
+ from .utils import prepare_audio
8
+ from .sampling import sample, sample_k
9
+ from ..data.utils import PadCrop
10
+
11
+ def generate_diffusion_uncond(
12
+ model,
13
+ steps: int = 250,
14
+ batch_size: int = 1,
15
+ sample_size: int = 2097152,
16
+ seed: int = -1,
17
+ device: str = "cuda",
18
+ init_audio: tp.Optional[tp.Tuple[int, torch.Tensor]] = None,
19
+ init_noise_level: float = 1.0,
20
+ return_latents = False,
21
+ **sampler_kwargs
22
+ ) -> torch.Tensor:
23
+
24
+ # The length of the output in audio samples
25
+ audio_sample_size = sample_size
26
+
27
+ # If this is latent diffusion, change sample_size instead to the downsampled latent size
28
+ if model.pretransform is not None:
29
+ sample_size = sample_size // model.pretransform.downsampling_ratio
30
+
31
+ # Seed
32
+ # The user can explicitly set the seed to deterministically generate the same output. Otherwise, use a random seed.
33
+ seed = seed if seed != -1 else np.random.randint(0, 2**32 - 1, dtype=np.uint32)
34
+ print(seed)
35
+ torch.manual_seed(seed)
36
+ # Define the initial noise immediately after setting the seed
37
+ noise = torch.randn([batch_size, model.io_channels, sample_size], device=device)
38
+
39
+ if init_audio is not None:
40
+ # The user supplied some initial audio (for inpainting or variation). Let us prepare the input audio.
41
+ in_sr, init_audio = init_audio
42
+
43
+ io_channels = model.io_channels
44
+
45
+ # For latent models, set the io_channels to the autoencoder's io_channels
46
+ if model.pretransform is not None:
47
+ io_channels = model.pretransform.io_channels
48
+
49
+ # Prepare the initial audio for use by the model
50
+ init_audio = prepare_audio(init_audio, in_sr=in_sr, target_sr=model.sample_rate, target_length=audio_sample_size, target_channels=io_channels, device=device)
51
+
52
+ # For latent models, encode the initial audio into latents
53
+ if model.pretransform is not None:
54
+ init_audio = model.pretransform.encode(init_audio)
55
+
56
+ init_audio = init_audio.repeat(batch_size, 1, 1)
57
+ else:
58
+ # The user did not supply any initial audio for inpainting or variation. Generate new output from scratch.
59
+ init_audio = None
60
+ init_noise_level = None
61
+
62
+ # Inpainting mask
63
+
64
+ if init_audio is not None:
65
+ # variations
66
+ sampler_kwargs["sigma_max"] = init_noise_level
67
+ mask = None
68
+ else:
69
+ mask = None
70
+
71
+ # Now the generative AI part:
72
+ # k-diffusion denoising process go!
73
+ sampled = sample_k(model.model, noise, init_audio, mask, steps, **sampler_kwargs, device=device)
74
+
75
+ # Denoising process done.
76
+ # If this is latent diffusion, decode latents back into audio
77
+ if model.pretransform is not None and not return_latents:
78
+ sampled = model.pretransform.decode(sampled)
79
+
80
+ # Return audio
81
+ return sampled
82
+
83
+
84
+ def generate_diffusion_cond(
85
+ model,
86
+ steps: int = 250,
87
+ cfg_scale=6,
88
+ conditioning: dict = None,
89
+ conditioning_tensors: tp.Optional[dict] = None,
90
+ negative_conditioning: dict = None,
91
+ negative_conditioning_tensors: tp.Optional[dict] = None,
92
+ batch_size: int = 1,
93
+ sample_size: int = 2097152,
94
+ sample_rate: int = 48000,
95
+ seed: int = -1,
96
+ device: str = "cuda",
97
+ init_audio: tp.Optional[tp.Tuple[int, torch.Tensor]] = None,
98
+ init_noise_level: float = 1.0,
99
+ mask_args: dict = None,
100
+ return_latents = False,
101
+ **sampler_kwargs
102
+ ) -> torch.Tensor:
103
+ """
104
+ Generate audio from a prompt using a diffusion model.
105
+
106
+ Args:
107
+ model: The diffusion model to use for generation.
108
+ steps: The number of diffusion steps to use.
109
+ cfg_scale: Classifier-free guidance scale
110
+ conditioning: A dictionary of conditioning parameters to use for generation.
111
+ conditioning_tensors: A dictionary of precomputed conditioning tensors to use for generation.
112
+ batch_size: The batch size to use for generation.
113
+ sample_size: The length of the audio to generate, in samples.
114
+ sample_rate: The sample rate of the audio to generate (Deprecated, now pulled from the model directly)
115
+ seed: The random seed to use for generation, or -1 to use a random seed.
116
+ device: The device to use for generation.
117
+ init_audio: A tuple of (sample_rate, audio) to use as the initial audio for generation.
118
+ init_noise_level: The noise level to use when generating from an initial audio sample.
119
+ return_latents: Whether to return the latents used for generation instead of the decoded audio.
120
+ **sampler_kwargs: Additional keyword arguments to pass to the sampler.
121
+ """
122
+
123
+ # The length of the output in audio samples
124
+ audio_sample_size = sample_size
125
+
126
+ # If this is latent diffusion, change sample_size instead to the downsampled latent size
127
+ if model.pretransform is not None:
128
+ sample_size = sample_size // model.pretransform.downsampling_ratio
129
+
130
+ # Seed
131
+ # The user can explicitly set the seed to deterministically generate the same output. Otherwise, use a random seed.
132
+ seed = seed if seed != -1 else np.random.randint(0, 2**32 - 1)
133
+ print(seed)
134
+ torch.manual_seed(seed)
135
+ # Define the initial noise immediately after setting the seed
136
+ noise = torch.randn([batch_size, model.io_channels, sample_size], device=device)
137
+
138
+ # Conditioning
139
+ assert conditioning is not None or conditioning_tensors is not None, "Must provide either conditioning or conditioning_tensors"
140
+ if conditioning_tensors is None:
141
+ conditioning_tensors = model.conditioner(conditioning, device)
142
+ conditioning_tensors = model.get_conditioning_inputs(conditioning_tensors)
143
+
144
+ if negative_conditioning is not None or negative_conditioning_tensors is not None:
145
+
146
+ if negative_conditioning_tensors is None:
147
+ negative_conditioning_tensors = model.conditioner(negative_conditioning, device)
148
+
149
+ negative_conditioning_tensors = model.get_conditioning_inputs(negative_conditioning_tensors, negative=True)
150
+ else:
151
+ negative_conditioning_tensors = {}
152
+
153
+ if init_audio is not None:
154
+ # The user supplied some initial audio (for inpainting or variation). Let us prepare the input audio.
155
+ in_sr, init_audio = init_audio
156
+
157
+ io_channels = model.io_channels
158
+
159
+ # For latent models, set the io_channels to the autoencoder's io_channels
160
+ if model.pretransform is not None:
161
+ io_channels = model.pretransform.io_channels
162
+
163
+ # Prepare the initial audio for use by the model
164
+ init_audio = prepare_audio(init_audio, in_sr=in_sr, target_sr=model.sample_rate, target_length=audio_sample_size, target_channels=io_channels, device=device)
165
+
166
+ # For latent models, encode the initial audio into latents
167
+ if model.pretransform is not None:
168
+ init_audio = model.pretransform.encode(init_audio)
169
+
170
+ init_audio = init_audio.repeat(batch_size, 1, 1)
171
+ else:
172
+ # The user did not supply any initial audio for inpainting or variation. Generate new output from scratch.
173
+ init_audio = None
174
+ init_noise_level = None
175
+ mask_args = None
176
+
177
+ # Inpainting mask
178
+ if init_audio is not None and mask_args is not None:
179
+ # Cut and paste init_audio according to cropfrom, pastefrom, pasteto
180
+ # This is helpful for forward and reverse outpainting
181
+ cropfrom = math.floor(mask_args["cropfrom"]/100.0 * sample_size)
182
+ pastefrom = math.floor(mask_args["pastefrom"]/100.0 * sample_size)
183
+ pasteto = math.ceil(mask_args["pasteto"]/100.0 * sample_size)
184
+ assert pastefrom < pasteto, "Paste From should be less than Paste To"
185
+ croplen = pasteto - pastefrom
186
+ if cropfrom + croplen > sample_size:
187
+ croplen = sample_size - cropfrom
188
+ cropto = cropfrom + croplen
189
+ pasteto = pastefrom + croplen
190
+ cutpaste = init_audio.new_zeros(init_audio.shape)
191
+ cutpaste[:, :, pastefrom:pasteto] = init_audio[:,:,cropfrom:cropto]
192
+ #print(cropfrom, cropto, pastefrom, pasteto)
193
+ init_audio = cutpaste
194
+ # Build a soft mask (list of floats 0 to 1, the size of the latent) from the given args
195
+ mask = build_mask(sample_size, mask_args)
196
+ mask = mask.to(device)
197
+ elif init_audio is not None and mask_args is None:
198
+ # variations
199
+ sampler_kwargs["sigma_max"] = init_noise_level
200
+ mask = None
201
+ else:
202
+ mask = None
203
+
204
+ # Now the generative AI part:
205
+ # k-diffusion denoising process go!
206
+ sampled = sample_k(model.model, noise, init_audio, mask, steps, **sampler_kwargs, **conditioning_tensors, **negative_conditioning_tensors, cfg_scale=cfg_scale, batch_cfg=True, rescale_cfg=True, device=device)
207
+
208
+ # v-diffusion:
209
+ #sampled = sample(model.model, noise, steps, 0, **conditioning_tensors, embedding_scale=cfg_scale)
210
+
211
+ # Denoising process done.
212
+ # If this is latent diffusion, decode latents back into audio
213
+ if model.pretransform is not None and not return_latents:
214
+ #cast sampled latents to pretransform dtype
215
+ sampled = sampled.to(next(model.pretransform.parameters()).dtype)
216
+ sampled = model.pretransform.decode(sampled)
217
+
218
+ # Return audio
219
+ return sampled
220
+
221
+ # builds a softmask given the parameters
222
+ # returns array of values 0 to 1, size sample_size, where 0 means noise / fresh generation, 1 means keep the input audio,
223
+ # and anything between is a mixture of old/new
224
+ # ideally 0.5 is half/half mixture but i haven't figured this out yet
225
+ def build_mask(sample_size, mask_args):
226
+ maskstart = math.floor(mask_args["maskstart"]/100.0 * sample_size)
227
+ maskend = math.ceil(mask_args["maskend"]/100.0 * sample_size)
228
+ softnessL = round(mask_args["softnessL"]/100.0 * sample_size)
229
+ softnessR = round(mask_args["softnessR"]/100.0 * sample_size)
230
+ marination = mask_args["marination"]
231
+ # use hann windows for softening the transition (i don't know if this is correct)
232
+ hannL = torch.hann_window(softnessL*2, periodic=False)[:softnessL]
233
+ hannR = torch.hann_window(softnessR*2, periodic=False)[softnessR:]
234
+ # build the mask.
235
+ mask = torch.zeros((sample_size))
236
+ mask[maskstart:maskend] = 1
237
+ mask[maskstart:maskstart+softnessL] = hannL
238
+ mask[maskend-softnessR:maskend] = hannR
239
+ # marination finishes the inpainting early in the denoising schedule, and lets audio get changed in the final rounds
240
+ if marination > 0:
241
+ mask = mask * (1-marination)
242
+ #print(mask)
243
+ return mask
stable_audio_tools/inference/sampling.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ from tqdm import trange
4
+
5
+ import k_diffusion as K
6
+
7
+ # Define the noise schedule and sampling loop
8
+ def get_alphas_sigmas(t):
9
+ """Returns the scaling factors for the clean image (alpha) and for the
10
+ noise (sigma), given a timestep."""
11
+ return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
12
+
13
+ def alpha_sigma_to_t(alpha, sigma):
14
+ """Returns a timestep, given the scaling factors for the clean image and for
15
+ the noise."""
16
+ return torch.atan2(sigma, alpha) / math.pi * 2
17
+
18
+ def t_to_alpha_sigma(t):
19
+ """Returns the scaling factors for the clean image and for the noise, given
20
+ a timestep."""
21
+ return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
22
+
23
+ @torch.no_grad()
24
+ def sample(model, x, steps, eta, **extra_args):
25
+ """Draws samples from a model given starting noise. v-diffusion"""
26
+ ts = x.new_ones([x.shape[0]])
27
+
28
+ # Create the noise schedule
29
+ t = torch.linspace(1, 0, steps + 1)[:-1]
30
+
31
+ alphas, sigmas = get_alphas_sigmas(t)
32
+
33
+ # The sampling loop
34
+ for i in trange(steps):
35
+
36
+ # Get the model output (v, the predicted velocity)
37
+ with torch.cuda.amp.autocast():
38
+ v = model(x, ts * t[i], **extra_args).float()
39
+
40
+ # Predict the noise and the denoised image
41
+ pred = x * alphas[i] - v * sigmas[i]
42
+ eps = x * sigmas[i] + v * alphas[i]
43
+
44
+ # If we are not on the last timestep, compute the noisy image for the
45
+ # next timestep.
46
+ if i < steps - 1:
47
+ # If eta > 0, adjust the scaling factor for the predicted noise
48
+ # downward according to the amount of additional noise to add
49
+ ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \
50
+ (1 - alphas[i]**2 / alphas[i + 1]**2).sqrt()
51
+ adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt()
52
+
53
+ # Recombine the predicted noise and predicted denoised image in the
54
+ # correct proportions for the next step
55
+ x = pred * alphas[i + 1] + eps * adjusted_sigma
56
+
57
+ # Add the correct amount of fresh noise
58
+ if eta:
59
+ x += torch.randn_like(x) * ddim_sigma
60
+
61
+ # If we are on the last timestep, output the denoised image
62
+ return pred
63
+
64
+ # Soft mask inpainting is just shrinking hard (binary) mask inpainting
65
+ # Given a float-valued soft mask (values between 0 and 1), get the binary mask for this particular step
66
+ def get_bmask(i, steps, mask):
67
+ strength = (i+1)/(steps)
68
+ # convert to binary mask
69
+ bmask = torch.where(mask<=strength,1,0)
70
+ return bmask
71
+
72
+ def make_cond_model_fn(model, cond_fn):
73
+ def cond_model_fn(x, sigma, **kwargs):
74
+ with torch.enable_grad():
75
+ x = x.detach().requires_grad_()
76
+ denoised = model(x, sigma, **kwargs)
77
+ cond_grad = cond_fn(x, sigma, denoised=denoised, **kwargs).detach()
78
+ cond_denoised = denoised.detach() + cond_grad * K.utils.append_dims(sigma**2, x.ndim)
79
+ return cond_denoised
80
+ return cond_model_fn
81
+
82
+ # Uses k-diffusion from https://github.com/crowsonkb/k-diffusion
83
+ # init_data is init_audio as latents (if this is latent diffusion)
84
+ # For sampling, set both init_data and mask to None
85
+ # For variations, set init_data
86
+ # For inpainting, set both init_data & mask
87
+ def sample_k(
88
+ model_fn,
89
+ noise,
90
+ init_data=None,
91
+ mask=None,
92
+ steps=100,
93
+ sampler_type="dpmpp-2m-sde",
94
+ sigma_min=0.5,
95
+ sigma_max=50,
96
+ rho=1.0, device="cuda",
97
+ callback=None,
98
+ cond_fn=None,
99
+ **extra_args
100
+ ):
101
+
102
+ denoiser = K.external.VDenoiser(model_fn)
103
+
104
+ if cond_fn is not None:
105
+ denoiser = make_cond_model_fn(denoiser, cond_fn)
106
+
107
+ # Make the list of sigmas. Sigma values are scalars related to the amount of noise each denoising step has
108
+ sigmas = K.sampling.get_sigmas_polyexponential(steps, sigma_min, sigma_max, rho, device=device)
109
+ # Scale the initial noise by sigma
110
+ noise = noise * sigmas[0]
111
+
112
+ wrapped_callback = callback
113
+
114
+ if mask is None and init_data is not None:
115
+ # VARIATION (no inpainting)
116
+ # set the initial latent to the init_data, and noise it with initial sigma
117
+ x = init_data + noise
118
+ elif mask is not None and init_data is not None:
119
+ # INPAINTING
120
+ bmask = get_bmask(0, steps, mask)
121
+ # initial noising
122
+ input_noised = init_data + noise
123
+ # set the initial latent to a mix of init_data and noise, based on step 0's binary mask
124
+ x = input_noised * bmask + noise * (1-bmask)
125
+ # define the inpainting callback function (Note: side effects, it mutates x)
126
+ # See https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py#L596C13-L596C105
127
+ # callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
128
+ # This is called immediately after `denoised = model(x, sigmas[i] * s_in, **extra_args)`
129
+ def inpainting_callback(args):
130
+ i = args["i"]
131
+ x = args["x"]
132
+ sigma = args["sigma"]
133
+ #denoised = args["denoised"]
134
+ # noise the init_data input with this step's appropriate amount of noise
135
+ input_noised = init_data + torch.randn_like(init_data) * sigma
136
+ # shrinking hard mask
137
+ bmask = get_bmask(i, steps, mask)
138
+ # mix input_noise with x, using binary mask
139
+ new_x = input_noised * bmask + x * (1-bmask)
140
+ # mutate x
141
+ x[:,:,:] = new_x[:,:,:]
142
+ # wrap together the inpainting callback and the user-submitted callback.
143
+ if callback is None:
144
+ wrapped_callback = inpainting_callback
145
+ else:
146
+ wrapped_callback = lambda args: (inpainting_callback(args), callback(args))
147
+ else:
148
+ # SAMPLING
149
+ # set the initial latent to noise
150
+ x = noise
151
+
152
+
153
+ with torch.cuda.amp.autocast():
154
+ if sampler_type == "k-heun":
155
+ return K.sampling.sample_heun(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
156
+ elif sampler_type == "k-lms":
157
+ return K.sampling.sample_lms(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
158
+ elif sampler_type == "k-dpmpp-2s-ancestral":
159
+ return K.sampling.sample_dpmpp_2s_ancestral(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
160
+ elif sampler_type == "k-dpm-2":
161
+ return K.sampling.sample_dpm_2(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
162
+ elif sampler_type == "k-dpm-fast":
163
+ return K.sampling.sample_dpm_fast(denoiser, x, sigma_min, sigma_max, steps, disable=False, callback=wrapped_callback, extra_args=extra_args)
164
+ elif sampler_type == "k-dpm-adaptive":
165
+ return K.sampling.sample_dpm_adaptive(denoiser, x, sigma_min, sigma_max, rtol=0.01, atol=0.01, disable=False, callback=wrapped_callback, extra_args=extra_args)
166
+ elif sampler_type == "dpmpp-2m-sde":
167
+ return K.sampling.sample_dpmpp_2m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
168
+ elif sampler_type == "dpmpp-3m-sde":
169
+ return K.sampling.sample_dpmpp_3m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
170
+
stable_audio_tools/inference/utils.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..data.utils import PadCrop
2
+
3
+ from torchaudio import transforms as T
4
+
5
+ def set_audio_channels(audio, target_channels):
6
+ if target_channels == 1:
7
+ # Convert to mono
8
+ audio = audio.mean(1, keepdim=True)
9
+ elif target_channels == 2:
10
+ # Convert to stereo
11
+ if audio.shape[1] == 1:
12
+ audio = audio.repeat(1, 2, 1)
13
+ elif audio.shape[1] > 2:
14
+ audio = audio[:, :2, :]
15
+ return audio
16
+
17
+ def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device):
18
+
19
+ audio = audio.to(device)
20
+
21
+ if in_sr != target_sr:
22
+ resample_tf = T.Resample(in_sr, target_sr).to(device)
23
+ audio = resample_tf(audio)
24
+
25
+ audio = PadCrop(target_length, randomize=False)(audio)
26
+
27
+ # Add batch dimension
28
+ if audio.dim() == 1:
29
+ audio = audio.unsqueeze(0).unsqueeze(0)
30
+ elif audio.dim() == 2:
31
+ audio = audio.unsqueeze(0)
32
+
33
+ audio = set_audio_channels(audio, target_channels)
34
+
35
+ return audio
stable_audio_tools/interface/__init__.py ADDED
File without changes
stable_audio_tools/interface/gradio.py ADDED
@@ -0,0 +1,788 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import numpy as np
3
+ import gradio as gr
4
+ import json
5
+ import torch
6
+ import torchaudio
7
+
8
+ from aeiou.viz import audio_spectrogram_image
9
+ from einops import rearrange
10
+ from safetensors.torch import load_file
11
+ from torch.nn import functional as F
12
+ from torchaudio import transforms as T
13
+
14
+ from ..inference.generation import generate_diffusion_cond, generate_diffusion_uncond
15
+ from ..models.factory import create_model_from_config
16
+ from ..models.pretrained import get_pretrained_model
17
+ from ..models.utils import load_ckpt_state_dict
18
+ from ..inference.utils import prepare_audio
19
+ from ..training.utils import copy_state_dict
20
+
21
+ # Define preset values
22
+ presets = {
23
+ "Pied Currawong": {
24
+ "latitude": -33.6467,
25
+ "longitude": 150.3246,
26
+ "temperature": 12.43,
27
+ "humidity": 86,
28
+ "wind_speed": 0.66,
29
+ "pressure": 1013,
30
+ "minutes_of_day": 369,
31
+ "day_of_year": 297,
32
+ },
33
+ "Yellow-tailed Black Cockatoo": {
34
+ "latitude": -32.8334,
35
+ "longitude": 150.2001,
36
+ "temperature": 23.23,
37
+ "humidity": 45,
38
+ "wind_speed": 1.37,
39
+ "pressure": 1009,
40
+ "minutes_of_day": 986,
41
+ "day_of_year": 78,
42
+ },
43
+ "Australian Magpie": {
44
+ "latitude": -38.522,
45
+ "longitude": 145.3365,
46
+ "temperature": 18.75,
47
+ "humidity": 67,
48
+ "wind_speed": 1.5,
49
+ "pressure": 1023,
50
+ "minutes_of_day": 940,
51
+ "day_of_year": 307,
52
+ },
53
+ "Laughing Kookaburra": {
54
+ "latitude": -27.2685099,
55
+ "longitude": 152.8587437,
56
+ "temperature": 9.02,
57
+ "humidity": 94,
58
+ "wind_speed": 1.5,
59
+ "pressure": 1025,
60
+ "minutes_of_day": 320,
61
+ "day_of_year": 236,
62
+ }
63
+ }
64
+
65
+ def update_sliders(preset_name):
66
+ preset = presets[preset_name]
67
+ return (preset["latitude"], preset["longitude"], preset["temperature"], preset["humidity"], preset["wind_speed"], preset["pressure"], preset["minutes_of_day"], preset["day_of_year"])
68
+
69
+
70
+ model = None
71
+ sample_rate = 44100
72
+ sample_size = 524288
73
+
74
+
75
+ def load_model(model_config=None, model_ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, device="cuda", model_half=False):
76
+ global model, sample_rate, sample_size
77
+
78
+ if pretrained_name is not None:
79
+ print(f"Loading pretrained model {pretrained_name}")
80
+ model, model_config = get_pretrained_model(pretrained_name)
81
+
82
+ elif model_config is not None and model_ckpt_path is not None:
83
+ print(f"Creating model from config")
84
+ model = create_model_from_config(model_config)
85
+
86
+ print(f"Loading model checkpoint from {model_ckpt_path}")
87
+ # Load checkpoint
88
+ copy_state_dict(model, load_ckpt_state_dict(model_ckpt_path))
89
+ #model.load_state_dict(load_ckpt_state_dict(model_ckpt_path))
90
+
91
+ sample_rate = model_config["sample_rate"]
92
+ sample_size = model_config["sample_size"]
93
+
94
+ if pretransform_ckpt_path is not None:
95
+ print(f"Loading pretransform checkpoint from {pretransform_ckpt_path}")
96
+ model.pretransform.load_state_dict(load_ckpt_state_dict(pretransform_ckpt_path), strict=False)
97
+ print(f"Done loading pretransform")
98
+
99
+ model.to(device).eval().requires_grad_(False)
100
+
101
+ if model_half:
102
+ model.to(torch.float16)
103
+
104
+ print(f"Done loading model")
105
+
106
+ return model, model_config
107
+
108
+ def generate_cond(
109
+ seconds_start=0,
110
+ seconds_total=30,
111
+ latitude = 0.0,
112
+ longitude = 0.0,
113
+ temperature = 0.0,
114
+ humidity = 0.0,
115
+ wind_speed = 0.0,
116
+ pressure = 0.0,
117
+ minutes_of_day = 0.0,
118
+ day_of_year = 0.0,
119
+ cfg_scale=6.0,
120
+ steps=250,
121
+ preview_every=None,
122
+ seed=-1,
123
+ sampler_type="dpmpp-2m-sde",
124
+ sigma_min=0.03,
125
+ sigma_max=50,
126
+ cfg_rescale=0.4,
127
+ use_init=False,
128
+ init_audio=None,
129
+ init_noise_level=1.0,
130
+ mask_cropfrom=None,
131
+ mask_pastefrom=None,
132
+ mask_pasteto=None,
133
+ mask_maskstart=None,
134
+ mask_maskend=None,
135
+ mask_softnessL=None,
136
+ mask_softnessR=None,
137
+ mask_marination=None,
138
+ batch_size=1
139
+ ):
140
+
141
+ if torch.cuda.is_available():
142
+ torch.cuda.empty_cache()
143
+ gc.collect()
144
+
145
+
146
+ global preview_images
147
+ preview_images = []
148
+ if preview_every == 0:
149
+ preview_every = None
150
+
151
+ # Return fake stereo audio
152
+ conditioning = [{"latitude": -latitude, "longitude": longitude, "temperature": temperature, "humidity": humidity, "wind_speed": wind_speed, "pressure": pressure, "minutes_of_day": minutes_of_day,"day_of_year": day_of_year, "seconds_start":seconds_start, "seconds_total": seconds_total }] * batch_size
153
+
154
+ #Get the device from the model
155
+ device = next(model.parameters()).device
156
+
157
+ seed = int(seed)
158
+
159
+ if not use_init:
160
+ init_audio = None
161
+
162
+ input_sample_size = sample_size
163
+
164
+ if init_audio is not None:
165
+ in_sr, init_audio = init_audio
166
+ # Turn into torch tensor, converting from int16 to float32
167
+ init_audio = torch.from_numpy(init_audio).float().div(32767)
168
+
169
+ if init_audio.dim() == 1:
170
+ init_audio = init_audio.unsqueeze(0) # [1, n]
171
+ elif init_audio.dim() == 2:
172
+ init_audio = init_audio.transpose(0, 1) # [n, 2] -> [2, n]
173
+
174
+ if in_sr != sample_rate:
175
+ resample_tf = T.Resample(in_sr, sample_rate).to(init_audio.device)
176
+ init_audio = resample_tf(init_audio)
177
+
178
+ audio_length = init_audio.shape[-1]
179
+
180
+ if audio_length > sample_size:
181
+
182
+ input_sample_size = audio_length + (model.min_input_length - (audio_length % model.min_input_length)) % model.min_input_length
183
+
184
+ init_audio = (sample_rate, init_audio)
185
+
186
+ def progress_callback(callback_info):
187
+ global preview_images
188
+ denoised = callback_info["denoised"]
189
+ current_step = callback_info["i"]
190
+ sigma = callback_info["sigma"]
191
+
192
+ if (current_step - 1) % preview_every == 0:
193
+ if model.pretransform is not None:
194
+ denoised = model.pretransform.decode(denoised)
195
+ denoised = rearrange(denoised, "b d n -> d (b n)")
196
+ denoised = denoised.clamp(-1, 1).mul(32767).to(torch.int16).cpu()
197
+ audio_spectrogram = audio_spectrogram_image(denoised, sample_rate=sample_rate)
198
+ preview_images.append((audio_spectrogram, f"Step {current_step} sigma={sigma:.3f})"))
199
+
200
+ # If inpainting, send mask args
201
+ # This will definitely change in the future
202
+ if mask_cropfrom is not None:
203
+ mask_args = {
204
+ "cropfrom": mask_cropfrom,
205
+ "pastefrom": mask_pastefrom,
206
+ "pasteto": mask_pasteto,
207
+ "maskstart": mask_maskstart,
208
+ "maskend": mask_maskend,
209
+ "softnessL": mask_softnessL,
210
+ "softnessR": mask_softnessR,
211
+ "marination": mask_marination,
212
+ }
213
+ else:
214
+ mask_args = None
215
+
216
+ # Do the audio generation
217
+ audio = generate_diffusion_cond(
218
+ model,
219
+ conditioning=conditioning,
220
+ steps=steps,
221
+ cfg_scale=cfg_scale,
222
+ batch_size=batch_size,
223
+ sample_size=input_sample_size,
224
+ sample_rate=sample_rate,
225
+ seed=seed,
226
+ device=device,
227
+ sampler_type=sampler_type,
228
+ sigma_min=sigma_min,
229
+ sigma_max=sigma_max,
230
+ init_audio=init_audio,
231
+ init_noise_level=init_noise_level,
232
+ mask_args = mask_args,
233
+ callback = progress_callback if preview_every is not None else None,
234
+ scale_phi = cfg_rescale
235
+ )
236
+
237
+ # Convert to WAV file
238
+ audio = rearrange(audio, "b d n -> d (b n)")
239
+ audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
240
+ torchaudio.save("output.wav", audio, sample_rate)
241
+
242
+ # Let's look at a nice spectrogram too
243
+ audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate)
244
+
245
+ return ("output.wav", [audio_spectrogram, *preview_images])
246
+
247
+ def generate_uncond(
248
+ steps=250,
249
+ seed=-1,
250
+ sampler_type="dpmpp-2m-sde",
251
+ sigma_min=0.03,
252
+ sigma_max=50,
253
+ use_init=False,
254
+ init_audio=None,
255
+ init_noise_level=1.0,
256
+ batch_size=1,
257
+ preview_every=None
258
+ ):
259
+
260
+ global preview_images
261
+
262
+ preview_images = []
263
+
264
+ if torch.cuda.is_available():
265
+ torch.cuda.empty_cache()
266
+ gc.collect()
267
+
268
+ #Get the device from the model
269
+ device = next(model.parameters()).device
270
+
271
+ seed = int(seed)
272
+
273
+ if not use_init:
274
+ init_audio = None
275
+
276
+ input_sample_size = sample_size
277
+
278
+ if init_audio is not None:
279
+ in_sr, init_audio = init_audio
280
+ # Turn into torch tensor, converting from int16 to float32
281
+ init_audio = torch.from_numpy(init_audio).float().div(32767)
282
+
283
+ if init_audio.dim() == 1:
284
+ init_audio = init_audio.unsqueeze(0) # [1, n]
285
+ elif init_audio.dim() == 2:
286
+ init_audio = init_audio.transpose(0, 1) # [n, 2] -> [2, n]
287
+
288
+ if in_sr != sample_rate:
289
+ resample_tf = T.Resample(in_sr, sample_rate).to(init_audio.device)
290
+ init_audio = resample_tf(init_audio)
291
+
292
+ audio_length = init_audio.shape[-1]
293
+
294
+ if audio_length > sample_size:
295
+
296
+ input_sample_size = audio_length + (model.min_input_length - (audio_length % model.min_input_length)) % model.min_input_length
297
+
298
+ init_audio = (sample_rate, init_audio)
299
+
300
+ def progress_callback(callback_info):
301
+ global preview_images
302
+ denoised = callback_info["denoised"]
303
+ current_step = callback_info["i"]
304
+ sigma = callback_info["sigma"]
305
+
306
+ if (current_step - 1) % preview_every == 0:
307
+
308
+ if model.pretransform is not None:
309
+ denoised = model.pretransform.decode(denoised)
310
+
311
+ denoised = rearrange(denoised, "b d n -> d (b n)")
312
+
313
+ denoised = denoised.clamp(-1, 1).mul(32767).to(torch.int16).cpu()
314
+
315
+ audio_spectrogram = audio_spectrogram_image(denoised, sample_rate=sample_rate)
316
+
317
+ preview_images.append((audio_spectrogram, f"Step {current_step} sigma={sigma:.3f})"))
318
+
319
+ audio = generate_diffusion_uncond(
320
+ model,
321
+ steps=steps,
322
+ batch_size=batch_size,
323
+ sample_size=input_sample_size,
324
+ seed=seed,
325
+ device=device,
326
+ sampler_type=sampler_type,
327
+ sigma_min=sigma_min,
328
+ sigma_max=sigma_max,
329
+ init_audio=init_audio,
330
+ init_noise_level=init_noise_level,
331
+ callback = progress_callback if preview_every is not None else None
332
+ )
333
+
334
+ audio = rearrange(audio, "b d n -> d (b n)")
335
+
336
+ audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
337
+
338
+ torchaudio.save("output.wav", audio, sample_rate)
339
+
340
+ audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate)
341
+
342
+ return ("output.wav", [audio_spectrogram, *preview_images])
343
+
344
+ def generate_lm(
345
+ temperature=1.0,
346
+ top_p=0.95,
347
+ top_k=0,
348
+ batch_size=1,
349
+ ):
350
+
351
+ if torch.cuda.is_available():
352
+ torch.cuda.empty_cache()
353
+ gc.collect()
354
+
355
+ #Get the device from the model
356
+ device = next(model.parameters()).device
357
+
358
+ audio = model.generate_audio(
359
+ batch_size=batch_size,
360
+ max_gen_len = sample_size//model.pretransform.downsampling_ratio,
361
+ conditioning=None,
362
+ temp=temperature,
363
+ top_p=top_p,
364
+ top_k=top_k,
365
+ use_cache=True
366
+ )
367
+
368
+ audio = rearrange(audio, "b d n -> d (b n)")
369
+
370
+ audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
371
+
372
+ torchaudio.save("output.wav", audio, sample_rate)
373
+
374
+ audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate)
375
+
376
+ return ("output.wav", [audio_spectrogram])
377
+
378
+
379
+ def create_uncond_sampling_ui(model_config):
380
+ generate_button = gr.Button("Generate", variant='primary', scale=1)
381
+
382
+ with gr.Row(equal_height=False):
383
+ with gr.Column():
384
+ with gr.Row():
385
+ # Steps slider
386
+ steps_slider = gr.Slider(minimum=1, maximum=500, step=1, value=100, label="Steps")
387
+
388
+ with gr.Accordion("Sampler params", open=False):
389
+
390
+ # Seed
391
+ seed_textbox = gr.Textbox(label="Seed (set to -1 for random seed)", value="-1")
392
+
393
+ # Sampler params
394
+ with gr.Row():
395
+ sampler_type_dropdown = gr.Dropdown(["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"], label="Sampler type", value="dpmpp-2m-sde")
396
+ sigma_min_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=0.03, label="Sigma min")
397
+ sigma_max_slider = gr.Slider(minimum=0.0, maximum=200.0, step=0.1, value=80, label="Sigma max")
398
+
399
+ with gr.Accordion("Init audio", open=False):
400
+ init_audio_checkbox = gr.Checkbox(label="Use init audio")
401
+ init_audio_input = gr.Audio(label="Init audio")
402
+ init_noise_level_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.01, value=0.1, label="Init noise level")
403
+
404
+ with gr.Column():
405
+ audio_output = gr.Audio(label="Output audio", interactive=False)
406
+ audio_spectrogram_output = gr.Gallery(label="Output spectrogram", show_label=False)
407
+ send_to_init_button = gr.Button("Send to init audio", scale=1)
408
+ send_to_init_button.click(fn=lambda audio: audio, inputs=[audio_output], outputs=[init_audio_input])
409
+
410
+ generate_button.click(fn=generate_uncond,
411
+ inputs=[
412
+ steps_slider,
413
+ seed_textbox,
414
+ sampler_type_dropdown,
415
+ sigma_min_slider,
416
+ sigma_max_slider,
417
+ init_audio_checkbox,
418
+ init_audio_input,
419
+ init_noise_level_slider,
420
+ ],
421
+ outputs=[
422
+ audio_output,
423
+ audio_spectrogram_output
424
+ ],
425
+ api_name="generate")
426
+ def create_conditioning_slider(min_val, max_val, default_value, label):
427
+ """
428
+ Create a Gradio slider for a given conditioning parameter.
429
+
430
+ Args:
431
+ - min_val: The minimum value for the slider.
432
+ - max_val: The maximum value for the slider.
433
+ - label: The label for the slider, which is displayed in the UI.
434
+
435
+ Returns:
436
+ - A gr.Slider object configured according to the provided parameters.
437
+ """
438
+ step = (max_val - min_val) / 1000
439
+ default_val = default_value
440
+ print(f"Creating slider for {label} with min_val={min_val}, max_val={max_val}, step={step}, default_val={default_val}")
441
+ return gr.Slider(minimum=min_val, maximum=max_val, step=step, value=default_val, label=label)
442
+
443
+ def create_sampling_ui(model_config):
444
+ with gr.Row():
445
+
446
+ generate_button = gr.Button("Generate", variant='primary', scale=1)
447
+
448
+ model_conditioning_config = model_config["model"].get("conditioning", None)
449
+
450
+ has_seconds_start = False
451
+ has_seconds_total = False
452
+
453
+ if model_conditioning_config is not None:
454
+ for conditioning_config in model_conditioning_config["configs"]:
455
+ if conditioning_config["id"] == "seconds_start":
456
+ has_seconds_start = True
457
+ if conditioning_config["id"] == "seconds_total":
458
+ has_seconds_total = True
459
+
460
+ with gr.Row(equal_height=False):
461
+ with gr.Column():
462
+ with gr.Row():
463
+
464
+ seconds_start_slider = gr.Slider(minimum=0, maximum=512, step=1, value=0, label="Seconds start", visible=has_seconds_start)
465
+
466
+ seconds_total_slider = gr.Slider(minimum=0, maximum=22, step=1, value=sample_size//sample_rate, label="Seconds total", visible=has_seconds_total)
467
+
468
+ with gr.Row():
469
+ # Steps slider
470
+ steps_slider = gr.Slider(minimum=1, maximum=500, step=1, value=250, label="Steps")
471
+
472
+ # Preview Every slider
473
+ preview_every_slider = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Preview Every")
474
+
475
+ # CFG scale
476
+ cfg_scale_slider = gr.Slider(minimum=0.0, maximum=25.0, step=0.1, value=4.0, label="CFG scale")
477
+
478
+ with gr.Accordion("Climate and location", open=True):
479
+ preset_dropdown = gr.Dropdown(choices=list(presets.keys()), label="Select Preset")
480
+
481
+ latitude_config = next((item for item in model_conditioning_config["configs"] if item["id"] == "latitude"), None)
482
+ if latitude_config:
483
+ latitude_slider = create_conditioning_slider(
484
+ min_val=latitude_config["config"]["min_val"],
485
+ max_val=latitude_config["config"]["max_val"],
486
+ default_value = -29.8913,
487
+ label="latitude")
488
+
489
+ longitude_config = next((item for item in model_conditioning_config["configs"] if item["id"] == "longitude"), None)
490
+ if longitude_config:
491
+ longitude_slider = create_conditioning_slider(
492
+ min_val=longitude_config["config"]["min_val"],
493
+ max_val=longitude_config["config"]["max_val"],
494
+ default_value=152.4951,
495
+ label="longitude")
496
+
497
+ temperature_config = next((item for item in model_conditioning_config["configs"] if item["id"] == "temperature"), None)
498
+ if temperature_config:
499
+ temperature_slider = create_conditioning_slider(
500
+ min_val=temperature_config["config"]["min_val"],
501
+ max_val=temperature_config["config"]["max_val"],
502
+ default_value=22.05,
503
+ label="temperature")
504
+
505
+ humidity_config = next((item for item in model_conditioning_config["configs"] if item["id"] == "humidity"), None)
506
+ if humidity_config:
507
+ humidity_slider = create_conditioning_slider(
508
+ min_val=humidity_config["config"]["min_val"],
509
+ max_val=humidity_config["config"]["max_val"],
510
+ default_value=88,
511
+ label="humidity")
512
+
513
+ wind_speed_config = next((item for item in model_conditioning_config["configs"] if item["id"] == "wind_speed"), None)
514
+ if wind_speed_config:
515
+ wind_speed_slider = create_conditioning_slider(
516
+ min_val=wind_speed_config["config"]["min_val"],
517
+ max_val=wind_speed_config["config"]["max_val"],
518
+ default_value=0.54,
519
+ label="wind_speed")
520
+
521
+ pressure_config = next((item for item in model_conditioning_config["configs"] if item["id"] == "pressure"), None)
522
+ if pressure_config:
523
+ pressure_slider = create_conditioning_slider(
524
+ min_val=pressure_config["config"]["min_val"],
525
+ max_val=pressure_config["config"]["max_val"],
526
+ default_value=1021,
527
+ label="pressure")
528
+
529
+ minutes_of_day_config = next((item for item in model_conditioning_config["configs"] if item["id"] == "minutes_of_day"), None)
530
+ if minutes_of_day_config:
531
+ minutes_of_day_slider = create_conditioning_slider(
532
+ min_val=minutes_of_day_config["config"]["min_val"],
533
+ max_val=minutes_of_day_config["config"]["max_val"],
534
+ default_value=1354,
535
+ label="minutes_of_day")
536
+
537
+ day_of_year_config = next((item for item in model_conditioning_config["configs"] if item["id"] == "day_of_year"), None)
538
+ if day_of_year_config:
539
+ day_of_year_slider = create_conditioning_slider(
540
+ min_val=day_of_year_config["config"]["min_val"],
541
+ max_val=day_of_year_config["config"]["max_val"],
542
+ default_value=342,
543
+ label="Day of year")
544
+
545
+ with gr.Accordion("Sampler params", open=False):
546
+
547
+ # Seed
548
+ seed_textbox = gr.Textbox(label="Seed (set to -1 for random seed)", value="-1")
549
+
550
+ # Sampler params
551
+ with gr.Row():
552
+ sampler_type_dropdown = gr.Dropdown(["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"], label="Sampler type", value="dpmpp-2m-sde")
553
+ sigma_min_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=0.03, label="Sigma min")
554
+ sigma_max_slider = gr.Slider(minimum=0.0, maximum=200.0, step=0.1, value=50, label="Sigma max")
555
+ cfg_rescale_slider = gr.Slider(minimum=0.0, maximum=1, step=0.01, value=0.4, label="CFG rescale amount")
556
+
557
+
558
+ # Default generation tab
559
+ with gr.Accordion("Init audio", open=False):
560
+ init_audio_input = gr.Audio(label="Init audio")
561
+ init_noise_level_slider = gr.Slider(minimum=0.1, maximum=100.0, step=0.01, value=1.0, label="Init noise level")
562
+
563
+ inputs = [
564
+ seconds_start_slider,
565
+ seconds_total_slider,
566
+ latitude_slider,
567
+ longitude_slider,
568
+ temperature_slider,
569
+ humidity_slider,
570
+ wind_speed_slider,
571
+ pressure_slider,
572
+ minutes_of_day_slider,
573
+ day_of_year_slider,
574
+ cfg_scale_slider,
575
+ steps_slider,
576
+ preview_every_slider,
577
+ seed_textbox,
578
+ sampler_type_dropdown,
579
+ sigma_min_slider,
580
+ sigma_max_slider,
581
+ cfg_rescale_slider,
582
+ init_noise_level_slider
583
+ ]
584
+
585
+ with gr.Column():
586
+ audio_output = gr.Audio(label="Output audio", interactive=False)
587
+ audio_spectrogram_output = gr.Gallery(label="Output spectrogram", show_label=False)
588
+
589
+ generate_button.click(fn=generate_cond,
590
+ inputs=inputs,
591
+ outputs=[
592
+ audio_output,
593
+ audio_spectrogram_output
594
+ ],
595
+ api_name="generate")
596
+
597
+ preset_dropdown.change(
598
+ fn=update_sliders,
599
+ inputs=[preset_dropdown],
600
+ outputs=[
601
+ latitude_slider,
602
+ longitude_slider,
603
+ temperature_slider,
604
+ humidity_slider,
605
+ wind_speed_slider,
606
+ pressure_slider,
607
+ minutes_of_day_slider,
608
+ day_of_year_slider
609
+ ]
610
+ )
611
+
612
+ def create_txt2audio_ui(model_config):
613
+ with gr.Blocks() as ui:
614
+ with gr.Tab("Generation"):
615
+ create_sampling_ui(model_config)
616
+ # with gr.Tab("Inpainting"):
617
+ # create_sampling_ui(model_config, inpainting=True)
618
+ return ui
619
+
620
+ def create_diffusion_uncond_ui(model_config):
621
+ with gr.Blocks() as ui:
622
+ create_uncond_sampling_ui(model_config)
623
+
624
+ return ui
625
+
626
+ def autoencoder_process(audio, latent_noise, n_quantizers):
627
+ if torch.cuda.is_available():
628
+ torch.cuda.empty_cache()
629
+ gc.collect()
630
+
631
+ #Get the device from the model
632
+ device = next(model.parameters()).device
633
+
634
+ in_sr, audio = audio
635
+
636
+ audio = torch.from_numpy(audio).float().div(32767).to(device)
637
+
638
+ if audio.dim() == 1:
639
+ audio = audio.unsqueeze(0)
640
+ else:
641
+ audio = audio.transpose(0, 1)
642
+
643
+ audio = model.preprocess_audio_for_encoder(audio, in_sr)
644
+ # Note: If you need to do chunked encoding, to reduce VRAM,
645
+ # then add these arguments to encode_audio and decode_audio: chunked=True, overlap=32, chunk_size=128
646
+ # To turn it off, do chunked=False
647
+ # Optimal overlap and chunk_size values will depend on the model.
648
+ # See encode_audio & decode_audio in autoencoders.py for more info
649
+ # Get dtype of model
650
+ dtype = next(model.parameters()).dtype
651
+
652
+ audio = audio.to(dtype)
653
+
654
+ if n_quantizers > 0:
655
+ latents = model.encode_audio(audio, chunked=False, n_quantizers=n_quantizers)
656
+ else:
657
+ latents = model.encode_audio(audio, chunked=False)
658
+
659
+ if latent_noise > 0:
660
+ latents = latents + torch.randn_like(latents) * latent_noise
661
+
662
+ audio = model.decode_audio(latents, chunked=False)
663
+
664
+ audio = rearrange(audio, "b d n -> d (b n)")
665
+
666
+ audio = audio.to(torch.float32).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
667
+
668
+ torchaudio.save("output.wav", audio, sample_rate)
669
+
670
+ return "output.wav"
671
+
672
+ def create_autoencoder_ui(model_config):
673
+
674
+ is_dac_rvq = "model" in model_config and "bottleneck" in model_config["model"] and model_config["model"]["bottleneck"]["type"] in ["dac_rvq","dac_rvq_vae"]
675
+
676
+ if is_dac_rvq:
677
+ n_quantizers = model_config["model"]["bottleneck"]["config"]["n_codebooks"]
678
+ else:
679
+ n_quantizers = 0
680
+
681
+ with gr.Blocks() as ui:
682
+ input_audio = gr.Audio(label="Input audio")
683
+ output_audio = gr.Audio(label="Output audio", interactive=False)
684
+ n_quantizers_slider = gr.Slider(minimum=1, maximum=n_quantizers, step=1, value=n_quantizers, label="# quantizers", visible=is_dac_rvq)
685
+ latent_noise_slider = gr.Slider(minimum=0.0, maximum=10.0, step=0.001, value=0.0, label="Add latent noise")
686
+ process_button = gr.Button("Process", variant='primary', scale=1)
687
+ process_button.click(fn=autoencoder_process, inputs=[input_audio, latent_noise_slider, n_quantizers_slider], outputs=output_audio, api_name="process")
688
+
689
+ return ui
690
+
691
+ def diffusion_prior_process(audio, steps, sampler_type, sigma_min, sigma_max):
692
+
693
+ if torch.cuda.is_available():
694
+ torch.cuda.empty_cache()
695
+ gc.collect()
696
+
697
+ #Get the device from the model
698
+ device = next(model.parameters()).device
699
+
700
+ in_sr, audio = audio
701
+
702
+ audio = torch.from_numpy(audio).float().div(32767).to(device)
703
+
704
+ if audio.dim() == 1:
705
+ audio = audio.unsqueeze(0) # [1, n]
706
+ elif audio.dim() == 2:
707
+ audio = audio.transpose(0, 1) # [n, 2] -> [2, n]
708
+
709
+ audio = audio.unsqueeze(0)
710
+
711
+ audio = model.stereoize(audio, in_sr, steps, sampler_kwargs={"sampler_type": sampler_type, "sigma_min": sigma_min, "sigma_max": sigma_max})
712
+
713
+ audio = rearrange(audio, "b d n -> d (b n)")
714
+
715
+ audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
716
+
717
+ torchaudio.save("output.wav", audio, sample_rate)
718
+
719
+ return "output.wav"
720
+
721
+ def create_diffusion_prior_ui(model_config):
722
+ with gr.Blocks() as ui:
723
+ input_audio = gr.Audio(label="Input audio")
724
+ output_audio = gr.Audio(label="Output audio", interactive=False)
725
+ # Sampler params
726
+ with gr.Row():
727
+ steps_slider = gr.Slider(minimum=1, maximum=500, step=1, value=100, label="Steps")
728
+ sampler_type_dropdown = gr.Dropdown(["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"], label="Sampler type", value="dpmpp-2m-sde")
729
+ sigma_min_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=0.03, label="Sigma min")
730
+ sigma_max_slider = gr.Slider(minimum=0.0, maximum=200.0, step=0.1, value=80, label="Sigma max")
731
+ process_button = gr.Button("Process", variant='primary', scale=1)
732
+ process_button.click(fn=diffusion_prior_process, inputs=[input_audio, steps_slider, sampler_type_dropdown, sigma_min_slider, sigma_max_slider], outputs=output_audio, api_name="process")
733
+
734
+ return ui
735
+
736
+ def create_lm_ui(model_config):
737
+ with gr.Blocks() as ui:
738
+ output_audio = gr.Audio(label="Output audio", interactive=False)
739
+ audio_spectrogram_output = gr.Gallery(label="Output spectrogram", show_label=False)
740
+
741
+ # Sampling params
742
+ with gr.Row():
743
+ temperature_slider = gr.Slider(minimum=0, maximum=5, step=0.01, value=1.0, label="Temperature")
744
+ top_p_slider = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.95, label="Top p")
745
+ top_k_slider = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Top k")
746
+
747
+ generate_button = gr.Button("Generate", variant='primary', scale=1)
748
+ generate_button.click(
749
+ fn=generate_lm,
750
+ inputs=[
751
+ temperature_slider,
752
+ top_p_slider,
753
+ top_k_slider
754
+ ],
755
+ outputs=[output_audio, audio_spectrogram_output],
756
+ api_name="generate"
757
+ )
758
+
759
+ return ui
760
+
761
+ def create_ui(model_config_path=None, ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, model_half=False):
762
+
763
+ assert (pretrained_name is not None) ^ (model_config_path is not None and ckpt_path is not None), "Must specify either pretrained name or provide a model config and checkpoint, but not both"
764
+
765
+ if model_config_path is not None:
766
+ # Load config from json file
767
+ with open(model_config_path) as f:
768
+ model_config = json.load(f)
769
+ else:
770
+ model_config = None
771
+
772
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
773
+ _, model_config = load_model(model_config, ckpt_path, pretrained_name=pretrained_name, pretransform_ckpt_path=pretransform_ckpt_path, model_half=model_half, device=device)
774
+
775
+ model_type = model_config["model_type"]
776
+
777
+ if model_type == "diffusion_cond":
778
+ ui = create_txt2audio_ui(model_config)
779
+ elif model_type == "diffusion_uncond":
780
+ ui = create_diffusion_uncond_ui(model_config)
781
+ elif model_type == "autoencoder" or model_type == "diffusion_autoencoder":
782
+ ui = create_autoencoder_ui(model_config)
783
+ elif model_type == "diffusion_prior":
784
+ ui = create_diffusion_prior_ui(model_config)
785
+ elif model_type == "lm":
786
+ ui = create_lm_ui(model_config)
787
+
788
+ return ui
stable_audio_tools/interface/testing.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import numpy as np
3
+ import json
4
+ import torch
5
+ import torchaudio
6
+ import os
7
+ import re
8
+
9
+ from aeiou.viz import audio_spectrogram_image
10
+ from einops import rearrange
11
+ from safetensors.torch import load_file
12
+ from torch.nn import functional as F
13
+ from torchaudio import transforms as T
14
+
15
+ from ..inference.generation import generate_diffusion_cond, generate_diffusion_uncond
16
+ from ..models.factory import create_model_from_config
17
+ from ..models.pretrained import get_pretrained_model
18
+ from ..models.utils import load_ckpt_state_dict
19
+ from ..inference.utils import prepare_audio
20
+ from ..training.utils import copy_state_dict
21
+
22
+
23
+ model = None
24
+ sample_rate = 44100
25
+ sample_size = 524288
26
+
27
+ def load_model(model_config=None, model_ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, device="cuda", model_half=False):
28
+ global model, sample_rate, sample_size
29
+
30
+ if pretrained_name is not None:
31
+ print(f"Loading pretrained model {pretrained_name}")
32
+ model, model_config = get_pretrained_model(pretrained_name)
33
+
34
+ elif model_config is not None and model_ckpt_path is not None:
35
+ print(f"Creating model from config")
36
+ model = create_model_from_config(model_config)
37
+
38
+ print(f"Loading model checkpoint from {model_ckpt_path}")
39
+ # Load checkpoint
40
+ copy_state_dict(model, load_ckpt_state_dict(model_ckpt_path))
41
+ #model.load_state_dict(load_ckpt_state_dict(model_ckpt_path))
42
+
43
+ sample_rate = model_config["sample_rate"]
44
+ sample_size = model_config["sample_size"]
45
+
46
+ if pretransform_ckpt_path is not None:
47
+ print(f"Loading pretransform checkpoint from {pretransform_ckpt_path}")
48
+ model.pretransform.load_state_dict(load_ckpt_state_dict(pretransform_ckpt_path), strict=False)
49
+ print(f"Done loading pretransform")
50
+
51
+ model.to(device).eval().requires_grad_(False)
52
+
53
+ if model_half:
54
+ model.to(torch.float16)
55
+
56
+ print(f"Done loading model")
57
+
58
+ return model, model_config
59
+
60
+ def generate_cond_with_path(
61
+ prompt,
62
+ negative_prompt=None,
63
+ seconds_start=0,
64
+ seconds_total=30,
65
+ latitude = 0.0,
66
+ longitude = 0.0,
67
+ temperature = 0.0,
68
+ humidity = 0.0,
69
+ wind_speed = 0.0,
70
+ pressure = 0.0,
71
+ minutes_of_day = 0.0,
72
+ day_of_year = 0.0,
73
+ cfg_scale=6.0,
74
+ steps=250,
75
+ preview_every=None,
76
+ seed=-1,
77
+ sampler_type="dpmpp-2m-sde",
78
+ sigma_min=0.03,
79
+ sigma_max=50,
80
+ cfg_rescale=0.4,
81
+ use_init=False,
82
+ init_audio=None,
83
+ init_noise_level=1.0,
84
+ mask_cropfrom=None,
85
+ mask_pastefrom=None,
86
+ mask_pasteto=None,
87
+ mask_maskstart=None,
88
+ mask_maskend=None,
89
+ mask_softnessL=None,
90
+ mask_softnessR=None,
91
+ mask_marination=None,
92
+ batch_size=1,
93
+ destination_folder=None,
94
+ file_name=None
95
+ ):
96
+
97
+ if torch.cuda.is_available():
98
+ torch.cuda.empty_cache()
99
+ gc.collect()
100
+
101
+ print(f"Prompt: {prompt}")
102
+
103
+ global preview_images
104
+ preview_images = []
105
+ if preview_every == 0:
106
+ preview_every = None
107
+
108
+ # Return fake stereo audio
109
+ conditioning = [{"prompt": prompt, "latitude": latitude, "longitude": longitude, "temperature": temperature, "humidity": humidity, "wind_speed": wind_speed, "pressure": pressure, "minutes_of_day": minutes_of_day,"day_of_year": day_of_year, "seconds_start":seconds_start, "seconds_total": seconds_total }] * batch_size
110
+
111
+ if negative_prompt:
112
+ negative_conditioning = [{"prompt": negative_prompt, "latitude": latitude, "longitude": longitude, "temperature": temperature, "humidity": humidity, "wind_speed": wind_speed, "pressure": pressure, "minutes_of_day": minutes_of_day,"day_of_year": day_of_year, "seconds_start":seconds_start, "seconds_total": seconds_total}] * batch_size
113
+ else:
114
+ negative_conditioning = None
115
+
116
+ #Get the device from the model
117
+ device = next(model.parameters()).device
118
+
119
+ seed = int(seed)
120
+
121
+ if not use_init:
122
+ init_audio = None
123
+
124
+ input_sample_size = sample_size
125
+
126
+ if init_audio is not None:
127
+ in_sr, init_audio = init_audio
128
+ # Turn into torch tensor, converting from int16 to float32
129
+ init_audio = torch.from_numpy(init_audio).float().div(32767)
130
+
131
+ if init_audio.dim() == 1:
132
+ init_audio = init_audio.unsqueeze(0) # [1, n]
133
+ elif init_audio.dim() == 2:
134
+ init_audio = init_audio.transpose(0, 1) # [n, 2] -> [2, n]
135
+
136
+ if in_sr != sample_rate:
137
+ resample_tf = T.Resample(in_sr, sample_rate).to(init_audio.device)
138
+ init_audio = resample_tf(init_audio)
139
+
140
+ audio_length = init_audio.shape[-1]
141
+
142
+ if audio_length > sample_size:
143
+
144
+ input_sample_size = audio_length + (model.min_input_length - (audio_length % model.min_input_length)) % model.min_input_length
145
+
146
+ init_audio = (sample_rate, init_audio)
147
+
148
+ def progress_callback(callback_info):
149
+ global preview_images
150
+ denoised = callback_info["denoised"]
151
+ current_step = callback_info["i"]
152
+ sigma = callback_info["sigma"]
153
+
154
+ if (current_step - 1) % preview_every == 0:
155
+ if model.pretransform is not None:
156
+ denoised = model.pretransform.decode(denoised)
157
+ denoised = rearrange(denoised, "b d n -> d (b n)")
158
+ denoised = denoised.clamp(-1, 1).mul(32767).to(torch.int16).cpu()
159
+ audio_spectrogram = audio_spectrogram_image(denoised, sample_rate=sample_rate)
160
+ preview_images.append((audio_spectrogram, f"Step {current_step} sigma={sigma:.3f})"))
161
+
162
+ # If inpainting, send mask args
163
+ # This will definitely change in the future
164
+ if mask_cropfrom is not None:
165
+ mask_args = {
166
+ "cropfrom": mask_cropfrom,
167
+ "pastefrom": mask_pastefrom,
168
+ "pasteto": mask_pasteto,
169
+ "maskstart": mask_maskstart,
170
+ "maskend": mask_maskend,
171
+ "softnessL": mask_softnessL,
172
+ "softnessR": mask_softnessR,
173
+ "marination": mask_marination,
174
+ }
175
+ else:
176
+ mask_args = None
177
+
178
+ # Do the audio generation
179
+ audio = generate_diffusion_cond(
180
+ model,
181
+ conditioning=conditioning,
182
+ negative_conditioning=negative_conditioning,
183
+ steps=steps,
184
+ cfg_scale=cfg_scale,
185
+ batch_size=batch_size,
186
+ sample_size=input_sample_size,
187
+ sample_rate=sample_rate,
188
+ seed=seed,
189
+ device=device,
190
+ sampler_type=sampler_type,
191
+ sigma_min=sigma_min,
192
+ sigma_max=sigma_max,
193
+ init_audio=init_audio,
194
+ init_noise_level=init_noise_level,
195
+ mask_args = mask_args,
196
+ callback = progress_callback if preview_every is not None else None,
197
+ scale_phi = cfg_rescale
198
+ )
199
+
200
+ # Convert to WAV file
201
+ audio = rearrange(audio, "b d n -> d (b n)")
202
+ audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
203
+ #save to the desired folder with the required filename and add the .wav extension
204
+
205
+ if destination_folder is not None and file_name is not None:
206
+ torchaudio.save(f"{destination_folder}/{file_name}.wav", audio, sample_rate)
207
+
208
+
209
+
210
+ # Let's look at a nice spectrogram too
211
+ # audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate)
212
+
213
+ # return ("output.wav", [audio_spectrogram, *preview_images])
214
+
215
+
216
+
217
+ def generate_lm(
218
+ temperature=1.0,
219
+ top_p=0.95,
220
+ top_k=0,
221
+ batch_size=1,
222
+ ):
223
+
224
+ if torch.cuda.is_available():
225
+ torch.cuda.empty_cache()
226
+ gc.collect()
227
+
228
+ #Get the device from the model
229
+ device = next(model.parameters()).device
230
+
231
+ audio = model.generate_audio(
232
+ batch_size=batch_size,
233
+ max_gen_len = sample_size//model.pretransform.downsampling_ratio,
234
+ conditioning=None,
235
+ temp=temperature,
236
+ top_p=top_p,
237
+ top_k=top_k,
238
+ use_cache=True
239
+ )
240
+
241
+ audio = rearrange(audio, "b d n -> d (b n)")
242
+
243
+ audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
244
+
245
+ torchaudio.save("output.wav", audio, sample_rate)
246
+
247
+ audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate)
248
+
249
+ return ("output.wav", [audio_spectrogram])
250
+
251
+
252
+
253
+
254
+ def autoencoder_process(audio, latent_noise, n_quantizers):
255
+ if torch.cuda.is_available():
256
+ torch.cuda.empty_cache()
257
+ gc.collect()
258
+
259
+ #Get the device from the model
260
+ device = next(model.parameters()).device
261
+
262
+ in_sr, audio = audio
263
+
264
+ audio = torch.from_numpy(audio).float().div(32767).to(device)
265
+
266
+ if audio.dim() == 1:
267
+ audio = audio.unsqueeze(0)
268
+ else:
269
+ audio = audio.transpose(0, 1)
270
+
271
+ audio = model.preprocess_audio_for_encoder(audio, in_sr)
272
+ # Note: If you need to do chunked encoding, to reduce VRAM,
273
+ # then add these arguments to encode_audio and decode_audio: chunked=True, overlap=32, chunk_size=128
274
+ # To turn it off, do chunked=False
275
+ # Optimal overlap and chunk_size values will depend on the model.
276
+ # See encode_audio & decode_audio in autoencoders.py for more info
277
+ # Get dtype of model
278
+ dtype = next(model.parameters()).dtype
279
+
280
+ audio = audio.to(dtype)
281
+
282
+ if n_quantizers > 0:
283
+ latents = model.encode_audio(audio, chunked=False, n_quantizers=n_quantizers)
284
+ else:
285
+ latents = model.encode_audio(audio, chunked=False)
286
+
287
+ if latent_noise > 0:
288
+ latents = latents + torch.randn_like(latents) * latent_noise
289
+
290
+ audio = model.decode_audio(latents, chunked=False)
291
+
292
+ audio = rearrange(audio, "b d n -> d (b n)")
293
+
294
+ audio = audio.to(torch.float32).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
295
+
296
+ torchaudio.save("output.wav", audio, sample_rate)
297
+
298
+ return "output.wav"
299
+
300
+ def load_and_generate(model_path, json_dir, output_dir):
301
+ """Load JSON files and generate audio for each set of conditions."""
302
+ # List all files in the json_dir
303
+ files = os.listdir(json_dir)
304
+
305
+ # Filter for JSON files
306
+ json_files = [file for file in files if file.endswith('.json')]
307
+
308
+ if not json_files:
309
+ print(f"No JSON files found in {json_dir}. Please check the directory path and file permissions.")
310
+ return
311
+
312
+ for json_filename in json_files:
313
+ json_file_path = os.path.join(json_dir, json_filename)
314
+
315
+ try:
316
+ with open(json_file_path, 'r') as file:
317
+ data = json.load(file)
318
+ except Exception as e:
319
+ print(f"Failed to read or parse {json_file_path}: {e}")
320
+ continue
321
+
322
+ # Print the JSON path
323
+ print(json_file_path)
324
+
325
+ # Extract conditions from JSON
326
+ conditions = {
327
+ 'birdSpecies': data['birdSpecies'],
328
+ 'latitude': data['coord']['lat'],
329
+ 'longitude': data['coord']['lon'],
330
+ 'temperature': data['main']['temp'],
331
+ 'humidity': data['main']['humidity'],
332
+ 'pressure': data['main']['pressure'],
333
+ 'wind_speed': data['wind']['speed'],
334
+ 'day_of_year': data['dayOfYear'],
335
+ 'minutes_of_day': data['minutesOfDay']
336
+ }
337
+
338
+ # Extract base filename components
339
+ step_number = re.search(r'step=(\d+)', model_path).group(1)
340
+ bird_species = conditions['birdSpecies'].replace(' ', '_')
341
+ base_filename = f"{bird_species}_{os.path.splitext(json_filename)[0]}_{step_number}_cfg_scale_"
342
+
343
+
344
+
345
+ #An array of cfg scale values to test
346
+ cfg_scales = [1.8, 2.5, 4.0, 5.0, 12.0]
347
+
348
+ # Generate audio we do this 4 times with a loop
349
+ for scale in cfg_scales:
350
+ generate_cond_with_path(prompt = "",
351
+ negative_prompt="",
352
+ seconds_start=0,
353
+ seconds_total=22,
354
+ latitude = conditions['latitude'],
355
+ longitude = conditions['longitude'],
356
+ temperature = conditions['temperature'],
357
+ humidity = conditions['humidity'],
358
+ wind_speed = conditions['wind_speed'],
359
+ pressure = conditions['pressure'],
360
+ minutes_of_day = conditions['minutes_of_day'],
361
+ day_of_year = conditions['day_of_year'],
362
+ cfg_scale=scale,
363
+ steps=250,
364
+ preview_every=None,
365
+ seed=-1,
366
+ sampler_type="dpmpp-2m-sde",
367
+ sigma_min=0.03,
368
+ sigma_max=50,
369
+ cfg_rescale=0.4,
370
+ use_init=False,
371
+ init_audio=None,
372
+ init_noise_level=1.0,
373
+ mask_cropfrom=None,
374
+ mask_pastefrom=None,
375
+ mask_pasteto=None,
376
+ mask_maskstart=None,
377
+ mask_maskend=None,
378
+ mask_softnessL=None,
379
+ mask_softnessR=None,
380
+ mask_marination=None,
381
+ batch_size=1,
382
+ destination_folder=output_dir,
383
+ file_name=base_filename + str(scale))
384
+
385
+
386
+ def runTests(model_config_path=None, ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, model_half=False, json_dir=None, output_dir=None):
387
+ assert (pretrained_name is not None) ^ (model_config_path is not None and ckpt_path is not None), "Must specify either pretrained name or provide a model config and checkpoint, but not both"
388
+
389
+ if model_config_path is not None:
390
+ # Load config from json file
391
+ with open(model_config_path) as f:
392
+ model_config = json.load(f)
393
+ else:
394
+ model_config = None
395
+
396
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
397
+ _, model_config = load_model(model_config, ckpt_path, pretrained_name=pretrained_name, pretransform_ckpt_path=pretransform_ckpt_path, model_half=model_half, device=device)
398
+
399
+ # Ensure output directory exists- os.makedirs(args.output_dir, exist_ok=True)
400
+
401
+ # Process all JSON files and generate audio
402
+ load_and_generate(ckpt_path, json_dir, output_dir)
403
+
404
+
405
+
406
+
407
+
408
+
409
+
stable_audio_tools/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .factory import create_model_from_config, create_model_from_config_path
stable_audio_tools/models/adp.py ADDED
@@ -0,0 +1,1588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied and modified from https://github.com/archinetai/audio-diffusion-pytorch/blob/v0.0.94/audio_diffusion_pytorch/modules.py under MIT License
2
+ # License can be found in LICENSES/LICENSE_ADP.txt
3
+
4
+ import math
5
+ from inspect import isfunction
6
+ from math import ceil, floor, log, pi, log2
7
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
8
+ from packaging import version
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from einops import rearrange, reduce, repeat
13
+ from einops.layers.torch import Rearrange
14
+ from einops_exts import rearrange_many
15
+ from torch import Tensor, einsum
16
+ from torch.backends.cuda import sdp_kernel
17
+ from torch.nn import functional as F
18
+ from dac.nn.layers import Snake1d
19
+
20
+ """
21
+ Utils
22
+ """
23
+
24
+
25
+ class ConditionedSequential(nn.Module):
26
+ def __init__(self, *modules):
27
+ super().__init__()
28
+ self.module_list = nn.ModuleList(*modules)
29
+
30
+ def forward(self, x: Tensor, mapping: Optional[Tensor] = None):
31
+ for module in self.module_list:
32
+ x = module(x, mapping)
33
+ return x
34
+
35
+ T = TypeVar("T")
36
+
37
+ def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T:
38
+ if exists(val):
39
+ return val
40
+ return d() if isfunction(d) else d
41
+
42
+ def exists(val: Optional[T]) -> T:
43
+ return val is not None
44
+
45
+ def closest_power_2(x: float) -> int:
46
+ exponent = log2(x)
47
+ distance_fn = lambda z: abs(x - 2 ** z) # noqa
48
+ exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn)
49
+ return 2 ** int(exponent_closest)
50
+
51
+ def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]:
52
+ return_dicts: Tuple[Dict, Dict] = ({}, {})
53
+ for key in d.keys():
54
+ no_prefix = int(not key.startswith(prefix))
55
+ return_dicts[no_prefix][key] = d[key]
56
+ return return_dicts
57
+
58
+ def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]:
59
+ kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d)
60
+ if keep_prefix:
61
+ return kwargs_with_prefix, kwargs
62
+ kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()}
63
+ return kwargs_no_prefix, kwargs
64
+
65
+ """
66
+ Convolutional Blocks
67
+ """
68
+ import typing as tp
69
+
70
+ # Copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conv.py under MIT License
71
+ # License available in LICENSES/LICENSE_META.txt
72
+
73
+ def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
74
+ padding_total: int = 0) -> int:
75
+ """See `pad_for_conv1d`."""
76
+ length = x.shape[-1]
77
+ n_frames = (length - kernel_size + padding_total) / stride + 1
78
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
79
+ return ideal_length - length
80
+
81
+
82
+ def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
83
+ """Pad for a convolution to make sure that the last window is full.
84
+ Extra padding is added at the end. This is required to ensure that we can rebuild
85
+ an output of the same length, as otherwise, even with padding, some time steps
86
+ might get removed.
87
+ For instance, with total padding = 4, kernel size = 4, stride = 2:
88
+ 0 0 1 2 3 4 5 0 0 # (0s are padding)
89
+ 1 2 3 # (output frames of a convolution, last 0 is never used)
90
+ 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
91
+ 1 2 3 4 # once you removed padding, we are missing one time step !
92
+ """
93
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
94
+ return F.pad(x, (0, extra_padding))
95
+
96
+
97
+ def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
98
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
99
+ If this is the case, we insert extra 0 padding to the right before the reflection happen.
100
+ """
101
+ length = x.shape[-1]
102
+ padding_left, padding_right = paddings
103
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
104
+ if mode == 'reflect':
105
+ max_pad = max(padding_left, padding_right)
106
+ extra_pad = 0
107
+ if length <= max_pad:
108
+ extra_pad = max_pad - length + 1
109
+ x = F.pad(x, (0, extra_pad))
110
+ padded = F.pad(x, paddings, mode, value)
111
+ end = padded.shape[-1] - extra_pad
112
+ return padded[..., :end]
113
+ else:
114
+ return F.pad(x, paddings, mode, value)
115
+
116
+
117
+ def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
118
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
119
+ padding_left, padding_right = paddings
120
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
121
+ assert (padding_left + padding_right) <= x.shape[-1]
122
+ end = x.shape[-1] - padding_right
123
+ return x[..., padding_left: end]
124
+
125
+
126
+ class Conv1d(nn.Conv1d):
127
+ def __init__(self, *args, **kwargs):
128
+ super().__init__(*args, **kwargs)
129
+
130
+ def forward(self, x: Tensor, causal=False) -> Tensor:
131
+ kernel_size = self.kernel_size[0]
132
+ stride = self.stride[0]
133
+ dilation = self.dilation[0]
134
+ kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
135
+ padding_total = kernel_size - stride
136
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
137
+ if causal:
138
+ # Left padding for causal
139
+ x = pad1d(x, (padding_total, extra_padding))
140
+ else:
141
+ # Asymmetric padding required for odd strides
142
+ padding_right = padding_total // 2
143
+ padding_left = padding_total - padding_right
144
+ x = pad1d(x, (padding_left, padding_right + extra_padding))
145
+ return super().forward(x)
146
+
147
+ class ConvTranspose1d(nn.ConvTranspose1d):
148
+ def __init__(self, *args, **kwargs):
149
+ super().__init__(*args, **kwargs)
150
+
151
+ def forward(self, x: Tensor, causal=False) -> Tensor:
152
+ kernel_size = self.kernel_size[0]
153
+ stride = self.stride[0]
154
+ padding_total = kernel_size - stride
155
+
156
+ y = super().forward(x)
157
+
158
+ # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
159
+ # removed at the very end, when keeping only the right length for the output,
160
+ # as removing it here would require also passing the length at the matching layer
161
+ # in the encoder.
162
+ if causal:
163
+ padding_right = ceil(padding_total)
164
+ padding_left = padding_total - padding_right
165
+ y = unpad1d(y, (padding_left, padding_right))
166
+ else:
167
+ # Asymmetric padding required for odd strides
168
+ padding_right = padding_total // 2
169
+ padding_left = padding_total - padding_right
170
+ y = unpad1d(y, (padding_left, padding_right))
171
+ return y
172
+
173
+
174
+ def Downsample1d(
175
+ in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
176
+ ) -> nn.Module:
177
+ assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even"
178
+
179
+ return Conv1d(
180
+ in_channels=in_channels,
181
+ out_channels=out_channels,
182
+ kernel_size=factor * kernel_multiplier + 1,
183
+ stride=factor
184
+ )
185
+
186
+
187
+ def Upsample1d(
188
+ in_channels: int, out_channels: int, factor: int, use_nearest: bool = False
189
+ ) -> nn.Module:
190
+
191
+ if factor == 1:
192
+ return Conv1d(
193
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3
194
+ )
195
+
196
+ if use_nearest:
197
+ return nn.Sequential(
198
+ nn.Upsample(scale_factor=factor, mode="nearest"),
199
+ Conv1d(
200
+ in_channels=in_channels,
201
+ out_channels=out_channels,
202
+ kernel_size=3
203
+ ),
204
+ )
205
+ else:
206
+ return ConvTranspose1d(
207
+ in_channels=in_channels,
208
+ out_channels=out_channels,
209
+ kernel_size=factor * 2,
210
+ stride=factor
211
+ )
212
+
213
+
214
+ class ConvBlock1d(nn.Module):
215
+ def __init__(
216
+ self,
217
+ in_channels: int,
218
+ out_channels: int,
219
+ *,
220
+ kernel_size: int = 3,
221
+ stride: int = 1,
222
+ dilation: int = 1,
223
+ num_groups: int = 8,
224
+ use_norm: bool = True,
225
+ use_snake: bool = False
226
+ ) -> None:
227
+ super().__init__()
228
+
229
+ self.groupnorm = (
230
+ nn.GroupNorm(num_groups=num_groups, num_channels=in_channels)
231
+ if use_norm
232
+ else nn.Identity()
233
+ )
234
+
235
+ if use_snake:
236
+ self.activation = Snake1d(in_channels)
237
+ else:
238
+ self.activation = nn.SiLU()
239
+
240
+ self.project = Conv1d(
241
+ in_channels=in_channels,
242
+ out_channels=out_channels,
243
+ kernel_size=kernel_size,
244
+ stride=stride,
245
+ dilation=dilation,
246
+ )
247
+
248
+ def forward(
249
+ self, x: Tensor, scale_shift: Optional[Tuple[Tensor, Tensor]] = None, causal=False
250
+ ) -> Tensor:
251
+ x = self.groupnorm(x)
252
+ if exists(scale_shift):
253
+ scale, shift = scale_shift
254
+ x = x * (scale + 1) + shift
255
+ x = self.activation(x)
256
+ return self.project(x, causal=causal)
257
+
258
+
259
+ class MappingToScaleShift(nn.Module):
260
+ def __init__(
261
+ self,
262
+ features: int,
263
+ channels: int,
264
+ ):
265
+ super().__init__()
266
+
267
+ self.to_scale_shift = nn.Sequential(
268
+ nn.SiLU(),
269
+ nn.Linear(in_features=features, out_features=channels * 2),
270
+ )
271
+
272
+ def forward(self, mapping: Tensor) -> Tuple[Tensor, Tensor]:
273
+ scale_shift = self.to_scale_shift(mapping)
274
+ scale_shift = rearrange(scale_shift, "b c -> b c 1")
275
+ scale, shift = scale_shift.chunk(2, dim=1)
276
+ return scale, shift
277
+
278
+
279
+ class ResnetBlock1d(nn.Module):
280
+ def __init__(
281
+ self,
282
+ in_channels: int,
283
+ out_channels: int,
284
+ *,
285
+ kernel_size: int = 3,
286
+ stride: int = 1,
287
+ dilation: int = 1,
288
+ use_norm: bool = True,
289
+ use_snake: bool = False,
290
+ num_groups: int = 8,
291
+ context_mapping_features: Optional[int] = None,
292
+ ) -> None:
293
+ super().__init__()
294
+
295
+ self.use_mapping = exists(context_mapping_features)
296
+
297
+ self.block1 = ConvBlock1d(
298
+ in_channels=in_channels,
299
+ out_channels=out_channels,
300
+ kernel_size=kernel_size,
301
+ stride=stride,
302
+ dilation=dilation,
303
+ use_norm=use_norm,
304
+ num_groups=num_groups,
305
+ use_snake=use_snake
306
+ )
307
+
308
+ if self.use_mapping:
309
+ assert exists(context_mapping_features)
310
+ self.to_scale_shift = MappingToScaleShift(
311
+ features=context_mapping_features, channels=out_channels
312
+ )
313
+
314
+ self.block2 = ConvBlock1d(
315
+ in_channels=out_channels,
316
+ out_channels=out_channels,
317
+ use_norm=use_norm,
318
+ num_groups=num_groups,
319
+ use_snake=use_snake
320
+ )
321
+
322
+ self.to_out = (
323
+ Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
324
+ if in_channels != out_channels
325
+ else nn.Identity()
326
+ )
327
+
328
+ def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor:
329
+ assert_message = "context mapping required if context_mapping_features > 0"
330
+ assert not (self.use_mapping ^ exists(mapping)), assert_message
331
+
332
+ h = self.block1(x, causal=causal)
333
+
334
+ scale_shift = None
335
+ if self.use_mapping:
336
+ scale_shift = self.to_scale_shift(mapping)
337
+
338
+ h = self.block2(h, scale_shift=scale_shift, causal=causal)
339
+
340
+ return h + self.to_out(x)
341
+
342
+
343
+ class Patcher(nn.Module):
344
+ def __init__(
345
+ self,
346
+ in_channels: int,
347
+ out_channels: int,
348
+ patch_size: int,
349
+ context_mapping_features: Optional[int] = None,
350
+ use_snake: bool = False,
351
+ ):
352
+ super().__init__()
353
+ assert_message = f"out_channels must be divisible by patch_size ({patch_size})"
354
+ assert out_channels % patch_size == 0, assert_message
355
+ self.patch_size = patch_size
356
+
357
+ self.block = ResnetBlock1d(
358
+ in_channels=in_channels,
359
+ out_channels=out_channels // patch_size,
360
+ num_groups=1,
361
+ context_mapping_features=context_mapping_features,
362
+ use_snake=use_snake
363
+ )
364
+
365
+ def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor:
366
+ x = self.block(x, mapping, causal=causal)
367
+ x = rearrange(x, "b c (l p) -> b (c p) l", p=self.patch_size)
368
+ return x
369
+
370
+
371
+ class Unpatcher(nn.Module):
372
+ def __init__(
373
+ self,
374
+ in_channels: int,
375
+ out_channels: int,
376
+ patch_size: int,
377
+ context_mapping_features: Optional[int] = None,
378
+ use_snake: bool = False
379
+ ):
380
+ super().__init__()
381
+ assert_message = f"in_channels must be divisible by patch_size ({patch_size})"
382
+ assert in_channels % patch_size == 0, assert_message
383
+ self.patch_size = patch_size
384
+
385
+ self.block = ResnetBlock1d(
386
+ in_channels=in_channels // patch_size,
387
+ out_channels=out_channels,
388
+ num_groups=1,
389
+ context_mapping_features=context_mapping_features,
390
+ use_snake=use_snake
391
+ )
392
+
393
+ def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor:
394
+ x = rearrange(x, " b (c p) l -> b c (l p) ", p=self.patch_size)
395
+ x = self.block(x, mapping, causal=causal)
396
+ return x
397
+
398
+
399
+ """
400
+ Attention Components
401
+ """
402
+ def FeedForward(features: int, multiplier: int) -> nn.Module:
403
+ mid_features = features * multiplier
404
+ return nn.Sequential(
405
+ nn.Linear(in_features=features, out_features=mid_features),
406
+ nn.GELU(),
407
+ nn.Linear(in_features=mid_features, out_features=features),
408
+ )
409
+
410
+ def add_mask(sim: Tensor, mask: Tensor) -> Tensor:
411
+ b, ndim = sim.shape[0], mask.ndim
412
+ if ndim == 3:
413
+ mask = rearrange(mask, "b n m -> b 1 n m")
414
+ if ndim == 2:
415
+ mask = repeat(mask, "n m -> b 1 n m", b=b)
416
+ max_neg_value = -torch.finfo(sim.dtype).max
417
+ sim = sim.masked_fill(~mask, max_neg_value)
418
+ return sim
419
+
420
+ def causal_mask(q: Tensor, k: Tensor) -> Tensor:
421
+ b, i, j, device = q.shape[0], q.shape[-2], k.shape[-2], q.device
422
+ mask = ~torch.ones((i, j), dtype=torch.bool, device=device).triu(j - i + 1)
423
+ mask = repeat(mask, "n m -> b n m", b=b)
424
+ return mask
425
+
426
+ class AttentionBase(nn.Module):
427
+ def __init__(
428
+ self,
429
+ features: int,
430
+ *,
431
+ head_features: int,
432
+ num_heads: int,
433
+ out_features: Optional[int] = None,
434
+ ):
435
+ super().__init__()
436
+ self.scale = head_features**-0.5
437
+ self.num_heads = num_heads
438
+ mid_features = head_features * num_heads
439
+ out_features = default(out_features, features)
440
+
441
+ self.to_out = nn.Linear(
442
+ in_features=mid_features, out_features=out_features
443
+ )
444
+
445
+ self.use_flash = False
446
+
447
+ if not self.use_flash:
448
+ return
449
+
450
+ device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
451
+
452
+ if device_properties.major == 8 and device_properties.minor == 0:
453
+ # Use flash attention for A100 GPUs
454
+ self.sdp_kernel_config = (False, True, True)
455
+ else:
456
+ # Don't use flash attention for other GPUs
457
+ self.sdp_kernel_config = (False, True, True)
458
+
459
+ def forward(
460
+ self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, is_causal: bool = False
461
+ ) -> Tensor:
462
+ # Split heads
463
+ q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads)
464
+
465
+ if not self.use_flash:
466
+ if is_causal and not mask:
467
+ # Mask out future tokens for causal attention
468
+ mask = causal_mask(q, k)
469
+
470
+ # Compute similarity matrix and add eventual mask
471
+ sim = einsum("... n d, ... m d -> ... n m", q, k) * self.scale
472
+ sim = add_mask(sim, mask) if exists(mask) else sim
473
+
474
+ # Get attention matrix with softmax
475
+ attn = sim.softmax(dim=-1, dtype=torch.float32)
476
+
477
+ # Compute values
478
+ out = einsum("... n m, ... m d -> ... n d", attn, v)
479
+ else:
480
+ with sdp_kernel(*self.sdp_kernel_config):
481
+ out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, is_causal=is_causal)
482
+
483
+ out = rearrange(out, "b h n d -> b n (h d)")
484
+ return self.to_out(out)
485
+
486
+ class Attention(nn.Module):
487
+ def __init__(
488
+ self,
489
+ features: int,
490
+ *,
491
+ head_features: int,
492
+ num_heads: int,
493
+ out_features: Optional[int] = None,
494
+ context_features: Optional[int] = None,
495
+ causal: bool = False,
496
+ ):
497
+ super().__init__()
498
+ self.context_features = context_features
499
+ self.causal = causal
500
+ mid_features = head_features * num_heads
501
+ context_features = default(context_features, features)
502
+
503
+ self.norm = nn.LayerNorm(features)
504
+ self.norm_context = nn.LayerNorm(context_features)
505
+ self.to_q = nn.Linear(
506
+ in_features=features, out_features=mid_features, bias=False
507
+ )
508
+ self.to_kv = nn.Linear(
509
+ in_features=context_features, out_features=mid_features * 2, bias=False
510
+ )
511
+ self.attention = AttentionBase(
512
+ features,
513
+ num_heads=num_heads,
514
+ head_features=head_features,
515
+ out_features=out_features,
516
+ )
517
+
518
+ def forward(
519
+ self,
520
+ x: Tensor, # [b, n, c]
521
+ context: Optional[Tensor] = None, # [b, m, d]
522
+ context_mask: Optional[Tensor] = None, # [b, m], false is masked,
523
+ causal: Optional[bool] = False,
524
+ ) -> Tensor:
525
+ assert_message = "You must provide a context when using context_features"
526
+ assert not self.context_features or exists(context), assert_message
527
+ # Use context if provided
528
+ context = default(context, x)
529
+ # Normalize then compute q from input and k,v from context
530
+ x, context = self.norm(x), self.norm_context(context)
531
+
532
+ q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
533
+
534
+ if exists(context_mask):
535
+ # Mask out cross-attention for padding tokens
536
+ mask = repeat(context_mask, "b m -> b m d", d=v.shape[-1])
537
+ k, v = k * mask, v * mask
538
+
539
+ # Compute and return attention
540
+ return self.attention(q, k, v, is_causal=self.causal or causal)
541
+
542
+
543
+ def FeedForward(features: int, multiplier: int) -> nn.Module:
544
+ mid_features = features * multiplier
545
+ return nn.Sequential(
546
+ nn.Linear(in_features=features, out_features=mid_features),
547
+ nn.GELU(),
548
+ nn.Linear(in_features=mid_features, out_features=features),
549
+ )
550
+
551
+ """
552
+ Transformer Blocks
553
+ """
554
+
555
+
556
+ class TransformerBlock(nn.Module):
557
+ def __init__(
558
+ self,
559
+ features: int,
560
+ num_heads: int,
561
+ head_features: int,
562
+ multiplier: int,
563
+ context_features: Optional[int] = None,
564
+ ):
565
+ super().__init__()
566
+
567
+ self.use_cross_attention = exists(context_features) and context_features > 0
568
+
569
+ self.attention = Attention(
570
+ features=features,
571
+ num_heads=num_heads,
572
+ head_features=head_features
573
+ )
574
+
575
+ if self.use_cross_attention:
576
+ self.cross_attention = Attention(
577
+ features=features,
578
+ num_heads=num_heads,
579
+ head_features=head_features,
580
+ context_features=context_features
581
+ )
582
+
583
+ self.feed_forward = FeedForward(features=features, multiplier=multiplier)
584
+
585
+ def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal: Optional[bool] = False) -> Tensor:
586
+ x = self.attention(x, causal=causal) + x
587
+ if self.use_cross_attention:
588
+ x = self.cross_attention(x, context=context, context_mask=context_mask) + x
589
+ x = self.feed_forward(x) + x
590
+ return x
591
+
592
+
593
+ """
594
+ Transformers
595
+ """
596
+
597
+
598
+ class Transformer1d(nn.Module):
599
+ def __init__(
600
+ self,
601
+ num_layers: int,
602
+ channels: int,
603
+ num_heads: int,
604
+ head_features: int,
605
+ multiplier: int,
606
+ context_features: Optional[int] = None,
607
+ ):
608
+ super().__init__()
609
+
610
+ self.to_in = nn.Sequential(
611
+ nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True),
612
+ Conv1d(
613
+ in_channels=channels,
614
+ out_channels=channels,
615
+ kernel_size=1,
616
+ ),
617
+ Rearrange("b c t -> b t c"),
618
+ )
619
+
620
+ self.blocks = nn.ModuleList(
621
+ [
622
+ TransformerBlock(
623
+ features=channels,
624
+ head_features=head_features,
625
+ num_heads=num_heads,
626
+ multiplier=multiplier,
627
+ context_features=context_features,
628
+ )
629
+ for i in range(num_layers)
630
+ ]
631
+ )
632
+
633
+ self.to_out = nn.Sequential(
634
+ Rearrange("b t c -> b c t"),
635
+ Conv1d(
636
+ in_channels=channels,
637
+ out_channels=channels,
638
+ kernel_size=1,
639
+ ),
640
+ )
641
+
642
+ def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal=False) -> Tensor:
643
+ x = self.to_in(x)
644
+ for block in self.blocks:
645
+ x = block(x, context=context, context_mask=context_mask, causal=causal)
646
+ x = self.to_out(x)
647
+ return x
648
+
649
+
650
+ """
651
+ Time Embeddings
652
+ """
653
+
654
+
655
+ class SinusoidalEmbedding(nn.Module):
656
+ def __init__(self, dim: int):
657
+ super().__init__()
658
+ self.dim = dim
659
+
660
+ def forward(self, x: Tensor) -> Tensor:
661
+ device, half_dim = x.device, self.dim // 2
662
+ emb = torch.tensor(log(10000) / (half_dim - 1), device=device)
663
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
664
+ emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j")
665
+ return torch.cat((emb.sin(), emb.cos()), dim=-1)
666
+
667
+
668
+ class LearnedPositionalEmbedding(nn.Module):
669
+ """Used for continuous time"""
670
+
671
+ def __init__(self, dim: int):
672
+ super().__init__()
673
+ assert (dim % 2) == 0
674
+ half_dim = dim // 2
675
+ self.weights = nn.Parameter(torch.randn(half_dim))
676
+
677
+ def forward(self, x: Tensor) -> Tensor:
678
+ x = rearrange(x, "b -> b 1")
679
+ freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi
680
+ fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
681
+ fouriered = torch.cat((x, fouriered), dim=-1)
682
+ return fouriered
683
+
684
+
685
+ def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
686
+ return nn.Sequential(
687
+ LearnedPositionalEmbedding(dim),
688
+ nn.Linear(in_features=dim + 1, out_features=out_features),
689
+ )
690
+
691
+
692
+ """
693
+ Encoder/Decoder Components
694
+ """
695
+
696
+
697
+ class DownsampleBlock1d(nn.Module):
698
+ def __init__(
699
+ self,
700
+ in_channels: int,
701
+ out_channels: int,
702
+ *,
703
+ factor: int,
704
+ num_groups: int,
705
+ num_layers: int,
706
+ kernel_multiplier: int = 2,
707
+ use_pre_downsample: bool = True,
708
+ use_skip: bool = False,
709
+ use_snake: bool = False,
710
+ extract_channels: int = 0,
711
+ context_channels: int = 0,
712
+ num_transformer_blocks: int = 0,
713
+ attention_heads: Optional[int] = None,
714
+ attention_features: Optional[int] = None,
715
+ attention_multiplier: Optional[int] = None,
716
+ context_mapping_features: Optional[int] = None,
717
+ context_embedding_features: Optional[int] = None,
718
+ ):
719
+ super().__init__()
720
+ self.use_pre_downsample = use_pre_downsample
721
+ self.use_skip = use_skip
722
+ self.use_transformer = num_transformer_blocks > 0
723
+ self.use_extract = extract_channels > 0
724
+ self.use_context = context_channels > 0
725
+
726
+ channels = out_channels if use_pre_downsample else in_channels
727
+
728
+ self.downsample = Downsample1d(
729
+ in_channels=in_channels,
730
+ out_channels=out_channels,
731
+ factor=factor,
732
+ kernel_multiplier=kernel_multiplier,
733
+ )
734
+
735
+ self.blocks = nn.ModuleList(
736
+ [
737
+ ResnetBlock1d(
738
+ in_channels=channels + context_channels if i == 0 else channels,
739
+ out_channels=channels,
740
+ num_groups=num_groups,
741
+ context_mapping_features=context_mapping_features,
742
+ use_snake=use_snake
743
+ )
744
+ for i in range(num_layers)
745
+ ]
746
+ )
747
+
748
+ if self.use_transformer:
749
+ assert (
750
+ (exists(attention_heads) or exists(attention_features))
751
+ and exists(attention_multiplier)
752
+ )
753
+
754
+ if attention_features is None and attention_heads is not None:
755
+ attention_features = channels // attention_heads
756
+
757
+ if attention_heads is None and attention_features is not None:
758
+ attention_heads = channels // attention_features
759
+
760
+ self.transformer = Transformer1d(
761
+ num_layers=num_transformer_blocks,
762
+ channels=channels,
763
+ num_heads=attention_heads,
764
+ head_features=attention_features,
765
+ multiplier=attention_multiplier,
766
+ context_features=context_embedding_features
767
+ )
768
+
769
+ if self.use_extract:
770
+ num_extract_groups = min(num_groups, extract_channels)
771
+ self.to_extracted = ResnetBlock1d(
772
+ in_channels=out_channels,
773
+ out_channels=extract_channels,
774
+ num_groups=num_extract_groups,
775
+ use_snake=use_snake
776
+ )
777
+
778
+ def forward(
779
+ self,
780
+ x: Tensor,
781
+ *,
782
+ mapping: Optional[Tensor] = None,
783
+ channels: Optional[Tensor] = None,
784
+ embedding: Optional[Tensor] = None,
785
+ embedding_mask: Optional[Tensor] = None,
786
+ causal: Optional[bool] = False
787
+ ) -> Union[Tuple[Tensor, List[Tensor]], Tensor]:
788
+
789
+ if self.use_pre_downsample:
790
+ x = self.downsample(x)
791
+
792
+ if self.use_context and exists(channels):
793
+ x = torch.cat([x, channels], dim=1)
794
+
795
+ skips = []
796
+ for block in self.blocks:
797
+ x = block(x, mapping=mapping, causal=causal)
798
+ skips += [x] if self.use_skip else []
799
+
800
+ if self.use_transformer:
801
+ x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal)
802
+ skips += [x] if self.use_skip else []
803
+
804
+ if not self.use_pre_downsample:
805
+ x = self.downsample(x)
806
+
807
+ if self.use_extract:
808
+ extracted = self.to_extracted(x)
809
+ return x, extracted
810
+
811
+ return (x, skips) if self.use_skip else x
812
+
813
+
814
+ class UpsampleBlock1d(nn.Module):
815
+ def __init__(
816
+ self,
817
+ in_channels: int,
818
+ out_channels: int,
819
+ *,
820
+ factor: int,
821
+ num_layers: int,
822
+ num_groups: int,
823
+ use_nearest: bool = False,
824
+ use_pre_upsample: bool = False,
825
+ use_skip: bool = False,
826
+ use_snake: bool = False,
827
+ skip_channels: int = 0,
828
+ use_skip_scale: bool = False,
829
+ extract_channels: int = 0,
830
+ num_transformer_blocks: int = 0,
831
+ attention_heads: Optional[int] = None,
832
+ attention_features: Optional[int] = None,
833
+ attention_multiplier: Optional[int] = None,
834
+ context_mapping_features: Optional[int] = None,
835
+ context_embedding_features: Optional[int] = None,
836
+ ):
837
+ super().__init__()
838
+
839
+ self.use_extract = extract_channels > 0
840
+ self.use_pre_upsample = use_pre_upsample
841
+ self.use_transformer = num_transformer_blocks > 0
842
+ self.use_skip = use_skip
843
+ self.skip_scale = 2 ** -0.5 if use_skip_scale else 1.0
844
+
845
+ channels = out_channels if use_pre_upsample else in_channels
846
+
847
+ self.blocks = nn.ModuleList(
848
+ [
849
+ ResnetBlock1d(
850
+ in_channels=channels + skip_channels,
851
+ out_channels=channels,
852
+ num_groups=num_groups,
853
+ context_mapping_features=context_mapping_features,
854
+ use_snake=use_snake
855
+ )
856
+ for _ in range(num_layers)
857
+ ]
858
+ )
859
+
860
+ if self.use_transformer:
861
+ assert (
862
+ (exists(attention_heads) or exists(attention_features))
863
+ and exists(attention_multiplier)
864
+ )
865
+
866
+ if attention_features is None and attention_heads is not None:
867
+ attention_features = channels // attention_heads
868
+
869
+ if attention_heads is None and attention_features is not None:
870
+ attention_heads = channels // attention_features
871
+
872
+ self.transformer = Transformer1d(
873
+ num_layers=num_transformer_blocks,
874
+ channels=channels,
875
+ num_heads=attention_heads,
876
+ head_features=attention_features,
877
+ multiplier=attention_multiplier,
878
+ context_features=context_embedding_features,
879
+ )
880
+
881
+ self.upsample = Upsample1d(
882
+ in_channels=in_channels,
883
+ out_channels=out_channels,
884
+ factor=factor,
885
+ use_nearest=use_nearest,
886
+ )
887
+
888
+ if self.use_extract:
889
+ num_extract_groups = min(num_groups, extract_channels)
890
+ self.to_extracted = ResnetBlock1d(
891
+ in_channels=out_channels,
892
+ out_channels=extract_channels,
893
+ num_groups=num_extract_groups,
894
+ use_snake=use_snake
895
+ )
896
+
897
+ def add_skip(self, x: Tensor, skip: Tensor) -> Tensor:
898
+ return torch.cat([x, skip * self.skip_scale], dim=1)
899
+
900
+ def forward(
901
+ self,
902
+ x: Tensor,
903
+ *,
904
+ skips: Optional[List[Tensor]] = None,
905
+ mapping: Optional[Tensor] = None,
906
+ embedding: Optional[Tensor] = None,
907
+ embedding_mask: Optional[Tensor] = None,
908
+ causal: Optional[bool] = False
909
+ ) -> Union[Tuple[Tensor, Tensor], Tensor]:
910
+
911
+ if self.use_pre_upsample:
912
+ x = self.upsample(x)
913
+
914
+ for block in self.blocks:
915
+ x = self.add_skip(x, skip=skips.pop()) if exists(skips) else x
916
+ x = block(x, mapping=mapping, causal=causal)
917
+
918
+ if self.use_transformer:
919
+ x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal)
920
+
921
+ if not self.use_pre_upsample:
922
+ x = self.upsample(x)
923
+
924
+ if self.use_extract:
925
+ extracted = self.to_extracted(x)
926
+ return x, extracted
927
+
928
+ return x
929
+
930
+
931
+ class BottleneckBlock1d(nn.Module):
932
+ def __init__(
933
+ self,
934
+ channels: int,
935
+ *,
936
+ num_groups: int,
937
+ num_transformer_blocks: int = 0,
938
+ attention_heads: Optional[int] = None,
939
+ attention_features: Optional[int] = None,
940
+ attention_multiplier: Optional[int] = None,
941
+ context_mapping_features: Optional[int] = None,
942
+ context_embedding_features: Optional[int] = None,
943
+ use_snake: bool = False,
944
+ ):
945
+ super().__init__()
946
+ self.use_transformer = num_transformer_blocks > 0
947
+
948
+ self.pre_block = ResnetBlock1d(
949
+ in_channels=channels,
950
+ out_channels=channels,
951
+ num_groups=num_groups,
952
+ context_mapping_features=context_mapping_features,
953
+ use_snake=use_snake
954
+ )
955
+
956
+ if self.use_transformer:
957
+ assert (
958
+ (exists(attention_heads) or exists(attention_features))
959
+ and exists(attention_multiplier)
960
+ )
961
+
962
+ if attention_features is None and attention_heads is not None:
963
+ attention_features = channels // attention_heads
964
+
965
+ if attention_heads is None and attention_features is not None:
966
+ attention_heads = channels // attention_features
967
+
968
+ self.transformer = Transformer1d(
969
+ num_layers=num_transformer_blocks,
970
+ channels=channels,
971
+ num_heads=attention_heads,
972
+ head_features=attention_features,
973
+ multiplier=attention_multiplier,
974
+ context_features=context_embedding_features,
975
+ )
976
+
977
+ self.post_block = ResnetBlock1d(
978
+ in_channels=channels,
979
+ out_channels=channels,
980
+ num_groups=num_groups,
981
+ context_mapping_features=context_mapping_features,
982
+ use_snake=use_snake
983
+ )
984
+
985
+ def forward(
986
+ self,
987
+ x: Tensor,
988
+ *,
989
+ mapping: Optional[Tensor] = None,
990
+ embedding: Optional[Tensor] = None,
991
+ embedding_mask: Optional[Tensor] = None,
992
+ causal: Optional[bool] = False
993
+ ) -> Tensor:
994
+ x = self.pre_block(x, mapping=mapping, causal=causal)
995
+ if self.use_transformer:
996
+ x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal)
997
+ x = self.post_block(x, mapping=mapping, causal=causal)
998
+ return x
999
+
1000
+
1001
+ """
1002
+ UNet
1003
+ """
1004
+
1005
+
1006
+ class UNet1d(nn.Module):
1007
+ def __init__(
1008
+ self,
1009
+ in_channels: int,
1010
+ channels: int,
1011
+ multipliers: Sequence[int],
1012
+ factors: Sequence[int],
1013
+ num_blocks: Sequence[int],
1014
+ attentions: Sequence[int],
1015
+ patch_size: int = 1,
1016
+ resnet_groups: int = 8,
1017
+ use_context_time: bool = True,
1018
+ kernel_multiplier_downsample: int = 2,
1019
+ use_nearest_upsample: bool = False,
1020
+ use_skip_scale: bool = True,
1021
+ use_snake: bool = False,
1022
+ use_stft: bool = False,
1023
+ use_stft_context: bool = False,
1024
+ out_channels: Optional[int] = None,
1025
+ context_features: Optional[int] = None,
1026
+ context_features_multiplier: int = 4,
1027
+ context_channels: Optional[Sequence[int]] = None,
1028
+ context_embedding_features: Optional[int] = None,
1029
+ **kwargs,
1030
+ ):
1031
+ super().__init__()
1032
+ out_channels = default(out_channels, in_channels)
1033
+ context_channels = list(default(context_channels, []))
1034
+ num_layers = len(multipliers) - 1
1035
+ use_context_features = exists(context_features)
1036
+ use_context_channels = len(context_channels) > 0
1037
+ context_mapping_features = None
1038
+
1039
+ attention_kwargs, kwargs = groupby("attention_", kwargs, keep_prefix=True)
1040
+
1041
+ self.num_layers = num_layers
1042
+ self.use_context_time = use_context_time
1043
+ self.use_context_features = use_context_features
1044
+ self.use_context_channels = use_context_channels
1045
+ self.use_stft = use_stft
1046
+ self.use_stft_context = use_stft_context
1047
+
1048
+ self.context_features = context_features
1049
+ context_channels_pad_length = num_layers + 1 - len(context_channels)
1050
+ context_channels = context_channels + [0] * context_channels_pad_length
1051
+ self.context_channels = context_channels
1052
+ self.context_embedding_features = context_embedding_features
1053
+
1054
+ if use_context_channels:
1055
+ has_context = [c > 0 for c in context_channels]
1056
+ self.has_context = has_context
1057
+ self.channels_ids = [sum(has_context[:i]) for i in range(len(has_context))]
1058
+
1059
+ assert (
1060
+ len(factors) == num_layers
1061
+ and len(attentions) >= num_layers
1062
+ and len(num_blocks) == num_layers
1063
+ )
1064
+
1065
+ if use_context_time or use_context_features:
1066
+ context_mapping_features = channels * context_features_multiplier
1067
+
1068
+ self.to_mapping = nn.Sequential(
1069
+ nn.Linear(context_mapping_features, context_mapping_features),
1070
+ nn.GELU(),
1071
+ nn.Linear(context_mapping_features, context_mapping_features),
1072
+ nn.GELU(),
1073
+ )
1074
+
1075
+ if use_context_time:
1076
+ assert exists(context_mapping_features)
1077
+ self.to_time = nn.Sequential(
1078
+ TimePositionalEmbedding(
1079
+ dim=channels, out_features=context_mapping_features
1080
+ ),
1081
+ nn.GELU(),
1082
+ )
1083
+
1084
+ if use_context_features:
1085
+ assert exists(context_features) and exists(context_mapping_features)
1086
+ self.to_features = nn.Sequential(
1087
+ nn.Linear(
1088
+ in_features=context_features, out_features=context_mapping_features
1089
+ ),
1090
+ nn.GELU(),
1091
+ )
1092
+
1093
+ if use_stft:
1094
+ stft_kwargs, kwargs = groupby("stft_", kwargs)
1095
+ assert "num_fft" in stft_kwargs, "stft_num_fft required if use_stft=True"
1096
+ stft_channels = (stft_kwargs["num_fft"] // 2 + 1) * 2
1097
+ in_channels *= stft_channels
1098
+ out_channels *= stft_channels
1099
+ context_channels[0] *= stft_channels if use_stft_context else 1
1100
+ assert exists(in_channels) and exists(out_channels)
1101
+ self.stft = STFT(**stft_kwargs)
1102
+
1103
+ assert not kwargs, f"Unknown arguments: {', '.join(list(kwargs.keys()))}"
1104
+
1105
+ self.to_in = Patcher(
1106
+ in_channels=in_channels + context_channels[0],
1107
+ out_channels=channels * multipliers[0],
1108
+ patch_size=patch_size,
1109
+ context_mapping_features=context_mapping_features,
1110
+ use_snake=use_snake
1111
+ )
1112
+
1113
+ self.downsamples = nn.ModuleList(
1114
+ [
1115
+ DownsampleBlock1d(
1116
+ in_channels=channels * multipliers[i],
1117
+ out_channels=channels * multipliers[i + 1],
1118
+ context_mapping_features=context_mapping_features,
1119
+ context_channels=context_channels[i + 1],
1120
+ context_embedding_features=context_embedding_features,
1121
+ num_layers=num_blocks[i],
1122
+ factor=factors[i],
1123
+ kernel_multiplier=kernel_multiplier_downsample,
1124
+ num_groups=resnet_groups,
1125
+ use_pre_downsample=True,
1126
+ use_skip=True,
1127
+ use_snake=use_snake,
1128
+ num_transformer_blocks=attentions[i],
1129
+ **attention_kwargs,
1130
+ )
1131
+ for i in range(num_layers)
1132
+ ]
1133
+ )
1134
+
1135
+ self.bottleneck = BottleneckBlock1d(
1136
+ channels=channels * multipliers[-1],
1137
+ context_mapping_features=context_mapping_features,
1138
+ context_embedding_features=context_embedding_features,
1139
+ num_groups=resnet_groups,
1140
+ num_transformer_blocks=attentions[-1],
1141
+ use_snake=use_snake,
1142
+ **attention_kwargs,
1143
+ )
1144
+
1145
+ self.upsamples = nn.ModuleList(
1146
+ [
1147
+ UpsampleBlock1d(
1148
+ in_channels=channels * multipliers[i + 1],
1149
+ out_channels=channels * multipliers[i],
1150
+ context_mapping_features=context_mapping_features,
1151
+ context_embedding_features=context_embedding_features,
1152
+ num_layers=num_blocks[i] + (1 if attentions[i] else 0),
1153
+ factor=factors[i],
1154
+ use_nearest=use_nearest_upsample,
1155
+ num_groups=resnet_groups,
1156
+ use_skip_scale=use_skip_scale,
1157
+ use_pre_upsample=False,
1158
+ use_skip=True,
1159
+ use_snake=use_snake,
1160
+ skip_channels=channels * multipliers[i + 1],
1161
+ num_transformer_blocks=attentions[i],
1162
+ **attention_kwargs,
1163
+ )
1164
+ for i in reversed(range(num_layers))
1165
+ ]
1166
+ )
1167
+
1168
+ self.to_out = Unpatcher(
1169
+ in_channels=channels * multipliers[0],
1170
+ out_channels=out_channels,
1171
+ patch_size=patch_size,
1172
+ context_mapping_features=context_mapping_features,
1173
+ use_snake=use_snake
1174
+ )
1175
+
1176
+ def get_channels(
1177
+ self, channels_list: Optional[Sequence[Tensor]] = None, layer: int = 0
1178
+ ) -> Optional[Tensor]:
1179
+ """Gets context channels at `layer` and checks that shape is correct"""
1180
+ use_context_channels = self.use_context_channels and self.has_context[layer]
1181
+ if not use_context_channels:
1182
+ return None
1183
+ assert exists(channels_list), "Missing context"
1184
+ # Get channels index (skipping zero channel contexts)
1185
+ channels_id = self.channels_ids[layer]
1186
+ # Get channels
1187
+ channels = channels_list[channels_id]
1188
+ message = f"Missing context for layer {layer} at index {channels_id}"
1189
+ assert exists(channels), message
1190
+ # Check channels
1191
+ num_channels = self.context_channels[layer]
1192
+ message = f"Expected context with {num_channels} channels at idx {channels_id}"
1193
+ assert channels.shape[1] == num_channels, message
1194
+ # STFT channels if requested
1195
+ channels = self.stft.encode1d(channels) if self.use_stft_context else channels # type: ignore # noqa
1196
+ return channels
1197
+
1198
+ def get_mapping(
1199
+ self, time: Optional[Tensor] = None, features: Optional[Tensor] = None
1200
+ ) -> Optional[Tensor]:
1201
+ """Combines context time features and features into mapping"""
1202
+ items, mapping = [], None
1203
+ # Compute time features
1204
+ if self.use_context_time:
1205
+ assert_message = "use_context_time=True but no time features provided"
1206
+ assert exists(time), assert_message
1207
+ items += [self.to_time(time)]
1208
+ # Compute features
1209
+ if self.use_context_features:
1210
+ assert_message = "context_features exists but no features provided"
1211
+ assert exists(features), assert_message
1212
+ items += [self.to_features(features)]
1213
+ # Compute joint mapping
1214
+ if self.use_context_time or self.use_context_features:
1215
+ mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
1216
+ mapping = self.to_mapping(mapping)
1217
+ return mapping
1218
+
1219
+ def forward(
1220
+ self,
1221
+ x: Tensor,
1222
+ time: Optional[Tensor] = None,
1223
+ *,
1224
+ features: Optional[Tensor] = None,
1225
+ channels_list: Optional[Sequence[Tensor]] = None,
1226
+ embedding: Optional[Tensor] = None,
1227
+ embedding_mask: Optional[Tensor] = None,
1228
+ causal: Optional[bool] = False,
1229
+ ) -> Tensor:
1230
+ channels = self.get_channels(channels_list, layer=0)
1231
+ # Apply stft if required
1232
+ x = self.stft.encode1d(x) if self.use_stft else x # type: ignore
1233
+ # Concat context channels at layer 0 if provided
1234
+ x = torch.cat([x, channels], dim=1) if exists(channels) else x
1235
+ # Compute mapping from time and features
1236
+ mapping = self.get_mapping(time, features)
1237
+ x = self.to_in(x, mapping, causal=causal)
1238
+ skips_list = [x]
1239
+
1240
+ for i, downsample in enumerate(self.downsamples):
1241
+ channels = self.get_channels(channels_list, layer=i + 1)
1242
+ x, skips = downsample(
1243
+ x, mapping=mapping, channels=channels, embedding=embedding, embedding_mask=embedding_mask, causal=causal
1244
+ )
1245
+ skips_list += [skips]
1246
+
1247
+ x = self.bottleneck(x, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal)
1248
+
1249
+ for i, upsample in enumerate(self.upsamples):
1250
+ skips = skips_list.pop()
1251
+ x = upsample(x, skips=skips, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal)
1252
+
1253
+ x += skips_list.pop()
1254
+ x = self.to_out(x, mapping, causal=causal)
1255
+ x = self.stft.decode1d(x) if self.use_stft else x
1256
+
1257
+ return x
1258
+
1259
+
1260
+ """ Conditioning Modules """
1261
+
1262
+
1263
+ class FixedEmbedding(nn.Module):
1264
+ def __init__(self, max_length: int, features: int):
1265
+ super().__init__()
1266
+ self.max_length = max_length
1267
+ self.embedding = nn.Embedding(max_length, features)
1268
+
1269
+ def forward(self, x: Tensor) -> Tensor:
1270
+ batch_size, length, device = *x.shape[0:2], x.device
1271
+ assert_message = "Input sequence length must be <= max_length"
1272
+ assert length <= self.max_length, assert_message
1273
+ position = torch.arange(length, device=device)
1274
+ fixed_embedding = self.embedding(position)
1275
+ fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size)
1276
+ return fixed_embedding
1277
+
1278
+
1279
+ def rand_bool(shape: Any, proba: float, device: Any = None) -> Tensor:
1280
+ if proba == 1:
1281
+ return torch.ones(shape, device=device, dtype=torch.bool)
1282
+ elif proba == 0:
1283
+ return torch.zeros(shape, device=device, dtype=torch.bool)
1284
+ else:
1285
+ return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)
1286
+
1287
+
1288
+ class UNetCFG1d(UNet1d):
1289
+
1290
+ """UNet1d with Classifier-Free Guidance"""
1291
+
1292
+ def __init__(
1293
+ self,
1294
+ context_embedding_max_length: int,
1295
+ context_embedding_features: int,
1296
+ use_xattn_time: bool = False,
1297
+ **kwargs,
1298
+ ):
1299
+ super().__init__(
1300
+ context_embedding_features=context_embedding_features, **kwargs
1301
+ )
1302
+
1303
+ self.use_xattn_time = use_xattn_time
1304
+
1305
+ if use_xattn_time:
1306
+ assert exists(context_embedding_features)
1307
+ self.to_time_embedding = nn.Sequential(
1308
+ TimePositionalEmbedding(
1309
+ dim=kwargs["channels"], out_features=context_embedding_features
1310
+ ),
1311
+ nn.GELU(),
1312
+ )
1313
+
1314
+ context_embedding_max_length += 1 # Add one for time embedding
1315
+
1316
+ self.fixed_embedding = FixedEmbedding(
1317
+ max_length=context_embedding_max_length, features=context_embedding_features
1318
+ )
1319
+
1320
+ def forward( # type: ignore
1321
+ self,
1322
+ x: Tensor,
1323
+ time: Tensor,
1324
+ *,
1325
+ embedding: Tensor,
1326
+ embedding_mask: Optional[Tensor] = None,
1327
+ embedding_scale: float = 1.0,
1328
+ embedding_mask_proba: float = 0.0,
1329
+ batch_cfg: bool = False,
1330
+ rescale_cfg: bool = False,
1331
+ scale_phi: float = 0.4,
1332
+ negative_embedding: Optional[Tensor] = None,
1333
+ negative_embedding_mask: Optional[Tensor] = None,
1334
+ **kwargs,
1335
+ ) -> Tensor:
1336
+ b, device = embedding.shape[0], embedding.device
1337
+
1338
+ if self.use_xattn_time:
1339
+ embedding = torch.cat([embedding, self.to_time_embedding(time).unsqueeze(1)], dim=1)
1340
+
1341
+ if embedding_mask is not None:
1342
+ embedding_mask = torch.cat([embedding_mask, torch.ones((b, 1), device=device)], dim=1)
1343
+
1344
+ fixed_embedding = self.fixed_embedding(embedding)
1345
+
1346
+ if embedding_mask_proba > 0.0:
1347
+ # Randomly mask embedding
1348
+ batch_mask = rand_bool(
1349
+ shape=(b, 1, 1), proba=embedding_mask_proba, device=device
1350
+ )
1351
+ embedding = torch.where(batch_mask, fixed_embedding, embedding)
1352
+
1353
+ if embedding_scale != 1.0:
1354
+ if batch_cfg:
1355
+ batch_x = torch.cat([x, x], dim=0)
1356
+ batch_time = torch.cat([time, time], dim=0)
1357
+
1358
+ if negative_embedding is not None:
1359
+ if negative_embedding_mask is not None:
1360
+ negative_embedding_mask = negative_embedding_mask.to(torch.bool).unsqueeze(2)
1361
+
1362
+ negative_embedding = torch.where(negative_embedding_mask, negative_embedding, fixed_embedding)
1363
+
1364
+ batch_embed = torch.cat([embedding, negative_embedding], dim=0)
1365
+
1366
+ else:
1367
+ batch_embed = torch.cat([embedding, fixed_embedding], dim=0)
1368
+
1369
+ batch_mask = None
1370
+ if embedding_mask is not None:
1371
+ batch_mask = torch.cat([embedding_mask, embedding_mask], dim=0)
1372
+
1373
+ batch_features = None
1374
+ features = kwargs.pop("features", None)
1375
+ if self.use_context_features:
1376
+ batch_features = torch.cat([features, features], dim=0)
1377
+
1378
+ batch_channels = None
1379
+ channels_list = kwargs.pop("channels_list", None)
1380
+ if self.use_context_channels:
1381
+ batch_channels = []
1382
+ for channels in channels_list:
1383
+ batch_channels += [torch.cat([channels, channels], dim=0)]
1384
+
1385
+ # Compute both normal and fixed embedding outputs
1386
+ batch_out = super().forward(batch_x, batch_time, embedding=batch_embed, embedding_mask=batch_mask, features=batch_features, channels_list=batch_channels, **kwargs)
1387
+ out, out_masked = batch_out.chunk(2, dim=0)
1388
+
1389
+ else:
1390
+ # Compute both normal and fixed embedding outputs
1391
+ out = super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs)
1392
+ out_masked = super().forward(x, time, embedding=fixed_embedding, embedding_mask=embedding_mask, **kwargs)
1393
+
1394
+ out_cfg = out_masked + (out - out_masked) * embedding_scale
1395
+
1396
+ if rescale_cfg:
1397
+
1398
+ out_std = out.std(dim=1, keepdim=True)
1399
+ out_cfg_std = out_cfg.std(dim=1, keepdim=True)
1400
+
1401
+ return scale_phi * (out_cfg * (out_std/out_cfg_std)) + (1-scale_phi) * out_cfg
1402
+
1403
+ else:
1404
+
1405
+ return out_cfg
1406
+
1407
+ else:
1408
+ return super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs)
1409
+
1410
+
1411
+ class UNetNCCA1d(UNet1d):
1412
+
1413
+ """UNet1d with Noise Channel Conditioning Augmentation"""
1414
+
1415
+ def __init__(self, context_features: int, **kwargs):
1416
+ super().__init__(context_features=context_features, **kwargs)
1417
+ self.embedder = NumberEmbedder(features=context_features)
1418
+
1419
+ def expand(self, x: Any, shape: Tuple[int, ...]) -> Tensor:
1420
+ x = x if torch.is_tensor(x) else torch.tensor(x)
1421
+ return x.expand(shape)
1422
+
1423
+ def forward( # type: ignore
1424
+ self,
1425
+ x: Tensor,
1426
+ time: Tensor,
1427
+ *,
1428
+ channels_list: Sequence[Tensor],
1429
+ channels_augmentation: Union[
1430
+ bool, Sequence[bool], Sequence[Sequence[bool]], Tensor
1431
+ ] = False,
1432
+ channels_scale: Union[
1433
+ float, Sequence[float], Sequence[Sequence[float]], Tensor
1434
+ ] = 0,
1435
+ **kwargs,
1436
+ ) -> Tensor:
1437
+ b, n = x.shape[0], len(channels_list)
1438
+ channels_augmentation = self.expand(channels_augmentation, shape=(b, n)).to(x)
1439
+ channels_scale = self.expand(channels_scale, shape=(b, n)).to(x)
1440
+
1441
+ # Augmentation (for each channel list item)
1442
+ for i in range(n):
1443
+ scale = channels_scale[:, i] * channels_augmentation[:, i]
1444
+ scale = rearrange(scale, "b -> b 1 1")
1445
+ item = channels_list[i]
1446
+ channels_list[i] = torch.randn_like(item) * scale + item * (1 - scale) # type: ignore # noqa
1447
+
1448
+ # Scale embedding (sum reduction if more than one channel list item)
1449
+ channels_scale_emb = self.embedder(channels_scale)
1450
+ channels_scale_emb = reduce(channels_scale_emb, "b n d -> b d", "sum")
1451
+
1452
+ return super().forward(
1453
+ x=x,
1454
+ time=time,
1455
+ channels_list=channels_list,
1456
+ features=channels_scale_emb,
1457
+ **kwargs,
1458
+ )
1459
+
1460
+
1461
+ class UNetAll1d(UNetCFG1d, UNetNCCA1d):
1462
+ def __init__(self, *args, **kwargs):
1463
+ super().__init__(*args, **kwargs)
1464
+
1465
+ def forward(self, *args, **kwargs): # type: ignore
1466
+ return UNetCFG1d.forward(self, *args, **kwargs)
1467
+
1468
+
1469
+ def XUNet1d(type: str = "base", **kwargs) -> UNet1d:
1470
+ if type == "base":
1471
+ return UNet1d(**kwargs)
1472
+ elif type == "all":
1473
+ return UNetAll1d(**kwargs)
1474
+ elif type == "cfg":
1475
+ return UNetCFG1d(**kwargs)
1476
+ elif type == "ncca":
1477
+ return UNetNCCA1d(**kwargs)
1478
+ else:
1479
+ raise ValueError(f"Unknown XUNet1d type: {type}")
1480
+
1481
+ class NumberEmbedder(nn.Module):
1482
+ def __init__(
1483
+ self,
1484
+ features: int,
1485
+ dim: int = 256,
1486
+ ):
1487
+ super().__init__()
1488
+ self.features = features
1489
+ self.embedding = TimePositionalEmbedding(dim=dim, out_features=features)
1490
+
1491
+ def forward(self, x: Union[List[float], Tensor]) -> Tensor:
1492
+ if not torch.is_tensor(x):
1493
+ device = next(self.embedding.parameters()).device
1494
+ x = torch.tensor(x, device=device)
1495
+ assert isinstance(x, Tensor)
1496
+ shape = x.shape
1497
+ x = rearrange(x, "... -> (...)")
1498
+ embedding = self.embedding(x)
1499
+ x = embedding.view(*shape, self.features)
1500
+ return x # type: ignore
1501
+
1502
+
1503
+ """
1504
+ Audio Transforms
1505
+ """
1506
+
1507
+
1508
+ class STFT(nn.Module):
1509
+ """Helper for torch stft and istft"""
1510
+
1511
+ def __init__(
1512
+ self,
1513
+ num_fft: int = 1023,
1514
+ hop_length: int = 256,
1515
+ window_length: Optional[int] = None,
1516
+ length: Optional[int] = None,
1517
+ use_complex: bool = False,
1518
+ ):
1519
+ super().__init__()
1520
+ self.num_fft = num_fft
1521
+ self.hop_length = default(hop_length, floor(num_fft // 4))
1522
+ self.window_length = default(window_length, num_fft)
1523
+ self.length = length
1524
+ self.register_buffer("window", torch.hann_window(self.window_length))
1525
+ self.use_complex = use_complex
1526
+
1527
+ def encode(self, wave: Tensor) -> Tuple[Tensor, Tensor]:
1528
+ b = wave.shape[0]
1529
+ wave = rearrange(wave, "b c t -> (b c) t")
1530
+
1531
+ stft = torch.stft(
1532
+ wave,
1533
+ n_fft=self.num_fft,
1534
+ hop_length=self.hop_length,
1535
+ win_length=self.window_length,
1536
+ window=self.window, # type: ignore
1537
+ return_complex=True,
1538
+ normalized=True,
1539
+ )
1540
+
1541
+ if self.use_complex:
1542
+ # Returns real and imaginary
1543
+ stft_a, stft_b = stft.real, stft.imag
1544
+ else:
1545
+ # Returns magnitude and phase matrices
1546
+ magnitude, phase = torch.abs(stft), torch.angle(stft)
1547
+ stft_a, stft_b = magnitude, phase
1548
+
1549
+ return rearrange_many((stft_a, stft_b), "(b c) f l -> b c f l", b=b)
1550
+
1551
+ def decode(self, stft_a: Tensor, stft_b: Tensor) -> Tensor:
1552
+ b, l = stft_a.shape[0], stft_a.shape[-1] # noqa
1553
+ length = closest_power_2(l * self.hop_length)
1554
+
1555
+ stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> (b c) f l")
1556
+
1557
+ if self.use_complex:
1558
+ real, imag = stft_a, stft_b
1559
+ else:
1560
+ magnitude, phase = stft_a, stft_b
1561
+ real, imag = magnitude * torch.cos(phase), magnitude * torch.sin(phase)
1562
+
1563
+ stft = torch.stack([real, imag], dim=-1)
1564
+
1565
+ wave = torch.istft(
1566
+ stft,
1567
+ n_fft=self.num_fft,
1568
+ hop_length=self.hop_length,
1569
+ win_length=self.window_length,
1570
+ window=self.window, # type: ignore
1571
+ length=default(self.length, length),
1572
+ normalized=True,
1573
+ )
1574
+
1575
+ return rearrange(wave, "(b c) t -> b c t", b=b)
1576
+
1577
+ def encode1d(
1578
+ self, wave: Tensor, stacked: bool = True
1579
+ ) -> Union[Tensor, Tuple[Tensor, Tensor]]:
1580
+ stft_a, stft_b = self.encode(wave)
1581
+ stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> b (c f) l")
1582
+ return torch.cat((stft_a, stft_b), dim=1) if stacked else (stft_a, stft_b)
1583
+
1584
+ def decode1d(self, stft_pair: Tensor) -> Tensor:
1585
+ f = self.num_fft // 2 + 1
1586
+ stft_a, stft_b = stft_pair.chunk(chunks=2, dim=1)
1587
+ stft_a, stft_b = rearrange_many((stft_a, stft_b), "b (c f) l -> b c f l", f=f)
1588
+ return self.decode(stft_a, stft_b)
stable_audio_tools/models/autoencoders.py ADDED
@@ -0,0 +1,800 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import numpy as np
4
+
5
+ from torch import nn, sin, pow
6
+ from torch.nn import functional as F
7
+ from torchaudio import transforms as T
8
+ from alias_free_torch import Activation1d
9
+ from dac.nn.layers import WNConv1d, WNConvTranspose1d
10
+ from typing import List, Literal, Dict, Any, Callable
11
+ from einops import rearrange
12
+
13
+ from ..inference.sampling import sample
14
+ from ..inference.utils import prepare_audio
15
+ from .blocks import SnakeBeta
16
+ from .bottleneck import Bottleneck, DiscreteBottleneck
17
+ from .diffusion import ConditionedDiffusionModel, DAU1DCondWrapper, UNet1DCondWrapper, DiTWrapper
18
+ from .factory import create_pretransform_from_config, create_bottleneck_from_config
19
+ from .pretransforms import Pretransform, AutoencoderPretransform
20
+
21
+ def checkpoint(function, *args, **kwargs):
22
+ kwargs.setdefault("use_reentrant", False)
23
+ return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
24
+
25
+ def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
26
+ if activation == "elu":
27
+ act = nn.ELU()
28
+ elif activation == "snake":
29
+ act = SnakeBeta(channels)
30
+ elif activation == "none":
31
+ act = nn.Identity()
32
+ else:
33
+ raise ValueError(f"Unknown activation {activation}")
34
+
35
+ if antialias:
36
+ act = Activation1d(act)
37
+
38
+ return act
39
+
40
+ class ResidualUnit(nn.Module):
41
+ def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False):
42
+ super().__init__()
43
+
44
+ self.dilation = dilation
45
+
46
+ act = get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels)
47
+
48
+ padding = (dilation * (7-1)) // 2
49
+
50
+ self.layers = nn.Sequential(
51
+ act,
52
+ WNConv1d(in_channels=in_channels, out_channels=out_channels,
53
+ kernel_size=7, dilation=dilation, padding=padding),
54
+ act,
55
+ WNConv1d(in_channels=out_channels, out_channels=out_channels,
56
+ kernel_size=1)
57
+ )
58
+
59
+ def forward(self, x):
60
+ res = x
61
+
62
+ # Disable checkpoint until tensor mismatch is fixed
63
+ #x = checkpoint(self.layers, x)
64
+ x = self.layers(x)
65
+
66
+
67
+ class EncoderBlock(nn.Module):
68
+ def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False):
69
+ super().__init__()
70
+
71
+ act = get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels)
72
+
73
+ self.layers = nn.Sequential(
74
+ ResidualUnit(in_channels=in_channels,
75
+ out_channels=in_channels, dilation=1, use_snake=use_snake),
76
+ ResidualUnit(in_channels=in_channels,
77
+ out_channels=in_channels, dilation=3, use_snake=use_snake),
78
+ ResidualUnit(in_channels=in_channels,
79
+ out_channels=in_channels, dilation=9, use_snake=use_snake),
80
+ act,
81
+ WNConv1d(in_channels=in_channels, out_channels=out_channels,
82
+ kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)),
83
+ )
84
+
85
+ def forward(self, x):
86
+ return self.layers(x)
87
+
88
+ class DecoderBlock(nn.Module):
89
+ def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False):
90
+ super().__init__()
91
+
92
+ if use_nearest_upsample:
93
+ upsample_layer = nn.Sequential(
94
+ nn.Upsample(scale_factor=stride, mode="nearest"),
95
+ WNConv1d(in_channels=in_channels,
96
+ out_channels=out_channels,
97
+ kernel_size=2*stride,
98
+ stride=1,
99
+ bias=False,
100
+ padding='same')
101
+ )
102
+ else:
103
+ upsample_layer = WNConvTranspose1d(in_channels=in_channels,
104
+ out_channels=out_channels,
105
+ kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2))
106
+
107
+ act = get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels)
108
+
109
+ self.layers = nn.Sequential(
110
+ act,
111
+ upsample_layer,
112
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels,
113
+ dilation=1, use_snake=use_snake),
114
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels,
115
+ dilation=3, use_snake=use_snake),
116
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels,
117
+ dilation=9, use_snake=use_snake),
118
+ )
119
+
120
+ def forward(self, x):
121
+ return self.layers(x)
122
+
123
+ class OobleckEncoder(nn.Module):
124
+ def __init__(self,
125
+ in_channels=2,
126
+ channels=128,
127
+ latent_dim=32,
128
+ c_mults = [1, 2, 4, 8],
129
+ strides = [2, 4, 8, 8],
130
+ use_snake=False,
131
+ antialias_activation=False
132
+ ):
133
+ super().__init__()
134
+
135
+ c_mults = [1] + c_mults
136
+
137
+ self.depth = len(c_mults)
138
+
139
+ layers = [
140
+ WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3)
141
+ ]
142
+
143
+ for i in range(self.depth-1):
144
+ layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)]
145
+
146
+ layers += [
147
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels),
148
+ WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1)
149
+ ]
150
+
151
+ self.layers = nn.Sequential(*layers)
152
+
153
+ def forward(self, x):
154
+ return self.layers(x)
155
+
156
+
157
+ class OobleckDecoder(nn.Module):
158
+ def __init__(self,
159
+ out_channels=2,
160
+ channels=128,
161
+ latent_dim=32,
162
+ c_mults = [1, 2, 4, 8],
163
+ strides = [2, 4, 8, 8],
164
+ use_snake=False,
165
+ antialias_activation=False,
166
+ use_nearest_upsample=False,
167
+ final_tanh=True):
168
+ super().__init__()
169
+
170
+ c_mults = [1] + c_mults
171
+
172
+ self.depth = len(c_mults)
173
+
174
+ layers = [
175
+ WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3),
176
+ ]
177
+
178
+ for i in range(self.depth-1, 0, -1):
179
+ layers += [DecoderBlock(
180
+ in_channels=c_mults[i]*channels,
181
+ out_channels=c_mults[i-1]*channels,
182
+ stride=strides[i-1],
183
+ use_snake=use_snake,
184
+ antialias_activation=antialias_activation,
185
+ use_nearest_upsample=use_nearest_upsample
186
+ )
187
+ ]
188
+
189
+ layers += [
190
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels),
191
+ WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False),
192
+ nn.Tanh() if final_tanh else nn.Identity()
193
+ ]
194
+
195
+ self.layers = nn.Sequential(*layers)
196
+
197
+ def forward(self, x):
198
+ return self.layers(x)
199
+
200
+ class DACEncoderWrapper(nn.Module):
201
+ def __init__(self, in_channels=1, **kwargs):
202
+ super().__init__()
203
+
204
+ from dac.model.dac import Encoder as DACEncoder
205
+
206
+ latent_dim = kwargs.pop("latent_dim", None)
207
+
208
+ encoder_out_dim = kwargs["d_model"] * (2 ** len(kwargs["strides"]))
209
+ self.encoder = DACEncoder(d_latent=encoder_out_dim, **kwargs)
210
+ self.latent_dim = latent_dim
211
+
212
+ # Latent-dim support was added to DAC after this was first written, and implemented differently, so this is for backwards compatibility
213
+ self.proj_out = nn.Conv1d(self.encoder.enc_dim, latent_dim, kernel_size=1) if latent_dim is not None else nn.Identity()
214
+
215
+ if in_channels != 1:
216
+ self.encoder.block[0] = WNConv1d(in_channels, kwargs.get("d_model", 64), kernel_size=7, padding=3)
217
+
218
+ def forward(self, x):
219
+ x = self.encoder(x)
220
+ x = self.proj_out(x)
221
+ return x
222
+
223
+ class DACDecoderWrapper(nn.Module):
224
+ def __init__(self, latent_dim, out_channels=1, **kwargs):
225
+ super().__init__()
226
+
227
+ from dac.model.dac import Decoder as DACDecoder
228
+
229
+ self.decoder = DACDecoder(**kwargs, input_channel = latent_dim, d_out=out_channels)
230
+
231
+ self.latent_dim = latent_dim
232
+
233
+ def forward(self, x):
234
+ return self.decoder(x)
235
+
236
+ class AudioAutoencoder(nn.Module):
237
+ def __init__(
238
+ self,
239
+ encoder,
240
+ decoder,
241
+ latent_dim,
242
+ downsampling_ratio,
243
+ sample_rate,
244
+ io_channels=2,
245
+ bottleneck: Bottleneck = None,
246
+ pretransform: Pretransform = None,
247
+ in_channels = None,
248
+ out_channels = None,
249
+ soft_clip = False
250
+ ):
251
+ super().__init__()
252
+
253
+ self.downsampling_ratio = downsampling_ratio
254
+ self.sample_rate = sample_rate
255
+
256
+ self.latent_dim = latent_dim
257
+ self.io_channels = io_channels
258
+ self.in_channels = io_channels
259
+ self.out_channels = io_channels
260
+
261
+ self.min_length = self.downsampling_ratio
262
+
263
+ if in_channels is not None:
264
+ self.in_channels = in_channels
265
+
266
+ if out_channels is not None:
267
+ self.out_channels = out_channels
268
+
269
+ self.bottleneck = bottleneck
270
+
271
+ self.encoder = encoder
272
+
273
+ self.decoder = decoder
274
+
275
+ self.pretransform = pretransform
276
+
277
+ self.soft_clip = soft_clip
278
+
279
+ self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete
280
+
281
+ def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batch=False, **kwargs):
282
+
283
+ info = {}
284
+
285
+ if self.pretransform is not None and not skip_pretransform:
286
+ if self.pretransform.enable_grad:
287
+ if iterate_batch:
288
+ audios = []
289
+ for i in range(audio.shape[0]):
290
+ audios.append(self.pretransform.encode(audio[i:i+1]))
291
+ audio = torch.cat(audios, dim=0)
292
+ else:
293
+ audio = self.pretransform.encode(audio)
294
+ else:
295
+ with torch.no_grad():
296
+ if iterate_batch:
297
+ audios = []
298
+ for i in range(audio.shape[0]):
299
+ audios.append(self.pretransform.encode(audio[i:i+1]))
300
+ audio = torch.cat(audios, dim=0)
301
+ else:
302
+ audio = self.pretransform.encode(audio)
303
+
304
+ if self.encoder is not None:
305
+ if iterate_batch:
306
+ latents = []
307
+ for i in range(audio.shape[0]):
308
+ latents.append(self.encoder(audio[i:i+1]))
309
+ latents = torch.cat(latents, dim=0)
310
+ else:
311
+ latents = self.encoder(audio)
312
+ else:
313
+ latents = audio
314
+
315
+ if self.bottleneck is not None:
316
+ # TODO: Add iterate batch logic, needs to merge the info dicts
317
+ latents, bottleneck_info = self.bottleneck.encode(latents, return_info=True, **kwargs)
318
+
319
+ info.update(bottleneck_info)
320
+
321
+ if return_info:
322
+ return latents, info
323
+
324
+ return latents
325
+
326
+ def decode(self, latents, iterate_batch=False, **kwargs):
327
+
328
+ if self.bottleneck is not None:
329
+ if iterate_batch:
330
+ decoded = []
331
+ for i in range(latents.shape[0]):
332
+ decoded.append(self.bottleneck.decode(latents[i:i+1]))
333
+ decoded = torch.cat(decoded, dim=0)
334
+ else:
335
+ latents = self.bottleneck.decode(latents)
336
+
337
+ if iterate_batch:
338
+ decoded = []
339
+ for i in range(latents.shape[0]):
340
+ decoded.append(self.decoder(latents[i:i+1]))
341
+ decoded = torch.cat(decoded, dim=0)
342
+ else:
343
+ decoded = self.decoder(latents, **kwargs)
344
+
345
+ if self.pretransform is not None:
346
+ if self.pretransform.enable_grad:
347
+ if iterate_batch:
348
+ decodeds = []
349
+ for i in range(decoded.shape[0]):
350
+ decodeds.append(self.pretransform.decode(decoded[i:i+1]))
351
+ decoded = torch.cat(decodeds, dim=0)
352
+ else:
353
+ decoded = self.pretransform.decode(decoded)
354
+ else:
355
+ with torch.no_grad():
356
+ if iterate_batch:
357
+ decodeds = []
358
+ for i in range(latents.shape[0]):
359
+ decodeds.append(self.pretransform.decode(decoded[i:i+1]))
360
+ decoded = torch.cat(decodeds, dim=0)
361
+ else:
362
+ decoded = self.pretransform.decode(decoded)
363
+
364
+ if self.soft_clip:
365
+ decoded = torch.tanh(decoded)
366
+
367
+ return decoded
368
+
369
+ def decode_tokens(self, tokens, **kwargs):
370
+ '''
371
+ Decode discrete tokens to audio
372
+ Only works with discrete autoencoders
373
+ '''
374
+
375
+ assert isinstance(self.bottleneck, DiscreteBottleneck), "decode_tokens only works with discrete autoencoders"
376
+
377
+ latents = self.bottleneck.decode_tokens(tokens, **kwargs)
378
+
379
+ return self.decode(latents, **kwargs)
380
+
381
+
382
+ def preprocess_audio_for_encoder(self, audio, in_sr):
383
+ '''
384
+ Preprocess single audio tensor (Channels x Length) to be compatible with the encoder.
385
+ If the model is mono, stereo audio will be converted to mono.
386
+ Audio will be silence-padded to be a multiple of the model's downsampling ratio.
387
+ Audio will be resampled to the model's sample rate.
388
+ The output will have batch size 1 and be shape (1 x Channels x Length)
389
+ '''
390
+ return self.preprocess_audio_list_for_encoder([audio], [in_sr])
391
+
392
+ def preprocess_audio_list_for_encoder(self, audio_list, in_sr_list):
393
+ '''
394
+ Preprocess a [list] of audio (Channels x Length) into a batch tensor to be compatable with the encoder.
395
+ The audio in that list can be of different lengths and channels.
396
+ in_sr can be an integer or list. If it's an integer it will be assumed it is the input sample_rate for every audio.
397
+ All audio will be resampled to the model's sample rate.
398
+ Audio will be silence-padded to the longest length, and further padded to be a multiple of the model's downsampling ratio.
399
+ If the model is mono, all audio will be converted to mono.
400
+ The output will be a tensor of shape (Batch x Channels x Length)
401
+ '''
402
+ batch_size = len(audio_list)
403
+ if isinstance(in_sr_list, int):
404
+ in_sr_list = [in_sr_list]*batch_size
405
+ assert len(in_sr_list) == batch_size, "list of sample rates must be the same length of audio_list"
406
+ new_audio = []
407
+ max_length = 0
408
+ # resample & find the max length
409
+ for i in range(batch_size):
410
+ audio = audio_list[i]
411
+ in_sr = in_sr_list[i]
412
+ if len(audio.shape) == 3 and audio.shape[0] == 1:
413
+ # batchsize 1 was given by accident. Just squeeze it.
414
+ audio = audio.squeeze(0)
415
+ elif len(audio.shape) == 1:
416
+ # Mono signal, channel dimension is missing, unsqueeze it in
417
+ audio = audio.unsqueeze(0)
418
+ assert len(audio.shape)==2, "Audio should be shape (Channels x Length) with no batch dimension"
419
+ # Resample audio
420
+ if in_sr != self.sample_rate:
421
+ resample_tf = T.Resample(in_sr, self.sample_rate).to(audio.device)
422
+ audio = resample_tf(audio)
423
+ new_audio.append(audio)
424
+ if audio.shape[-1] > max_length:
425
+ max_length = audio.shape[-1]
426
+ # Pad every audio to the same length, multiple of model's downsampling ratio
427
+ padded_audio_length = max_length + (self.min_length - (max_length % self.min_length)) % self.min_length
428
+ for i in range(batch_size):
429
+ # Pad it & if necessary, mixdown/duplicate stereo/mono channels to support model
430
+ new_audio[i] = prepare_audio(new_audio[i], in_sr=in_sr, target_sr=in_sr, target_length=padded_audio_length,
431
+ target_channels=self.in_channels, device=new_audio[i].device).squeeze(0)
432
+ # convert to tensor
433
+ return torch.stack(new_audio)
434
+
435
+ def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwargs):
436
+ '''
437
+ Encode audios into latents. Audios should already be preprocesed by preprocess_audio_for_encoder.
438
+ If chunked is True, split the audio into chunks of a given maximum size chunk_size, with given overlap.
439
+ Overlap and chunk_size params are both measured in number of latents (not audio samples)
440
+ # and therefore you likely could use the same values with decode_audio.
441
+ A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
442
+ Every autoencoder will have a different receptive field size, and thus ideal overlap.
443
+ You can determine it empirically by diffing unchunked vs chunked output and looking at maximum diff.
444
+ The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
445
+ Smaller chunk_size uses less memory, but more compute.
446
+ The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
447
+ For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
448
+ '''
449
+ if not chunked:
450
+ # default behavior. Encode the entire audio in parallel
451
+ return self.encode(audio, **kwargs)
452
+ else:
453
+ # CHUNKED ENCODING
454
+ # samples_per_latent is just the downsampling ratio (which is also the upsampling ratio)
455
+ samples_per_latent = self.downsampling_ratio
456
+ total_size = audio.shape[2] # in samples
457
+ batch_size = audio.shape[0]
458
+ chunk_size *= samples_per_latent # converting metric in latents to samples
459
+ overlap *= samples_per_latent # converting metric in latents to samples
460
+ hop_size = chunk_size - overlap
461
+ chunks = []
462
+ for i in range(0, total_size - chunk_size + 1, hop_size):
463
+ chunk = audio[:,:,i:i+chunk_size]
464
+ chunks.append(chunk)
465
+ if i+chunk_size != total_size:
466
+ # Final chunk
467
+ chunk = audio[:,:,-chunk_size:]
468
+ chunks.append(chunk)
469
+ chunks = torch.stack(chunks)
470
+ num_chunks = chunks.shape[0]
471
+ # Note: y_size might be a different value from the latent length used in diffusion training
472
+ # because we can encode audio of varying lengths
473
+ # However, the audio should've been padded to a multiple of samples_per_latent by now.
474
+ y_size = total_size // samples_per_latent
475
+ # Create an empty latent, we will populate it with chunks as we encode them
476
+ y_final = torch.zeros((batch_size,self.latent_dim,y_size)).to(audio.device)
477
+ for i in range(num_chunks):
478
+ x_chunk = chunks[i,:]
479
+ # encode the chunk
480
+ y_chunk = self.encode(x_chunk)
481
+ # figure out where to put the audio along the time domain
482
+ if i == num_chunks-1:
483
+ # final chunk always goes at the end
484
+ t_end = y_size
485
+ t_start = t_end - y_chunk.shape[2]
486
+ else:
487
+ t_start = i * hop_size // samples_per_latent
488
+ t_end = t_start + chunk_size // samples_per_latent
489
+ # remove the edges of the overlaps
490
+ ol = overlap//samples_per_latent//2
491
+ chunk_start = 0
492
+ chunk_end = y_chunk.shape[2]
493
+ if i > 0:
494
+ # no overlap for the start of the first chunk
495
+ t_start += ol
496
+ chunk_start += ol
497
+ if i < num_chunks-1:
498
+ # no overlap for the end of the last chunk
499
+ t_end -= ol
500
+ chunk_end -= ol
501
+ # paste the chunked audio into our y_final output audio
502
+ y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
503
+ return y_final
504
+
505
+ def decode_audio(self, latents, chunked=False, overlap=32, chunk_size=128, **kwargs):
506
+ '''
507
+ Decode latents to audio.
508
+ If chunked is True, split the latents into chunks of a given maximum size chunk_size, with given overlap, both of which are measured in number of latents.
509
+ A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
510
+ Every autoencoder will have a different receptive field size, and thus ideal overlap.
511
+ You can determine it empirically by diffing unchunked vs chunked audio and looking at maximum diff.
512
+ The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
513
+ Smaller chunk_size uses less memory, but more compute.
514
+ The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
515
+ For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
516
+ '''
517
+ if not chunked:
518
+ # default behavior. Decode the entire latent in parallel
519
+ return self.decode(latents, **kwargs)
520
+ else:
521
+ # chunked decoding
522
+ hop_size = chunk_size - overlap
523
+ total_size = latents.shape[2]
524
+ batch_size = latents.shape[0]
525
+ chunks = []
526
+ for i in range(0, total_size - chunk_size + 1, hop_size):
527
+ chunk = latents[:,:,i:i+chunk_size]
528
+ chunks.append(chunk)
529
+ if i+chunk_size != total_size:
530
+ # Final chunk
531
+ chunk = latents[:,:,-chunk_size:]
532
+ chunks.append(chunk)
533
+ chunks = torch.stack(chunks)
534
+ num_chunks = chunks.shape[0]
535
+ # samples_per_latent is just the downsampling ratio
536
+ samples_per_latent = self.downsampling_ratio
537
+ # Create an empty waveform, we will populate it with chunks as decode them
538
+ y_size = total_size * samples_per_latent
539
+ y_final = torch.zeros((batch_size,self.out_channels,y_size)).to(latents.device)
540
+ for i in range(num_chunks):
541
+ x_chunk = chunks[i,:]
542
+ # decode the chunk
543
+ y_chunk = self.decode(x_chunk)
544
+ # figure out where to put the audio along the time domain
545
+ if i == num_chunks-1:
546
+ # final chunk always goes at the end
547
+ t_end = y_size
548
+ t_start = t_end - y_chunk.shape[2]
549
+ else:
550
+ t_start = i * hop_size * samples_per_latent
551
+ t_end = t_start + chunk_size * samples_per_latent
552
+ # remove the edges of the overlaps
553
+ ol = (overlap//2) * samples_per_latent
554
+ chunk_start = 0
555
+ chunk_end = y_chunk.shape[2]
556
+ if i > 0:
557
+ # no overlap for the start of the first chunk
558
+ t_start += ol
559
+ chunk_start += ol
560
+ if i < num_chunks-1:
561
+ # no overlap for the end of the last chunk
562
+ t_end -= ol
563
+ chunk_end -= ol
564
+ # paste the chunked audio into our y_final output audio
565
+ y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
566
+ return y_final
567
+
568
+
569
+ class DiffusionAutoencoder(AudioAutoencoder):
570
+ def __init__(
571
+ self,
572
+ diffusion: ConditionedDiffusionModel,
573
+ diffusion_downsampling_ratio,
574
+ *args,
575
+ **kwargs
576
+ ):
577
+ super().__init__(*args, **kwargs)
578
+
579
+ self.diffusion = diffusion
580
+
581
+ self.min_length = self.downsampling_ratio * diffusion_downsampling_ratio
582
+
583
+ if self.encoder is not None:
584
+ # Shrink the initial encoder parameters to avoid saturated latents
585
+ with torch.no_grad():
586
+ for param in self.encoder.parameters():
587
+ param *= 0.5
588
+
589
+ def decode(self, latents, steps=100):
590
+
591
+ upsampled_length = latents.shape[2] * self.downsampling_ratio
592
+
593
+ if self.bottleneck is not None:
594
+ latents = self.bottleneck.decode(latents)
595
+
596
+ if self.decoder is not None:
597
+ latents = self.decode(latents)
598
+
599
+ # Upsample latents to match diffusion length
600
+ if latents.shape[2] != upsampled_length:
601
+ latents = F.interpolate(latents, size=upsampled_length, mode='nearest')
602
+
603
+ noise = torch.randn(latents.shape[0], self.io_channels, upsampled_length, device=latents.device)
604
+ decoded = sample(self.diffusion, noise, steps, 0, input_concat_cond=latents)
605
+
606
+ if self.pretransform is not None:
607
+ if self.pretransform.enable_grad:
608
+ decoded = self.pretransform.decode(decoded)
609
+ else:
610
+ with torch.no_grad():
611
+ decoded = self.pretransform.decode(decoded)
612
+
613
+ return decoded
614
+
615
+ # AE factories
616
+
617
+ def create_encoder_from_config(encoder_config: Dict[str, Any]):
618
+ encoder_type = encoder_config.get("type", None)
619
+ assert encoder_type is not None, "Encoder type must be specified"
620
+
621
+ if encoder_type == "oobleck":
622
+ encoder = OobleckEncoder(
623
+ **encoder_config["config"]
624
+ )
625
+
626
+ elif encoder_type == "seanet":
627
+ from encodec.modules import SEANetEncoder
628
+ seanet_encoder_config = encoder_config["config"]
629
+
630
+ #SEANet encoder expects strides in reverse order
631
+ seanet_encoder_config["ratios"] = list(reversed(seanet_encoder_config.get("ratios", [2, 2, 2, 2, 2])))
632
+ encoder = SEANetEncoder(
633
+ **seanet_encoder_config
634
+ )
635
+ elif encoder_type == "dac":
636
+ dac_config = encoder_config["config"]
637
+
638
+ encoder = DACEncoderWrapper(**dac_config)
639
+ elif encoder_type == "local_attn":
640
+ from .local_attention import TransformerEncoder1D
641
+
642
+ local_attn_config = encoder_config["config"]
643
+
644
+ encoder = TransformerEncoder1D(
645
+ **local_attn_config
646
+ )
647
+ else:
648
+ raise ValueError(f"Unknown encoder type {encoder_type}")
649
+
650
+ requires_grad = encoder_config.get("requires_grad", True)
651
+ if not requires_grad:
652
+ for param in encoder.parameters():
653
+ param.requires_grad = False
654
+
655
+ return encoder
656
+
657
+ def create_decoder_from_config(decoder_config: Dict[str, Any]):
658
+ decoder_type = decoder_config.get("type", None)
659
+ assert decoder_type is not None, "Decoder type must be specified"
660
+
661
+ if decoder_type == "oobleck":
662
+ decoder = OobleckDecoder(
663
+ **decoder_config["config"]
664
+ )
665
+ elif decoder_type == "seanet":
666
+ from encodec.modules import SEANetDecoder
667
+
668
+ decoder = SEANetDecoder(
669
+ **decoder_config["config"]
670
+ )
671
+ elif decoder_type == "dac":
672
+ dac_config = decoder_config["config"]
673
+
674
+ decoder = DACDecoderWrapper(**dac_config)
675
+ elif decoder_type == "local_attn":
676
+ from .local_attention import TransformerDecoder1D
677
+
678
+ local_attn_config = decoder_config["config"]
679
+
680
+ decoder = TransformerDecoder1D(
681
+ **local_attn_config
682
+ )
683
+ else:
684
+ raise ValueError(f"Unknown decoder type {decoder_type}")
685
+
686
+ requires_grad = decoder_config.get("requires_grad", True)
687
+ if not requires_grad:
688
+ for param in decoder.parameters():
689
+ param.requires_grad = False
690
+
691
+ return decoder
692
+
693
+ def create_autoencoder_from_config(config: Dict[str, Any]):
694
+
695
+ ae_config = config["model"]
696
+
697
+ encoder = create_encoder_from_config(ae_config["encoder"])
698
+ decoder = create_decoder_from_config(ae_config["decoder"])
699
+
700
+ bottleneck = ae_config.get("bottleneck", None)
701
+
702
+ latent_dim = ae_config.get("latent_dim", None)
703
+ assert latent_dim is not None, "latent_dim must be specified in model config"
704
+ downsampling_ratio = ae_config.get("downsampling_ratio", None)
705
+ assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
706
+ io_channels = ae_config.get("io_channels", None)
707
+ assert io_channels is not None, "io_channels must be specified in model config"
708
+ sample_rate = config.get("sample_rate", None)
709
+ assert sample_rate is not None, "sample_rate must be specified in model config"
710
+
711
+ in_channels = ae_config.get("in_channels", None)
712
+ out_channels = ae_config.get("out_channels", None)
713
+
714
+ pretransform = ae_config.get("pretransform", None)
715
+
716
+ if pretransform is not None:
717
+ pretransform = create_pretransform_from_config(pretransform, sample_rate)
718
+
719
+ if bottleneck is not None:
720
+ bottleneck = create_bottleneck_from_config(bottleneck)
721
+
722
+ soft_clip = ae_config["decoder"].get("soft_clip", False)
723
+
724
+ return AudioAutoencoder(
725
+ encoder,
726
+ decoder,
727
+ io_channels=io_channels,
728
+ latent_dim=latent_dim,
729
+ downsampling_ratio=downsampling_ratio,
730
+ sample_rate=sample_rate,
731
+ bottleneck=bottleneck,
732
+ pretransform=pretransform,
733
+ in_channels=in_channels,
734
+ out_channels=out_channels,
735
+ soft_clip=soft_clip
736
+ )
737
+
738
+ def create_diffAE_from_config(config: Dict[str, Any]):
739
+
740
+ diffae_config = config["model"]
741
+
742
+ if "encoder" in diffae_config:
743
+ encoder = create_encoder_from_config(diffae_config["encoder"])
744
+ else:
745
+ encoder = None
746
+
747
+ if "decoder" in diffae_config:
748
+ decoder = create_decoder_from_config(diffae_config["decoder"])
749
+ else:
750
+ decoder = None
751
+
752
+ diffusion_model_type = diffae_config["diffusion"]["type"]
753
+
754
+ if diffusion_model_type == "DAU1d":
755
+ diffusion = DAU1DCondWrapper(**diffae_config["diffusion"]["config"])
756
+ elif diffusion_model_type == "adp_1d":
757
+ diffusion = UNet1DCondWrapper(**diffae_config["diffusion"]["config"])
758
+ elif diffusion_model_type == "dit":
759
+ diffusion = DiTWrapper(**diffae_config["diffusion"]["config"])
760
+
761
+ latent_dim = diffae_config.get("latent_dim", None)
762
+ assert latent_dim is not None, "latent_dim must be specified in model config"
763
+ downsampling_ratio = diffae_config.get("downsampling_ratio", None)
764
+ assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
765
+ io_channels = diffae_config.get("io_channels", None)
766
+ assert io_channels is not None, "io_channels must be specified in model config"
767
+ sample_rate = config.get("sample_rate", None)
768
+ assert sample_rate is not None, "sample_rate must be specified in model config"
769
+
770
+ bottleneck = diffae_config.get("bottleneck", None)
771
+
772
+ pretransform = diffae_config.get("pretransform", None)
773
+
774
+ if pretransform is not None:
775
+ pretransform = create_pretransform_from_config(pretransform, sample_rate)
776
+
777
+ if bottleneck is not None:
778
+ bottleneck = create_bottleneck_from_config(bottleneck)
779
+
780
+ diffusion_downsampling_ratio = None,
781
+
782
+ if diffusion_model_type == "DAU1d":
783
+ diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["strides"])
784
+ elif diffusion_model_type == "adp_1d":
785
+ diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["factors"])
786
+ elif diffusion_model_type == "dit":
787
+ diffusion_downsampling_ratio = 1
788
+
789
+ return DiffusionAutoencoder(
790
+ encoder=encoder,
791
+ decoder=decoder,
792
+ diffusion=diffusion,
793
+ io_channels=io_channels,
794
+ sample_rate=sample_rate,
795
+ latent_dim=latent_dim,
796
+ downsampling_ratio=downsampling_ratio,
797
+ diffusion_downsampling_ratio=diffusion_downsampling_ratio,
798
+ bottleneck=bottleneck,
799
+ pretransform=pretransform
800
+ )
stable_audio_tools/models/blocks.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import reduce
2
+ import math
3
+ import numpy as np
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from torch.backends.cuda import sdp_kernel
9
+ from packaging import version
10
+
11
+ from dac.nn.layers import Snake1d
12
+
13
+ class ResidualBlock(nn.Module):
14
+ def __init__(self, main, skip=None):
15
+ super().__init__()
16
+ self.main = nn.Sequential(*main)
17
+ self.skip = skip if skip else nn.Identity()
18
+
19
+ def forward(self, input):
20
+ return self.main(input) + self.skip(input)
21
+
22
+ class ResConvBlock(ResidualBlock):
23
+ def __init__(self, c_in, c_mid, c_out, is_last=False, kernel_size=5, conv_bias=True, use_snake=False):
24
+ skip = None if c_in == c_out else nn.Conv1d(c_in, c_out, 1, bias=False)
25
+ super().__init__([
26
+ nn.Conv1d(c_in, c_mid, kernel_size, padding=kernel_size//2, bias=conv_bias),
27
+ nn.GroupNorm(1, c_mid),
28
+ Snake1d(c_mid) if use_snake else nn.GELU(),
29
+ nn.Conv1d(c_mid, c_out, kernel_size, padding=kernel_size//2, bias=conv_bias),
30
+ nn.GroupNorm(1, c_out) if not is_last else nn.Identity(),
31
+ (Snake1d(c_out) if use_snake else nn.GELU()) if not is_last else nn.Identity(),
32
+ ], skip)
33
+
34
+ class SelfAttention1d(nn.Module):
35
+ def __init__(self, c_in, n_head=1, dropout_rate=0.):
36
+ super().__init__()
37
+ assert c_in % n_head == 0
38
+ self.norm = nn.GroupNorm(1, c_in)
39
+ self.n_head = n_head
40
+ self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1)
41
+ self.out_proj = nn.Conv1d(c_in, c_in, 1)
42
+ self.dropout = nn.Dropout(dropout_rate, inplace=True)
43
+
44
+ self.use_flash = False
45
+
46
+ if not self.use_flash:
47
+ return
48
+
49
+ device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
50
+
51
+ if device_properties.major == 8 and device_properties.minor == 0:
52
+ # Use flash attention for A100 GPUs
53
+ self.sdp_kernel_config = (False, True, True)
54
+ else:
55
+ # Don't use flash attention for other GPUs
56
+ self.sdp_kernel_config = (False, True, True)
57
+
58
+ def forward(self, input):
59
+ n, c, s = input.shape
60
+ qkv = self.qkv_proj(self.norm(input))
61
+ qkv = qkv.view(
62
+ [n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3)
63
+ q, k, v = qkv.chunk(3, dim=1)
64
+ scale = k.shape[3]**-0.25
65
+
66
+ if self.use_flash:
67
+ with sdp_kernel(*self.sdp_kernel_config):
68
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=False).contiguous().view([n, c, s])
69
+ else:
70
+ att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3)
71
+ y = (att @ v).transpose(2, 3).contiguous().view([n, c, s])
72
+
73
+
74
+ return input + self.dropout(self.out_proj(y))
75
+
76
+ class SkipBlock(nn.Module):
77
+ def __init__(self, *main):
78
+ super().__init__()
79
+ self.main = nn.Sequential(*main)
80
+
81
+ def forward(self, input):
82
+ return torch.cat([self.main(input), input], dim=1)
83
+
84
+ class FourierFeatures(nn.Module):
85
+ def __init__(self, in_features, out_features, std=1.):
86
+ super().__init__()
87
+ assert out_features % 2 == 0
88
+ self.weight = nn.Parameter(torch.randn(
89
+ [out_features // 2, in_features]) * std)
90
+
91
+ def forward(self, input):
92
+ f = 2 * math.pi * input @ self.weight.T
93
+ return torch.cat([f.cos(), f.sin()], dim=-1)
94
+
95
+ def expand_to_planes(input, shape):
96
+ return input[..., None].repeat([1, 1, shape[2]])
97
+
98
+ _kernels = {
99
+ 'linear':
100
+ [1 / 8, 3 / 8, 3 / 8, 1 / 8],
101
+ 'cubic':
102
+ [-0.01171875, -0.03515625, 0.11328125, 0.43359375,
103
+ 0.43359375, 0.11328125, -0.03515625, -0.01171875],
104
+ 'lanczos3':
105
+ [0.003689131001010537, 0.015056144446134567, -0.03399861603975296,
106
+ -0.066637322306633, 0.13550527393817902, 0.44638532400131226,
107
+ 0.44638532400131226, 0.13550527393817902, -0.066637322306633,
108
+ -0.03399861603975296, 0.015056144446134567, 0.003689131001010537]
109
+ }
110
+
111
+ class Downsample1d(nn.Module):
112
+ def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
113
+ super().__init__()
114
+ self.pad_mode = pad_mode
115
+ kernel_1d = torch.tensor(_kernels[kernel])
116
+ self.pad = kernel_1d.shape[0] // 2 - 1
117
+ self.register_buffer('kernel', kernel_1d)
118
+ self.channels_last = channels_last
119
+
120
+ def forward(self, x):
121
+ if self.channels_last:
122
+ x = x.permute(0, 2, 1)
123
+ x = F.pad(x, (self.pad,) * 2, self.pad_mode)
124
+ weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
125
+ indices = torch.arange(x.shape[1], device=x.device)
126
+ weight[indices, indices] = self.kernel.to(weight)
127
+ x = F.conv1d(x, weight, stride=2)
128
+ if self.channels_last:
129
+ x = x.permute(0, 2, 1)
130
+ return x
131
+
132
+
133
+ class Upsample1d(nn.Module):
134
+ def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
135
+ super().__init__()
136
+ self.pad_mode = pad_mode
137
+ kernel_1d = torch.tensor(_kernels[kernel]) * 2
138
+ self.pad = kernel_1d.shape[0] // 2 - 1
139
+ self.register_buffer('kernel', kernel_1d)
140
+ self.channels_last = channels_last
141
+
142
+ def forward(self, x):
143
+ if self.channels_last:
144
+ x = x.permute(0, 2, 1)
145
+ x = F.pad(x, ((self.pad + 1) // 2,) * 2, self.pad_mode)
146
+ weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
147
+ indices = torch.arange(x.shape[1], device=x.device)
148
+ weight[indices, indices] = self.kernel.to(weight)
149
+ x = F.conv_transpose1d(x, weight, stride=2, padding=self.pad * 2 + 1)
150
+ if self.channels_last:
151
+ x = x.permute(0, 2, 1)
152
+ return x
153
+
154
+ def Downsample1d_2(
155
+ in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
156
+ ) -> nn.Module:
157
+ assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even"
158
+
159
+ return nn.Conv1d(
160
+ in_channels=in_channels,
161
+ out_channels=out_channels,
162
+ kernel_size=factor * kernel_multiplier + 1,
163
+ stride=factor,
164
+ padding=factor * (kernel_multiplier // 2),
165
+ )
166
+
167
+
168
+ def Upsample1d_2(
169
+ in_channels: int, out_channels: int, factor: int, use_nearest: bool = False
170
+ ) -> nn.Module:
171
+
172
+ if factor == 1:
173
+ return nn.Conv1d(
174
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1
175
+ )
176
+
177
+ if use_nearest:
178
+ return nn.Sequential(
179
+ nn.Upsample(scale_factor=factor, mode="nearest"),
180
+ nn.Conv1d(
181
+ in_channels=in_channels,
182
+ out_channels=out_channels,
183
+ kernel_size=3,
184
+ padding=1,
185
+ ),
186
+ )
187
+ else:
188
+ return nn.ConvTranspose1d(
189
+ in_channels=in_channels,
190
+ out_channels=out_channels,
191
+ kernel_size=factor * 2,
192
+ stride=factor,
193
+ padding=factor // 2 + factor % 2,
194
+ output_padding=factor % 2,
195
+ )
196
+
197
+ def zero_init(layer):
198
+ nn.init.zeros_(layer.weight)
199
+ if layer.bias is not None:
200
+ nn.init.zeros_(layer.bias)
201
+ return layer
202
+
203
+ def rms_norm(x, scale, eps):
204
+ dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
205
+ mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
206
+ scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
207
+ return x * scale.to(x.dtype)
208
+
209
+ rms_norm = torch.compile(rms_norm)
210
+
211
+ class AdaRMSNorm(nn.Module):
212
+ def __init__(self, features, cond_features, eps=1e-6):
213
+ super().__init__()
214
+ self.eps = eps
215
+ self.linear = zero_init(nn.Linear(cond_features, features, bias=False))
216
+
217
+ def extra_repr(self):
218
+ return f"eps={self.eps},"
219
+
220
+ def forward(self, x, cond):
221
+ return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps)
222
+
223
+ def normalize(x, eps=1e-4):
224
+ dim = list(range(1, x.ndim))
225
+ n = torch.linalg.vector_norm(x, dim=dim, keepdim=True)
226
+ alpha = np.sqrt(n.numel() / x.numel())
227
+ return x / torch.add(eps, n, alpha=alpha)
228
+
229
+ class ForcedWNConv1d(nn.Module):
230
+ def __init__(self, in_channels, out_channels, kernel_size=1):
231
+ super().__init__()
232
+ self.weight = nn.Parameter(torch.randn([out_channels, in_channels, kernel_size]))
233
+
234
+ def forward(self, x):
235
+ if self.training:
236
+ with torch.no_grad():
237
+ self.weight.copy_(normalize(self.weight))
238
+
239
+ fan_in = self.weight[0].numel()
240
+
241
+ w = normalize(self.weight) / math.sqrt(fan_in)
242
+
243
+ return F.conv1d(x, w, padding='same')
244
+
245
+ # Kernels
246
+
247
+ use_compile = True
248
+
249
+ def compile(function, *args, **kwargs):
250
+ if not use_compile:
251
+ return function
252
+ try:
253
+ return torch.compile(function, *args, **kwargs)
254
+ except RuntimeError:
255
+ return function
256
+
257
+
258
+ @compile
259
+ def linear_geglu(x, weight, bias=None):
260
+ x = x @ weight.mT
261
+ if bias is not None:
262
+ x = x + bias
263
+ x, gate = x.chunk(2, dim=-1)
264
+ return x * F.gelu(gate)
265
+
266
+
267
+ @compile
268
+ def rms_norm(x, scale, eps):
269
+ dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
270
+ mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
271
+ scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
272
+ return x * scale.to(x.dtype)
273
+
274
+ # Layers
275
+
276
+ class LinearGEGLU(nn.Linear):
277
+ def __init__(self, in_features, out_features, bias=True):
278
+ super().__init__(in_features, out_features * 2, bias=bias)
279
+ self.out_features = out_features
280
+
281
+ def forward(self, x):
282
+ return linear_geglu(x, self.weight, self.bias)
283
+
284
+
285
+ class RMSNorm(nn.Module):
286
+ def __init__(self, shape, fix_scale = False, eps=1e-6):
287
+ super().__init__()
288
+ self.eps = eps
289
+
290
+ if fix_scale:
291
+ self.register_buffer("scale", torch.ones(shape))
292
+ else:
293
+ self.scale = nn.Parameter(torch.ones(shape))
294
+
295
+ def extra_repr(self):
296
+ return f"shape={tuple(self.scale.shape)}, eps={self.eps}"
297
+
298
+ def forward(self, x):
299
+ return rms_norm(x, self.scale, self.eps)
300
+
301
+ def snake_beta(x, alpha, beta):
302
+ return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
303
+
304
+ try:
305
+ snake_beta = torch.compile(snake_beta)
306
+ except RuntimeError:
307
+ pass
308
+
309
+ # Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license
310
+ # License available in LICENSES/LICENSE_NVIDIA.txt
311
+ class SnakeBeta(nn.Module):
312
+
313
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
314
+ super(SnakeBeta, self).__init__()
315
+ self.in_features = in_features
316
+
317
+ # initialize alpha
318
+ self.alpha_logscale = alpha_logscale
319
+ if self.alpha_logscale: # log scale alphas initialized to zeros
320
+ self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
321
+ self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
322
+ else: # linear scale alphas initialized to ones
323
+ self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
324
+ self.beta = nn.Parameter(torch.ones(in_features) * alpha)
325
+
326
+ self.alpha.requires_grad = alpha_trainable
327
+ self.beta.requires_grad = alpha_trainable
328
+
329
+ self.no_div_by_zero = 0.000000001
330
+
331
+ def forward(self, x):
332
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
333
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
334
+ if self.alpha_logscale:
335
+ alpha = torch.exp(alpha)
336
+ beta = torch.exp(beta)
337
+ x = snake_beta(x, alpha, beta)
338
+
339
+ return x
stable_audio_tools/models/bottleneck.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+ from einops import rearrange
6
+ from vector_quantize_pytorch import ResidualVQ, FSQ
7
+ from dac.nn.quantize import ResidualVectorQuantize as DACResidualVQ
8
+
9
+ class Bottleneck(nn.Module):
10
+ def __init__(self, is_discrete: bool = False):
11
+ super().__init__()
12
+
13
+ self.is_discrete = is_discrete
14
+
15
+ def encode(self, x, return_info=False, **kwargs):
16
+ raise NotImplementedError
17
+
18
+ def decode(self, x):
19
+ raise NotImplementedError
20
+
21
+ class DiscreteBottleneck(Bottleneck):
22
+ def __init__(self, num_quantizers, codebook_size, tokens_id):
23
+ super().__init__(is_discrete=True)
24
+
25
+ self.num_quantizers = num_quantizers
26
+ self.codebook_size = codebook_size
27
+ self.tokens_id = tokens_id
28
+
29
+ def decode_tokens(self, codes, **kwargs):
30
+ raise NotImplementedError
31
+
32
+ class TanhBottleneck(Bottleneck):
33
+ def __init__(self):
34
+ super().__init__(is_discrete=False)
35
+ self.tanh = nn.Tanh()
36
+
37
+ def encode(self, x, return_info=False):
38
+ info = {}
39
+
40
+ x = torch.tanh(x)
41
+
42
+ if return_info:
43
+ return x, info
44
+ else:
45
+ return x
46
+
47
+ def decode(self, x):
48
+ return x
49
+
50
+ def vae_sample(mean, scale):
51
+ stdev = nn.functional.softplus(scale) + 1e-4
52
+ var = stdev * stdev
53
+ logvar = torch.log(var)
54
+ latents = torch.randn_like(mean) * stdev + mean
55
+
56
+ kl = (mean * mean + var - logvar - 1).sum(1).mean()
57
+
58
+ return latents, kl
59
+
60
+ class VAEBottleneck(Bottleneck):
61
+ def __init__(self):
62
+ super().__init__(is_discrete=False)
63
+
64
+ def encode(self, x, return_info=False, **kwargs):
65
+ info = {}
66
+
67
+ mean, scale = x.chunk(2, dim=1)
68
+
69
+ x, kl = vae_sample(mean, scale)
70
+
71
+ info["kl"] = kl
72
+
73
+ if return_info:
74
+ return x, info
75
+ else:
76
+ return x
77
+
78
+ def decode(self, x):
79
+ return x
80
+
81
+ def compute_mean_kernel(x, y):
82
+ kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1]
83
+ return torch.exp(-kernel_input).mean()
84
+
85
+ def compute_mmd(latents):
86
+ latents_reshaped = latents.permute(0, 2, 1).reshape(-1, latents.shape[1])
87
+ noise = torch.randn_like(latents_reshaped)
88
+
89
+ latents_kernel = compute_mean_kernel(latents_reshaped, latents_reshaped)
90
+ noise_kernel = compute_mean_kernel(noise, noise)
91
+ latents_noise_kernel = compute_mean_kernel(latents_reshaped, noise)
92
+
93
+ mmd = latents_kernel + noise_kernel - 2 * latents_noise_kernel
94
+ return mmd.mean()
95
+
96
+ class WassersteinBottleneck(Bottleneck):
97
+ def __init__(self, noise_augment_dim: int = 0):
98
+ super().__init__(is_discrete=False)
99
+
100
+ self.noise_augment_dim = noise_augment_dim
101
+
102
+ def encode(self, x, return_info=False):
103
+ info = {}
104
+
105
+ if self.training and return_info:
106
+ mmd = compute_mmd(x)
107
+ info["mmd"] = mmd
108
+
109
+ if return_info:
110
+ return x, info
111
+
112
+ return x
113
+
114
+ def decode(self, x):
115
+
116
+ if self.noise_augment_dim > 0:
117
+ noise = torch.randn(x.shape[0], self.noise_augment_dim,
118
+ x.shape[-1]).type_as(x)
119
+ x = torch.cat([x, noise], dim=1)
120
+
121
+ return x
122
+
123
+ class L2Bottleneck(Bottleneck):
124
+ def __init__(self):
125
+ super().__init__(is_discrete=False)
126
+
127
+ def encode(self, x, return_info=False):
128
+ info = {}
129
+
130
+ x = F.normalize(x, dim=1)
131
+
132
+ if return_info:
133
+ return x, info
134
+ else:
135
+ return x
136
+
137
+ def decode(self, x):
138
+ return F.normalize(x, dim=1)
139
+
140
+ class RVQBottleneck(DiscreteBottleneck):
141
+ def __init__(self, **quantizer_kwargs):
142
+ super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
143
+ self.quantizer = ResidualVQ(**quantizer_kwargs)
144
+ self.num_quantizers = quantizer_kwargs["num_quantizers"]
145
+
146
+ def encode(self, x, return_info=False, **kwargs):
147
+ info = {}
148
+
149
+ x = rearrange(x, "b c n -> b n c")
150
+ x, indices, loss = self.quantizer(x)
151
+ x = rearrange(x, "b n c -> b c n")
152
+
153
+ info["quantizer_indices"] = indices
154
+ info["quantizer_loss"] = loss.mean()
155
+
156
+ if return_info:
157
+ return x, info
158
+ else:
159
+ return x
160
+
161
+ def decode(self, x):
162
+ return x
163
+
164
+ def decode_tokens(self, codes, **kwargs):
165
+ latents = self.quantizer.get_outputs_from_indices(codes)
166
+
167
+ return self.decode(latents, **kwargs)
168
+
169
+ class RVQVAEBottleneck(DiscreteBottleneck):
170
+ def __init__(self, **quantizer_kwargs):
171
+ super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
172
+ self.quantizer = ResidualVQ(**quantizer_kwargs)
173
+ self.num_quantizers = quantizer_kwargs["num_quantizers"]
174
+
175
+ def encode(self, x, return_info=False):
176
+ info = {}
177
+
178
+ x, kl = vae_sample(*x.chunk(2, dim=1))
179
+
180
+ info["kl"] = kl
181
+
182
+ x = rearrange(x, "b c n -> b n c")
183
+ x, indices, loss = self.quantizer(x)
184
+ x = rearrange(x, "b n c -> b c n")
185
+
186
+ info["quantizer_indices"] = indices
187
+ info["quantizer_loss"] = loss.mean()
188
+
189
+ if return_info:
190
+ return x, info
191
+ else:
192
+ return x
193
+
194
+ def decode(self, x):
195
+ return x
196
+
197
+ def decode_tokens(self, codes, **kwargs):
198
+ latents = self.quantizer.get_outputs_from_indices(codes)
199
+
200
+ return self.decode(latents, **kwargs)
201
+
202
+ class DACRVQBottleneck(DiscreteBottleneck):
203
+ def __init__(self, quantize_on_decode=False, **quantizer_kwargs):
204
+ super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
205
+ self.quantizer = DACResidualVQ(**quantizer_kwargs)
206
+ self.num_quantizers = quantizer_kwargs["n_codebooks"]
207
+ self.quantize_on_decode = quantize_on_decode
208
+
209
+ def encode(self, x, return_info=False, **kwargs):
210
+ info = {}
211
+
212
+ info["pre_quantizer"] = x
213
+
214
+ if self.quantize_on_decode:
215
+ return x, info if return_info else x
216
+
217
+ z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, **kwargs)
218
+
219
+ output = {
220
+ "z": z,
221
+ "codes": codes,
222
+ "latents": latents,
223
+ "vq/commitment_loss": commitment_loss,
224
+ "vq/codebook_loss": codebook_loss,
225
+ }
226
+
227
+ output["vq/commitment_loss"] /= self.num_quantizers
228
+ output["vq/codebook_loss"] /= self.num_quantizers
229
+
230
+ info.update(output)
231
+
232
+ if return_info:
233
+ return output["z"], info
234
+
235
+ return output["z"]
236
+
237
+ def decode(self, x):
238
+
239
+ if self.quantize_on_decode:
240
+ x = self.quantizer(x)[0]
241
+
242
+ return x
243
+
244
+ def decode_tokens(self, codes, **kwargs):
245
+ latents, _, _ = self.quantizer.from_codes(codes)
246
+
247
+ return self.decode(latents, **kwargs)
248
+
249
+ class DACRVQVAEBottleneck(DiscreteBottleneck):
250
+ def __init__(self, quantize_on_decode=False, **quantizer_kwargs):
251
+ super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
252
+ self.quantizer = DACResidualVQ(**quantizer_kwargs)
253
+ self.num_quantizers = quantizer_kwargs["n_codebooks"]
254
+ self.quantize_on_decode = quantize_on_decode
255
+
256
+ def encode(self, x, return_info=False, n_quantizers: int = None):
257
+ info = {}
258
+
259
+ mean, scale = x.chunk(2, dim=1)
260
+
261
+ x, kl = vae_sample(mean, scale)
262
+
263
+ info["pre_quantizer"] = x
264
+ info["kl"] = kl
265
+
266
+ if self.quantize_on_decode:
267
+ return x, info if return_info else x
268
+
269
+ z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, n_quantizers=n_quantizers)
270
+
271
+ output = {
272
+ "z": z,
273
+ "codes": codes,
274
+ "latents": latents,
275
+ "vq/commitment_loss": commitment_loss,
276
+ "vq/codebook_loss": codebook_loss,
277
+ }
278
+
279
+ output["vq/commitment_loss"] /= self.num_quantizers
280
+ output["vq/codebook_loss"] /= self.num_quantizers
281
+
282
+ info.update(output)
283
+
284
+ if return_info:
285
+ return output["z"], info
286
+
287
+ return output["z"]
288
+
289
+ def decode(self, x):
290
+
291
+ if self.quantize_on_decode:
292
+ x = self.quantizer(x)[0]
293
+
294
+ return x
295
+
296
+ def decode_tokens(self, codes, **kwargs):
297
+ latents, _, _ = self.quantizer.from_codes(codes)
298
+
299
+ return self.decode(latents, **kwargs)
300
+
301
+ class FSQBottleneck(DiscreteBottleneck):
302
+ def __init__(self, dim, levels):
303
+ super().__init__(num_quantizers = 1, codebook_size = levels ** dim, tokens_id = "quantizer_indices")
304
+ self.quantizer = FSQ(levels=[levels] * dim)
305
+
306
+ def encode(self, x, return_info=False):
307
+ info = {}
308
+
309
+ x = rearrange(x, "b c n -> b n c")
310
+ x, indices = self.quantizer(x)
311
+ x = rearrange(x, "b n c -> b c n")
312
+
313
+ info["quantizer_indices"] = indices
314
+
315
+ if return_info:
316
+ return x, info
317
+ else:
318
+ return x
319
+
320
+ def decode(self, x):
321
+ return x
322
+
323
+ def decode_tokens(self, tokens, **kwargs):
324
+ latents = self.quantizer.indices_to_codes(tokens)
325
+
326
+ return self.decode(latents, **kwargs)
stable_audio_tools/models/conditioners.py ADDED
@@ -0,0 +1,558 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Heavily influenced by https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conditioners.py
2
+
3
+ import torch
4
+ import logging, warnings
5
+ import string
6
+ import typing as tp
7
+ import gc
8
+
9
+ from .adp import NumberEmbedder
10
+ from ..inference.utils import set_audio_channels
11
+ from .factory import create_pretransform_from_config
12
+ from .pretransforms import Pretransform
13
+ from ..training.utils import copy_state_dict
14
+ from .utils import load_ckpt_state_dict
15
+
16
+ from torch import nn
17
+
18
+ class Conditioner(nn.Module):
19
+ def __init__(
20
+ self,
21
+ dim: int,
22
+ output_dim: int,
23
+ project_out: bool = False,
24
+ ):
25
+
26
+ super().__init__()
27
+
28
+ self.dim = dim
29
+ self.output_dim = output_dim
30
+ self.proj_out = nn.Linear(dim, output_dim) if (dim != output_dim or project_out) else nn.Identity()
31
+
32
+ def forward(self, x: tp.Any) -> tp.Any:
33
+ raise NotImplementedError()
34
+
35
+ class IntConditioner(Conditioner):
36
+ def __init__(self,
37
+ output_dim: int,
38
+ min_val: int=0,
39
+ max_val: int=512
40
+ ):
41
+ super().__init__(output_dim, output_dim)
42
+
43
+ self.min_val = min_val
44
+ self.max_val = max_val
45
+ self.int_embedder = nn.Embedding(max_val - min_val + 1, output_dim).requires_grad_(True)
46
+
47
+ def forward(self, ints: tp.List[int], device=None) -> tp.Any:
48
+
49
+ #self.int_embedder.to(device)
50
+
51
+ ints = torch.tensor(ints).to(device)
52
+ ints = ints.clamp(self.min_val, self.max_val)
53
+
54
+ int_embeds = self.int_embedder(ints).unsqueeze(1)
55
+
56
+ return [int_embeds, torch.ones(int_embeds.shape[0], 1).to(device)]
57
+
58
+ class NumberConditioner(Conditioner):
59
+ '''
60
+ Conditioner that takes a list of floats, normalizes them for a given range, and returns a list of embeddings
61
+ '''
62
+ def __init__(self,
63
+ output_dim: int,
64
+ min_val: float=0,
65
+ max_val: float=1
66
+ ):
67
+ super().__init__(output_dim, output_dim)
68
+
69
+ self.min_val = min_val
70
+ self.max_val = max_val
71
+
72
+ self.embedder = NumberEmbedder(features=output_dim)
73
+
74
+ def forward(self, floats: tp.List[float], device=None) -> tp.Any:
75
+
76
+ # Cast the inputs to floats
77
+ floats = [float(x) for x in floats]
78
+
79
+ floats = torch.tensor(floats).to(device)
80
+
81
+ floats = floats.clamp(self.min_val, self.max_val)
82
+
83
+ normalized_floats = (floats - self.min_val) / (self.max_val - self.min_val)
84
+
85
+ # Cast floats to same type as embedder
86
+ embedder_dtype = next(self.embedder.parameters()).dtype
87
+ normalized_floats = normalized_floats.to(embedder_dtype)
88
+
89
+ float_embeds = self.embedder(normalized_floats).unsqueeze(1)
90
+
91
+ return [float_embeds, torch.ones(float_embeds.shape[0], 1).to(device)]
92
+
93
+ class CLAPTextConditioner(Conditioner):
94
+ def __init__(self,
95
+ output_dim: int,
96
+ clap_ckpt_path,
97
+ use_text_features = False,
98
+ feature_layer_ix: int = -1,
99
+ audio_model_type="HTSAT-base",
100
+ enable_fusion=True,
101
+ project_out: bool = False,
102
+ finetune: bool = False):
103
+ super().__init__(768 if use_text_features else 512, output_dim, project_out=project_out)
104
+
105
+ self.use_text_features = use_text_features
106
+ self.feature_layer_ix = feature_layer_ix
107
+ self.finetune = finetune
108
+
109
+ # Suppress logging from transformers
110
+ previous_level = logging.root.manager.disable
111
+ logging.disable(logging.ERROR)
112
+ with warnings.catch_warnings():
113
+ warnings.simplefilter("ignore")
114
+ try:
115
+ import laion_clap
116
+ from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict
117
+
118
+ model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device='cpu')
119
+
120
+ if self.finetune:
121
+ self.model = model
122
+ else:
123
+ self.__dict__["model"] = model
124
+
125
+ state_dict = clap_load_state_dict(clap_ckpt_path)
126
+ self.model.model.load_state_dict(state_dict, strict=False)
127
+
128
+ if self.finetune:
129
+ self.model.model.text_branch.requires_grad_(True)
130
+ self.model.model.text_branch.train()
131
+ else:
132
+ self.model.model.text_branch.requires_grad_(False)
133
+ self.model.model.text_branch.eval()
134
+
135
+ finally:
136
+ logging.disable(previous_level)
137
+
138
+ del self.model.model.audio_branch
139
+
140
+ gc.collect()
141
+ torch.cuda.empty_cache()
142
+
143
+ def get_clap_features(self, prompts, layer_ix=-2, device: tp.Any = "cuda"):
144
+ prompt_tokens = self.model.tokenizer(prompts)
145
+ attention_mask = prompt_tokens["attention_mask"].to(device=device, non_blocking=True)
146
+ prompt_features = self.model.model.text_branch(
147
+ input_ids=prompt_tokens["input_ids"].to(device=device, non_blocking=True),
148
+ attention_mask=attention_mask,
149
+ output_hidden_states=True
150
+ )["hidden_states"][layer_ix]
151
+
152
+ return prompt_features, attention_mask
153
+
154
+ def forward(self, texts: tp.List[str], device: tp.Any = "cuda") -> tp.Any:
155
+ self.model.to(device)
156
+
157
+ if self.use_text_features:
158
+ if len(texts) == 1:
159
+ text_features, text_attention_mask = self.get_clap_features([texts[0], ""], layer_ix=self.feature_layer_ix, device=device)
160
+ text_features = text_features[:1, ...]
161
+ text_attention_mask = text_attention_mask[:1, ...]
162
+ else:
163
+ text_features, text_attention_mask = self.get_clap_features(texts, layer_ix=self.feature_layer_ix, device=device)
164
+ return [self.proj_out(text_features), text_attention_mask]
165
+
166
+ # Fix for CLAP bug when only one text is passed
167
+ if len(texts) == 1:
168
+ text_embedding = self.model.get_text_embedding([texts[0], ""], use_tensor=True)[:1, ...]
169
+ else:
170
+ text_embedding = self.model.get_text_embedding(texts, use_tensor=True)
171
+
172
+ text_embedding = text_embedding.unsqueeze(1).to(device)
173
+
174
+ return [self.proj_out(text_embedding), torch.ones(text_embedding.shape[0], 1).to(device)]
175
+
176
+ class CLAPAudioConditioner(Conditioner):
177
+ def __init__(self,
178
+ output_dim: int,
179
+ clap_ckpt_path,
180
+ audio_model_type="HTSAT-base",
181
+ enable_fusion=True,
182
+ project_out: bool = False):
183
+ super().__init__(512, output_dim, project_out=project_out)
184
+
185
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
186
+
187
+ # Suppress logging from transformers
188
+ previous_level = logging.root.manager.disable
189
+ logging.disable(logging.ERROR)
190
+ with warnings.catch_warnings():
191
+ warnings.simplefilter("ignore")
192
+ try:
193
+ import laion_clap
194
+ from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict
195
+
196
+ model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device='cpu')
197
+
198
+ if self.finetune:
199
+ self.model = model
200
+ else:
201
+ self.__dict__["model"] = model
202
+
203
+ state_dict = clap_load_state_dict(clap_ckpt_path)
204
+ self.model.model.load_state_dict(state_dict, strict=False)
205
+
206
+ if self.finetune:
207
+ self.model.model.audio_branch.requires_grad_(True)
208
+ self.model.model.audio_branch.train()
209
+ else:
210
+ self.model.model.audio_branch.requires_grad_(False)
211
+ self.model.model.audio_branch.eval()
212
+
213
+ finally:
214
+ logging.disable(previous_level)
215
+
216
+ del self.model.model.text_branch
217
+
218
+ gc.collect()
219
+ torch.cuda.empty_cache()
220
+
221
+ def forward(self, audios: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]] , device: tp.Any = "cuda") -> tp.Any:
222
+
223
+ self.model.to(device)
224
+
225
+ if isinstance(audios, list) or isinstance(audios, tuple):
226
+ audios = torch.cat(audios, dim=0)
227
+
228
+ # Convert to mono
229
+ mono_audios = audios.mean(dim=1)
230
+
231
+ with torch.cuda.amp.autocast(enabled=False):
232
+ audio_embedding = self.model.get_audio_embedding_from_data(mono_audios.float(), use_tensor=True)
233
+
234
+ audio_embedding = audio_embedding.unsqueeze(1).to(device)
235
+
236
+ return [self.proj_out(audio_embedding), torch.ones(audio_embedding.shape[0], 1).to(device)]
237
+
238
+ class T5Conditioner(Conditioner):
239
+
240
+ T5_MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b",
241
+ "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large",
242
+ "google/flan-t5-xl", "google/flan-t5-xxl"]
243
+
244
+ T5_MODEL_DIMS = {
245
+ "t5-small": 512,
246
+ "t5-base": 768,
247
+ "t5-large": 1024,
248
+ "t5-3b": 1024,
249
+ "t5-11b": 1024,
250
+ "t5-xl": 2048,
251
+ "t5-xxl": 4096,
252
+ "google/flan-t5-small": 512,
253
+ "google/flan-t5-base": 768,
254
+ "google/flan-t5-large": 1024,
255
+ "google/flan-t5-3b": 1024,
256
+ "google/flan-t5-11b": 1024,
257
+ "google/flan-t5-xl": 2048,
258
+ "google/flan-t5-xxl": 4096,
259
+ }
260
+
261
+ def __init__(
262
+ self,
263
+ output_dim: int,
264
+ t5_model_name: str = "t5-base",
265
+ max_length: str = 128,
266
+ enable_grad: bool = False,
267
+ project_out: bool = False,
268
+ ):
269
+ assert t5_model_name in self.T5_MODELS, f"Unknown T5 model name: {t5_model_name}"
270
+ super().__init__(self.T5_MODEL_DIMS[t5_model_name], output_dim, project_out=project_out)
271
+
272
+ from transformers import T5EncoderModel, AutoTokenizer
273
+
274
+ self.max_length = max_length
275
+ self.enable_grad = enable_grad
276
+
277
+ # Suppress logging from transformers
278
+ previous_level = logging.root.manager.disable
279
+ logging.disable(logging.ERROR)
280
+ with warnings.catch_warnings():
281
+ warnings.simplefilter("ignore")
282
+ try:
283
+ # self.tokenizer = T5Tokenizer.from_pretrained(t5_model_name, model_max_length = max_length)
284
+ # model = T5EncoderModel.from_pretrained(t5_model_name, max_length=max_length).train(enable_grad).requires_grad_(enable_grad)
285
+ self.tokenizer = AutoTokenizer.from_pretrained(t5_model_name)
286
+ model = T5EncoderModel.from_pretrained(t5_model_name).train(enable_grad).requires_grad_(enable_grad).to(torch.float16)
287
+ finally:
288
+ logging.disable(previous_level)
289
+
290
+ if self.enable_grad:
291
+ self.model = model
292
+ else:
293
+ self.__dict__["model"] = model
294
+
295
+
296
+ def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
297
+
298
+ self.model.to(device)
299
+ self.proj_out.to(device)
300
+
301
+ encoded = self.tokenizer(
302
+ texts,
303
+ truncation=True,
304
+ max_length=self.max_length,
305
+ padding="max_length",
306
+ return_tensors="pt",
307
+ )
308
+
309
+ input_ids = encoded["input_ids"].to(device)
310
+ attention_mask = encoded["attention_mask"].to(device).to(torch.bool)
311
+
312
+ self.model.eval()
313
+
314
+ with torch.cuda.amp.autocast(dtype=torch.float16) and torch.set_grad_enabled(self.enable_grad):
315
+ embeddings = self.model(
316
+ input_ids=input_ids, attention_mask=attention_mask
317
+ )["last_hidden_state"]
318
+
319
+ embeddings = self.proj_out(embeddings.float())
320
+
321
+ embeddings = embeddings * attention_mask.unsqueeze(-1).float()
322
+
323
+ return embeddings, attention_mask
324
+
325
+ class PhonemeConditioner(Conditioner):
326
+ """
327
+ A conditioner that turns text into phonemes and embeds them using a lookup table
328
+ Only works for English text
329
+
330
+ Args:
331
+ output_dim: the dimension of the output embeddings
332
+ max_length: the maximum number of phonemes to embed
333
+ project_out: whether to add another linear projection to the output embeddings
334
+ """
335
+
336
+ def __init__(
337
+ self,
338
+ output_dim: int,
339
+ max_length: int = 1024,
340
+ project_out: bool = False,
341
+ ):
342
+ super().__init__(output_dim, output_dim, project_out=project_out)
343
+
344
+ from g2p_en import G2p
345
+
346
+ self.max_length = max_length
347
+
348
+ self.g2p = G2p()
349
+
350
+ # Reserving 0 for padding, 1 for ignored
351
+ self.phoneme_embedder = nn.Embedding(len(self.g2p.phonemes) + 2, output_dim)
352
+
353
+ def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
354
+
355
+ self.phoneme_embedder.to(device)
356
+ self.proj_out.to(device)
357
+
358
+ batch_phonemes = [self.g2p(text) for text in texts] # shape [batch_size, length]
359
+
360
+ phoneme_ignore = [" ", *string.punctuation]
361
+
362
+ # Remove ignored phonemes and cut to max length
363
+ batch_phonemes = [[p if p not in phoneme_ignore else "_" for p in phonemes] for phonemes in batch_phonemes]
364
+
365
+ # Convert to ids
366
+ phoneme_ids = [[self.g2p.p2idx[p] + 2 if p in self.g2p.p2idx else 1 for p in phonemes] for phonemes in batch_phonemes]
367
+
368
+ #Pad to match longest and make a mask tensor for the padding
369
+ longest = max([len(ids) for ids in phoneme_ids])
370
+ phoneme_ids = [ids + [0] * (longest - len(ids)) for ids in phoneme_ids]
371
+
372
+ phoneme_ids = torch.tensor(phoneme_ids).to(device)
373
+
374
+ # Convert to embeddings
375
+ phoneme_embeds = self.phoneme_embedder(phoneme_ids)
376
+
377
+ phoneme_embeds = self.proj_out(phoneme_embeds)
378
+
379
+ return phoneme_embeds, torch.ones(phoneme_embeds.shape[0], phoneme_embeds.shape[1]).to(device)
380
+
381
+ class TokenizerLUTConditioner(Conditioner):
382
+ """
383
+ A conditioner that embeds text using a lookup table on a pretrained tokenizer's vocabulary
384
+
385
+ Args:
386
+ tokenizer_name: the name of the tokenizer from the Hugging Face transformers library
387
+ output_dim: the dimension of the output embeddings
388
+ max_length: the maximum length of the text to embed
389
+ project_out: whether to add another linear projection to the output embeddings
390
+ """
391
+
392
+ def __init__(
393
+ self,
394
+ tokenizer_name: str, # Name of a tokenizer from the Hugging Face transformers library
395
+ output_dim: int,
396
+ max_length: int = 1024,
397
+ project_out: bool = False,
398
+ ):
399
+ super().__init__(output_dim, output_dim, project_out=project_out)
400
+
401
+ from transformers import AutoTokenizer
402
+
403
+ # Suppress logging from transformers
404
+ previous_level = logging.root.manager.disable
405
+ logging.disable(logging.ERROR)
406
+ with warnings.catch_warnings():
407
+ warnings.simplefilter("ignore")
408
+ try:
409
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
410
+ finally:
411
+ logging.disable(previous_level)
412
+
413
+ self.max_length = max_length
414
+
415
+ self.token_embedder = nn.Embedding(len(self.tokenizer), output_dim)
416
+
417
+ def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
418
+ self.proj_out.to(device)
419
+
420
+ encoded = self.tokenizer(
421
+ texts,
422
+ truncation=True,
423
+ max_length=self.max_length,
424
+ padding="max_length",
425
+ return_tensors="pt",
426
+ )
427
+
428
+ input_ids = encoded["input_ids"].to(device)
429
+ attention_mask = encoded["attention_mask"].to(device).to(torch.bool)
430
+
431
+ embeddings = self.token_embedder(input_ids)
432
+
433
+ embeddings = self.proj_out(embeddings)
434
+
435
+ embeddings = embeddings * attention_mask.unsqueeze(-1).float()
436
+
437
+ return embeddings, attention_mask
438
+
439
+ class PretransformConditioner(Conditioner):
440
+ """
441
+ A conditioner that uses a pretransform's encoder for conditioning
442
+
443
+ Args:
444
+ pretransform: an instantiated pretransform to use for conditioning
445
+ output_dim: the dimension of the output embeddings
446
+ """
447
+ def __init__(self, pretransform: Pretransform, output_dim: int):
448
+ super().__init__(pretransform.encoded_channels, output_dim)
449
+
450
+ self.pretransform = pretransform
451
+
452
+ def forward(self, audio: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
453
+
454
+ self.pretransform.to(device)
455
+ self.proj_out.to(device)
456
+
457
+ if isinstance(audio, list) or isinstance(audio, tuple):
458
+ audio = torch.cat(audio, dim=0)
459
+
460
+ # Convert audio to pretransform input channels
461
+ audio = set_audio_channels(audio, self.pretransform.io_channels)
462
+
463
+ latents = self.pretransform.encode(audio)
464
+
465
+ latents = self.proj_out(latents)
466
+
467
+ return [latents, torch.ones(latents.shape[0], latents.shape[2]).to(latents.device)]
468
+
469
+ class MultiConditioner(nn.Module):
470
+ """
471
+ A module that applies multiple conditioners to an input dictionary based on the keys
472
+
473
+ Args:
474
+ conditioners: a dictionary of conditioners with keys corresponding to the keys of the conditioning input dictionary (e.g. "prompt")
475
+ default_keys: a dictionary of default keys to use if the key is not in the input dictionary (e.g. {"prompt_t5": "prompt"})
476
+ """
477
+ def __init__(self, conditioners: tp.Dict[str, Conditioner], default_keys: tp.Dict[str, str] = {}):
478
+ super().__init__()
479
+
480
+ self.conditioners = nn.ModuleDict(conditioners)
481
+ self.default_keys = default_keys
482
+
483
+ def forward(self, batch_metadata: tp.List[tp.Dict[str, tp.Any]], device: tp.Union[torch.device, str]) -> tp.Dict[str, tp.Any]:
484
+ output = {}
485
+
486
+ for key, conditioner in self.conditioners.items():
487
+ condition_key = key
488
+
489
+ conditioner_inputs = []
490
+
491
+ for x in batch_metadata:
492
+
493
+ if condition_key not in x:
494
+ if condition_key in self.default_keys:
495
+ condition_key = self.default_keys[condition_key]
496
+ else:
497
+ raise ValueError(f"Conditioner key {condition_key} not found in batch metadata")
498
+
499
+ #Unwrap the condition info if it's a single-element list or tuple, this is to support collation functions that wrap everything in a list
500
+ if isinstance(x[condition_key], list) or isinstance(x[condition_key], tuple) and len(x[condition_key]) == 1:
501
+ conditioner_inputs.append(x[condition_key][0])
502
+ else:
503
+ conditioner_inputs.append(x[condition_key])
504
+
505
+ output[key] = conditioner(conditioner_inputs, device)
506
+
507
+ return output
508
+
509
+ def create_multi_conditioner_from_conditioning_config(config: tp.Dict[str, tp.Any]) -> MultiConditioner:
510
+ """
511
+ Create a MultiConditioner from a conditioning config dictionary
512
+
513
+ Args:
514
+ config: the conditioning config dictionary
515
+ device: the device to put the conditioners on
516
+ """
517
+ conditioners = {}
518
+ cond_dim = config["cond_dim"]
519
+
520
+ default_keys = config.get("default_keys", {})
521
+
522
+ for conditioner_info in config["configs"]:
523
+ id = conditioner_info["id"]
524
+
525
+ conditioner_type = conditioner_info["type"]
526
+
527
+ conditioner_config = {"output_dim": cond_dim}
528
+
529
+ conditioner_config.update(conditioner_info["config"])
530
+
531
+ if conditioner_type == "t5":
532
+ conditioners[id] = T5Conditioner(**conditioner_config)
533
+ elif conditioner_type == "clap_text":
534
+ conditioners[id] = CLAPTextConditioner(**conditioner_config)
535
+ elif conditioner_type == "clap_audio":
536
+ conditioners[id] = CLAPAudioConditioner(**conditioner_config)
537
+ elif conditioner_type == "int":
538
+ conditioners[id] = IntConditioner(**conditioner_config)
539
+ elif conditioner_type == "number":
540
+ conditioners[id] = NumberConditioner(**conditioner_config)
541
+ elif conditioner_type == "phoneme":
542
+ conditioners[id] = PhonemeConditioner(**conditioner_config)
543
+ elif conditioner_type == "lut":
544
+ conditioners[id] = TokenizerLUTConditioner(**conditioner_config)
545
+ elif conditioner_type == "pretransform":
546
+ sample_rate = conditioner_config.pop("sample_rate", None)
547
+ assert sample_rate is not None, "Sample rate must be specified for pretransform conditioners"
548
+
549
+ pretransform = create_pretransform_from_config(conditioner_config.pop("pretransform_config"), sample_rate=sample_rate)
550
+
551
+ if conditioner_config.get("pretransform_ckpt_path", None) is not None:
552
+ pretransform.load_state_dict(load_ckpt_state_dict(conditioner_config.pop("pretransform_ckpt_path")))
553
+
554
+ conditioners[id] = PretransformConditioner(pretransform, **conditioner_config)
555
+ else:
556
+ raise ValueError(f"Unknown conditioner type: {conditioner_type}")
557
+
558
+ return MultiConditioner(conditioners, default_keys=default_keys)
stable_audio_tools/models/diffusion.py ADDED
@@ -0,0 +1,678 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+ from functools import partial, reduce
5
+ import numpy as np
6
+ import typing as tp
7
+
8
+ from .blocks import ResConvBlock, FourierFeatures, Upsample1d, Upsample1d_2, Downsample1d, Downsample1d_2, SelfAttention1d, SkipBlock, expand_to_planes
9
+ from .conditioners import MultiConditioner, create_multi_conditioner_from_conditioning_config
10
+ from .dit import DiffusionTransformer
11
+ from .factory import create_pretransform_from_config
12
+ from .pretransforms import Pretransform
13
+ from ..inference.generation import generate_diffusion_cond
14
+
15
+ from .adp import UNetCFG1d, UNet1d
16
+
17
+ from time import time
18
+
19
+ class Profiler:
20
+
21
+ def __init__(self):
22
+ self.ticks = [[time(), None]]
23
+
24
+ def tick(self, msg):
25
+ self.ticks.append([time(), msg])
26
+
27
+ def __repr__(self):
28
+ rep = 80 * "=" + "\n"
29
+ for i in range(1, len(self.ticks)):
30
+ msg = self.ticks[i][1]
31
+ ellapsed = self.ticks[i][0] - self.ticks[i - 1][0]
32
+ rep += msg + f": {ellapsed*1000:.2f}ms\n"
33
+ rep += 80 * "=" + "\n\n\n"
34
+ return rep
35
+
36
+ class DiffusionModel(nn.Module):
37
+ def __init__(self, *args, **kwargs):
38
+ super().__init__(*args, **kwargs)
39
+
40
+ def forward(self, x, t, **kwargs):
41
+ raise NotImplementedError()
42
+
43
+ class DiffusionModelWrapper(nn.Module):
44
+ def __init__(
45
+ self,
46
+ model: DiffusionModel,
47
+ io_channels,
48
+ sample_size,
49
+ sample_rate,
50
+ min_input_length,
51
+ pretransform: tp.Optional[Pretransform] = None,
52
+ ):
53
+ super().__init__()
54
+ self.io_channels = io_channels
55
+ self.sample_size = sample_size
56
+ self.sample_rate = sample_rate
57
+ self.min_input_length = min_input_length
58
+
59
+ self.model = model
60
+
61
+ if pretransform is not None:
62
+ self.pretransform = pretransform
63
+ else:
64
+ self.pretransform = None
65
+
66
+ def forward(self, x, t, **kwargs):
67
+ return self.model(x, t, **kwargs)
68
+
69
+ class ConditionedDiffusionModel(nn.Module):
70
+ def __init__(self,
71
+ *args,
72
+ supports_cross_attention: bool = False,
73
+ supports_input_concat: bool = False,
74
+ supports_global_cond: bool = False,
75
+ supports_prepend_cond: bool = False,
76
+ **kwargs):
77
+ super().__init__(*args, **kwargs)
78
+ self.supports_cross_attention = supports_cross_attention
79
+ self.supports_input_concat = supports_input_concat
80
+ self.supports_global_cond = supports_global_cond
81
+ self.supports_prepend_cond = supports_prepend_cond
82
+
83
+ def forward(self,
84
+ x: torch.Tensor,
85
+ t: torch.Tensor,
86
+ cross_attn_cond: torch.Tensor = None,
87
+ cross_attn_mask: torch.Tensor = None,
88
+ input_concat_cond: torch.Tensor = None,
89
+ global_embed: torch.Tensor = None,
90
+ prepend_cond: torch.Tensor = None,
91
+ prepend_cond_mask: torch.Tensor = None,
92
+ cfg_scale: float = 1.0,
93
+ cfg_dropout_prob: float = 0.0,
94
+ batch_cfg: bool = False,
95
+ rescale_cfg: bool = False,
96
+ **kwargs):
97
+ raise NotImplementedError()
98
+
99
+ class ConditionedDiffusionModelWrapper(nn.Module):
100
+ """
101
+ A diffusion model that takes in conditioning
102
+ """
103
+ def __init__(
104
+ self,
105
+ model: ConditionedDiffusionModel,
106
+ conditioner: MultiConditioner,
107
+ io_channels,
108
+ sample_rate,
109
+ min_input_length: int,
110
+ pretransform: tp.Optional[Pretransform] = None,
111
+ cross_attn_cond_ids: tp.List[str] = [],
112
+ global_cond_ids: tp.List[str] = [],
113
+ input_concat_ids: tp.List[str] = [],
114
+ prepend_cond_ids: tp.List[str] = [],
115
+ ):
116
+ super().__init__()
117
+
118
+ self.model = model
119
+ self.conditioner = conditioner
120
+ self.io_channels = io_channels
121
+ self.sample_rate = sample_rate
122
+ self.pretransform = pretransform
123
+ self.cross_attn_cond_ids = cross_attn_cond_ids
124
+ self.global_cond_ids = global_cond_ids
125
+ self.input_concat_ids = input_concat_ids
126
+ self.prepend_cond_ids = prepend_cond_ids
127
+ self.min_input_length = min_input_length
128
+
129
+ def get_conditioning_inputs(self, cond: tp.Dict[str, tp.Any], negative=False):
130
+ cross_attention_input = None
131
+ cross_attention_masks = None
132
+ global_cond = None
133
+ input_concat_cond = None
134
+ prepend_cond = None
135
+ prepend_cond_mask = None
136
+
137
+ if len(self.cross_attn_cond_ids) > 0:
138
+ # Concatenate all cross-attention inputs over the sequence dimension
139
+ # Assumes that the cross-attention inputs are of shape (batch, seq, channels)
140
+ cross_attention_input = []
141
+ cross_attention_masks = []
142
+
143
+ for key in self.cross_attn_cond_ids:
144
+ cross_attn_in, cross_attn_mask = cond[key]
145
+
146
+ # Add sequence dimension if it's not there
147
+ if len(cross_attn_in.shape) == 2:
148
+ cross_attn_in = cross_attn_in.unsqueeze(1)
149
+ cross_attn_mask = cross_attn_mask.unsqueeze(1)
150
+
151
+ cross_attention_input.append(cross_attn_in)
152
+ cross_attention_masks.append(cross_attn_mask)
153
+
154
+ cross_attention_input = torch.cat(cross_attention_input, dim=1)
155
+ cross_attention_masks = torch.cat(cross_attention_masks, dim=1)
156
+
157
+ if len(self.global_cond_ids) > 0:
158
+ # Concatenate all global conditioning inputs over the channel dimension
159
+ # Assumes that the global conditioning inputs are of shape (batch, channels)
160
+ global_cond = torch.cat([cond[key][0] for key in self.global_cond_ids], dim=-1)
161
+ if len(global_cond.shape) == 3:
162
+ global_cond = global_cond.squeeze(1)
163
+
164
+ if len(self.input_concat_ids) > 0:
165
+ # Concatenate all input concat conditioning inputs over the channel dimension
166
+ # Assumes that the input concat conditioning inputs are of shape (batch, channels, seq)
167
+ input_concat_cond = torch.cat([cond[key][0] for key in self.input_concat_ids], dim=1)
168
+
169
+ if len(self.prepend_cond_ids) > 0:
170
+ # Concatenate all prepend conditioning inputs over the sequence dimension
171
+ # Assumes that the prepend conditioning inputs are of shape (batch, seq, channels)
172
+ prepend_cond = torch.cat([cond[key][0] for key in self.prepend_cond_ids], dim=1)
173
+ prepend_cond_mask = torch.cat([cond[key][1] for key in self.prepend_cond_ids], dim=1)
174
+
175
+ if negative:
176
+ return {
177
+ "negative_cross_attn_cond": cross_attention_input,
178
+ "negative_cross_attn_mask": cross_attention_masks,
179
+ "negative_global_cond": global_cond,
180
+ "negative_input_concat_cond": input_concat_cond
181
+ }
182
+ else:
183
+ return {
184
+ "cross_attn_cond": cross_attention_input,
185
+ "cross_attn_mask": cross_attention_masks,
186
+ "global_cond": global_cond,
187
+ "input_concat_cond": input_concat_cond,
188
+ "prepend_cond": prepend_cond,
189
+ "prepend_cond_mask": prepend_cond_mask
190
+ }
191
+
192
+ def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs):
193
+ return self.model(x, t, **self.get_conditioning_inputs(cond), **kwargs)
194
+
195
+ def generate(self, *args, **kwargs):
196
+ return generate_diffusion_cond(self, *args, **kwargs)
197
+
198
+ class UNetCFG1DWrapper(ConditionedDiffusionModel):
199
+ def __init__(
200
+ self,
201
+ *args,
202
+ **kwargs
203
+ ):
204
+ super().__init__(supports_cross_attention=True, supports_global_cond=True, supports_input_concat=True)
205
+
206
+ self.model = UNetCFG1d(*args, **kwargs)
207
+
208
+ with torch.no_grad():
209
+ for param in self.model.parameters():
210
+ param *= 0.5
211
+
212
+ def forward(self,
213
+ x,
214
+ t,
215
+ cross_attn_cond=None,
216
+ cross_attn_mask=None,
217
+ input_concat_cond=None,
218
+ global_cond=None,
219
+ cfg_scale=1.0,
220
+ cfg_dropout_prob: float = 0.0,
221
+ batch_cfg: bool = False,
222
+ rescale_cfg: bool = False,
223
+ negative_cross_attn_cond=None,
224
+ negative_cross_attn_mask=None,
225
+ negative_global_cond=None,
226
+ negative_input_concat_cond=None,
227
+ prepend_cond=None,
228
+ prepend_cond_mask=None,
229
+ **kwargs):
230
+ p = Profiler()
231
+
232
+ p.tick("start")
233
+
234
+ channels_list = None
235
+ if input_concat_cond is not None:
236
+ channels_list = [input_concat_cond]
237
+
238
+ outputs = self.model(
239
+ x,
240
+ t,
241
+ embedding=cross_attn_cond,
242
+ embedding_mask=cross_attn_mask,
243
+ features=global_cond,
244
+ channels_list=channels_list,
245
+ embedding_scale=cfg_scale,
246
+ embedding_mask_proba=cfg_dropout_prob,
247
+ batch_cfg=batch_cfg,
248
+ rescale_cfg=rescale_cfg,
249
+ negative_embedding=negative_cross_attn_cond,
250
+ negative_embedding_mask=negative_cross_attn_mask,
251
+ **kwargs)
252
+
253
+ p.tick("UNetCFG1D forward")
254
+
255
+ #print(f"Profiler: {p}")
256
+ return outputs
257
+
258
+ class UNet1DCondWrapper(ConditionedDiffusionModel):
259
+ def __init__(
260
+ self,
261
+ *args,
262
+ **kwargs
263
+ ):
264
+ super().__init__(supports_cross_attention=False, supports_global_cond=True, supports_input_concat=True)
265
+
266
+ self.model = UNet1d(*args, **kwargs)
267
+
268
+ with torch.no_grad():
269
+ for param in self.model.parameters():
270
+ param *= 0.5
271
+
272
+ def forward(self,
273
+ x,
274
+ t,
275
+ input_concat_cond=None,
276
+ global_cond=None,
277
+ cross_attn_cond=None,
278
+ cross_attn_mask=None,
279
+ prepend_cond=None,
280
+ prepend_cond_mask=None,
281
+ cfg_scale=1.0,
282
+ cfg_dropout_prob: float = 0.0,
283
+ batch_cfg: bool = False,
284
+ rescale_cfg: bool = False,
285
+ negative_cross_attn_cond=None,
286
+ negative_cross_attn_mask=None,
287
+ negative_global_cond=None,
288
+ negative_input_concat_cond=None,
289
+ **kwargs):
290
+
291
+ channels_list = None
292
+ if input_concat_cond is not None:
293
+
294
+ # Interpolate input_concat_cond to the same length as x
295
+ if input_concat_cond.shape[2] != x.shape[2]:
296
+ input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')
297
+
298
+ channels_list = [input_concat_cond]
299
+
300
+ outputs = self.model(
301
+ x,
302
+ t,
303
+ features=global_cond,
304
+ channels_list=channels_list,
305
+ **kwargs)
306
+
307
+ return outputs
308
+
309
+ class UNet1DUncondWrapper(DiffusionModel):
310
+ def __init__(
311
+ self,
312
+ in_channels,
313
+ *args,
314
+ **kwargs
315
+ ):
316
+ super().__init__()
317
+
318
+ self.model = UNet1d(in_channels=in_channels, *args, **kwargs)
319
+
320
+ self.io_channels = in_channels
321
+
322
+ with torch.no_grad():
323
+ for param in self.model.parameters():
324
+ param *= 0.5
325
+
326
+ def forward(self, x, t, **kwargs):
327
+ return self.model(x, t, **kwargs)
328
+
329
+ class DAU1DCondWrapper(ConditionedDiffusionModel):
330
+ def __init__(
331
+ self,
332
+ *args,
333
+ **kwargs
334
+ ):
335
+ super().__init__(supports_cross_attention=False, supports_global_cond=False, supports_input_concat=True)
336
+
337
+ self.model = DiffusionAttnUnet1D(*args, **kwargs)
338
+
339
+ with torch.no_grad():
340
+ for param in self.model.parameters():
341
+ param *= 0.5
342
+
343
+ def forward(self,
344
+ x,
345
+ t,
346
+ input_concat_cond=None,
347
+ cross_attn_cond=None,
348
+ cross_attn_mask=None,
349
+ global_cond=None,
350
+ cfg_scale=1.0,
351
+ cfg_dropout_prob: float = 0.0,
352
+ batch_cfg: bool = False,
353
+ rescale_cfg: bool = False,
354
+ negative_cross_attn_cond=None,
355
+ negative_cross_attn_mask=None,
356
+ negative_global_cond=None,
357
+ negative_input_concat_cond=None,
358
+ prepend_cond=None,
359
+ **kwargs):
360
+
361
+ return self.model(x, t, cond = input_concat_cond)
362
+
363
+ class DiffusionAttnUnet1D(nn.Module):
364
+ def __init__(
365
+ self,
366
+ io_channels = 2,
367
+ depth=14,
368
+ n_attn_layers = 6,
369
+ channels = [128, 128, 256, 256] + [512] * 10,
370
+ cond_dim = 0,
371
+ cond_noise_aug = False,
372
+ kernel_size = 5,
373
+ learned_resample = False,
374
+ strides = [2] * 13,
375
+ conv_bias = True,
376
+ use_snake = False
377
+ ):
378
+ super().__init__()
379
+
380
+ self.cond_noise_aug = cond_noise_aug
381
+
382
+ self.io_channels = io_channels
383
+
384
+ if self.cond_noise_aug:
385
+ self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
386
+
387
+ self.timestep_embed = FourierFeatures(1, 16)
388
+
389
+ attn_layer = depth - n_attn_layers
390
+
391
+ strides = [1] + strides
392
+
393
+ block = nn.Identity()
394
+
395
+ conv_block = partial(ResConvBlock, kernel_size=kernel_size, conv_bias = conv_bias, use_snake=use_snake)
396
+
397
+ for i in range(depth, 0, -1):
398
+ c = channels[i - 1]
399
+ stride = strides[i-1]
400
+ if stride > 2 and not learned_resample:
401
+ raise ValueError("Must have stride 2 without learned resampling")
402
+
403
+ if i > 1:
404
+ c_prev = channels[i - 2]
405
+ add_attn = i >= attn_layer and n_attn_layers > 0
406
+ block = SkipBlock(
407
+ Downsample1d_2(c_prev, c_prev, stride) if (learned_resample or stride == 1) else Downsample1d("cubic"),
408
+ conv_block(c_prev, c, c),
409
+ SelfAttention1d(
410
+ c, c // 32) if add_attn else nn.Identity(),
411
+ conv_block(c, c, c),
412
+ SelfAttention1d(
413
+ c, c // 32) if add_attn else nn.Identity(),
414
+ conv_block(c, c, c),
415
+ SelfAttention1d(
416
+ c, c // 32) if add_attn else nn.Identity(),
417
+ block,
418
+ conv_block(c * 2 if i != depth else c, c, c),
419
+ SelfAttention1d(
420
+ c, c // 32) if add_attn else nn.Identity(),
421
+ conv_block(c, c, c),
422
+ SelfAttention1d(
423
+ c, c // 32) if add_attn else nn.Identity(),
424
+ conv_block(c, c, c_prev),
425
+ SelfAttention1d(c_prev, c_prev //
426
+ 32) if add_attn else nn.Identity(),
427
+ Upsample1d_2(c_prev, c_prev, stride) if learned_resample else Upsample1d(kernel="cubic")
428
+ )
429
+ else:
430
+ cond_embed_dim = 16 if not self.cond_noise_aug else 32
431
+ block = nn.Sequential(
432
+ conv_block((io_channels + cond_dim) + cond_embed_dim, c, c),
433
+ conv_block(c, c, c),
434
+ conv_block(c, c, c),
435
+ block,
436
+ conv_block(c * 2, c, c),
437
+ conv_block(c, c, c),
438
+ conv_block(c, c, io_channels, is_last=True),
439
+ )
440
+ self.net = block
441
+
442
+ with torch.no_grad():
443
+ for param in self.net.parameters():
444
+ param *= 0.5
445
+
446
+ def forward(self, x, t, cond=None, cond_aug_scale=None):
447
+
448
+ timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), x.shape)
449
+
450
+ inputs = [x, timestep_embed]
451
+
452
+ if cond is not None:
453
+ if cond.shape[2] != x.shape[2]:
454
+ cond = F.interpolate(cond, (x.shape[2], ), mode='linear', align_corners=False)
455
+
456
+ if self.cond_noise_aug:
457
+ # Get a random number between 0 and 1, uniformly sampled
458
+ if cond_aug_scale is None:
459
+ aug_level = self.rng.draw(cond.shape[0])[:, 0].to(cond)
460
+ else:
461
+ aug_level = torch.tensor([cond_aug_scale]).repeat([cond.shape[0]]).to(cond)
462
+
463
+ # Add noise to the conditioning signal
464
+ cond = cond + torch.randn_like(cond) * aug_level[:, None, None]
465
+
466
+ # Get embedding for noise cond level, reusing timestamp_embed
467
+ aug_level_embed = expand_to_planes(self.timestep_embed(aug_level[:, None]), x.shape)
468
+
469
+ inputs.append(aug_level_embed)
470
+
471
+ inputs.append(cond)
472
+
473
+ outputs = self.net(torch.cat(inputs, dim=1))
474
+
475
+ return outputs
476
+
477
+ class DiTWrapper(ConditionedDiffusionModel):
478
+ def __init__(
479
+ self,
480
+ *args,
481
+ **kwargs
482
+ ):
483
+ super().__init__(supports_cross_attention=True, supports_global_cond=False, supports_input_concat=False)
484
+
485
+ self.model = DiffusionTransformer(*args, **kwargs)
486
+
487
+ with torch.no_grad():
488
+ for param in self.model.parameters():
489
+ param *= 0.5
490
+
491
+ def forward(self,
492
+ x,
493
+ t,
494
+ cross_attn_cond=None,
495
+ cross_attn_mask=None,
496
+ negative_cross_attn_cond=None,
497
+ negative_cross_attn_mask=None,
498
+ input_concat_cond=None,
499
+ negative_input_concat_cond=None,
500
+ global_cond=None,
501
+ negative_global_cond=None,
502
+ prepend_cond=None,
503
+ prepend_cond_mask=None,
504
+ cfg_scale=1.0,
505
+ cfg_dropout_prob: float = 0.0,
506
+ batch_cfg: bool = True,
507
+ rescale_cfg: bool = False,
508
+ scale_phi: float = 0.0,
509
+ **kwargs):
510
+
511
+ assert batch_cfg, "batch_cfg must be True for DiTWrapper"
512
+ assert negative_input_concat_cond is None, "negative_input_concat_cond is not supported for DiTWrapper"
513
+
514
+ return self.model(
515
+ x,
516
+ t,
517
+ cross_attn_cond=cross_attn_cond,
518
+ cross_attn_cond_mask=cross_attn_mask,
519
+ negative_cross_attn_cond=negative_cross_attn_cond,
520
+ negative_cross_attn_mask=negative_cross_attn_mask,
521
+ input_concat_cond=input_concat_cond,
522
+ prepend_cond=prepend_cond,
523
+ prepend_cond_mask=prepend_cond_mask,
524
+ cfg_scale=cfg_scale,
525
+ cfg_dropout_prob=cfg_dropout_prob,
526
+ scale_phi=scale_phi,
527
+ global_embed=global_cond,
528
+ **kwargs)
529
+
530
+ class DiTUncondWrapper(DiffusionModel):
531
+ def __init__(
532
+ self,
533
+ in_channels,
534
+ *args,
535
+ **kwargs
536
+ ):
537
+ super().__init__()
538
+
539
+ self.model = DiffusionTransformer(io_channels=in_channels, *args, **kwargs)
540
+
541
+ self.io_channels = in_channels
542
+
543
+ with torch.no_grad():
544
+ for param in self.model.parameters():
545
+ param *= 0.5
546
+
547
+ def forward(self, x, t, **kwargs):
548
+ return self.model(x, t, **kwargs)
549
+
550
+ def create_diffusion_uncond_from_config(config: tp.Dict[str, tp.Any]):
551
+ diffusion_uncond_config = config["model"]
552
+
553
+ model_type = diffusion_uncond_config.get('type', None)
554
+
555
+ diffusion_config = diffusion_uncond_config.get('config', {})
556
+
557
+ assert model_type is not None, "Must specify model type in config"
558
+
559
+ pretransform = diffusion_uncond_config.get("pretransform", None)
560
+
561
+ sample_size = config.get("sample_size", None)
562
+ assert sample_size is not None, "Must specify sample size in config"
563
+
564
+ sample_rate = config.get("sample_rate", None)
565
+ assert sample_rate is not None, "Must specify sample rate in config"
566
+
567
+ if pretransform is not None:
568
+ pretransform = create_pretransform_from_config(pretransform, sample_rate)
569
+ min_input_length = pretransform.downsampling_ratio
570
+ else:
571
+ min_input_length = 1
572
+
573
+ if model_type == 'DAU1d':
574
+
575
+ model = DiffusionAttnUnet1D(
576
+ **diffusion_config
577
+ )
578
+
579
+ elif model_type == "adp_uncond_1d":
580
+
581
+ model = UNet1DUncondWrapper(
582
+ **diffusion_config
583
+ )
584
+
585
+ elif model_type == "dit":
586
+ model = DiTUncondWrapper(
587
+ **diffusion_config
588
+ )
589
+
590
+ else:
591
+ raise NotImplementedError(f'Unknown model type: {model_type}')
592
+
593
+ return DiffusionModelWrapper(model,
594
+ io_channels=model.io_channels,
595
+ sample_size=sample_size,
596
+ sample_rate=sample_rate,
597
+ pretransform=pretransform,
598
+ min_input_length=min_input_length)
599
+
600
+ def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]):
601
+
602
+ model_config = config["model"]
603
+
604
+ model_type = config["model_type"]
605
+
606
+ diffusion_config = model_config.get('diffusion', None)
607
+ assert diffusion_config is not None, "Must specify diffusion config"
608
+
609
+ diffusion_model_type = diffusion_config.get('type', None)
610
+ assert diffusion_model_type is not None, "Must specify diffusion model type"
611
+
612
+ diffusion_model_config = diffusion_config.get('config', None)
613
+ assert diffusion_model_config is not None, "Must specify diffusion model config"
614
+
615
+ if diffusion_model_type == 'adp_cfg_1d':
616
+ diffusion_model = UNetCFG1DWrapper(**diffusion_model_config)
617
+ elif diffusion_model_type == 'adp_1d':
618
+ diffusion_model = UNet1DCondWrapper(**diffusion_model_config)
619
+ elif diffusion_model_type == 'dit':
620
+ diffusion_model = DiTWrapper(**diffusion_model_config)
621
+
622
+ io_channels = model_config.get('io_channels', None)
623
+ assert io_channels is not None, "Must specify io_channels in model config"
624
+
625
+ sample_rate = config.get('sample_rate', None)
626
+ assert sample_rate is not None, "Must specify sample_rate in config"
627
+
628
+ conditioning_config = model_config.get('conditioning', None)
629
+
630
+ conditioner = None
631
+ if conditioning_config is not None:
632
+ conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config)
633
+
634
+ cross_attention_ids = diffusion_config.get('cross_attention_cond_ids', [])
635
+ global_cond_ids = diffusion_config.get('global_cond_ids', [])
636
+ input_concat_ids = diffusion_config.get('input_concat_ids', [])
637
+ prepend_cond_ids = diffusion_config.get('prepend_cond_ids', [])
638
+
639
+ pretransform = model_config.get("pretransform", None)
640
+
641
+ if pretransform is not None:
642
+ pretransform = create_pretransform_from_config(pretransform, sample_rate)
643
+ min_input_length = pretransform.downsampling_ratio
644
+ else:
645
+ min_input_length = 1
646
+
647
+ if diffusion_model_type == "adp_cfg_1d" or diffusion_model_type == "adp_1d":
648
+ min_input_length *= np.prod(diffusion_model_config["factors"])
649
+ elif diffusion_model_type == "dit":
650
+ min_input_length *= diffusion_model.model.patch_size
651
+
652
+ # Get the proper wrapper class
653
+
654
+ if model_type == "diffusion_cond" or model_type == "diffusion_cond_inpaint":
655
+ wrapper_fn = ConditionedDiffusionModelWrapper
656
+ elif model_type == "diffusion_prior":
657
+ prior_type = model_config.get("prior_type", None)
658
+ assert prior_type is not None, "Must specify prior_type in diffusion prior model config"
659
+
660
+ if prior_type == "mono_stereo":
661
+ from .diffusion_prior import MonoToStereoDiffusionPrior
662
+ wrapper_fn = MonoToStereoDiffusionPrior
663
+ elif prior_type == "source_separation":
664
+ from .diffusion_prior import SourceSeparationDiffusionPrior
665
+ wrapper_fn = SourceSeparationDiffusionPrior
666
+
667
+ return wrapper_fn(
668
+ diffusion_model,
669
+ conditioner,
670
+ min_input_length=min_input_length,
671
+ sample_rate=sample_rate,
672
+ cross_attn_cond_ids=cross_attention_ids,
673
+ global_cond_ids=global_cond_ids,
674
+ input_concat_ids=input_concat_ids,
675
+ prepend_cond_ids=prepend_cond_ids,
676
+ pretransform=pretransform,
677
+ io_channels=io_channels
678
+ )
stable_audio_tools/models/diffusion_prior.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ import typing as tp
3
+
4
+ from .diffusion import ConditionedDiffusionModelWrapper
5
+ from ..inference.generation import generate_diffusion_cond
6
+ from ..inference.utils import prepare_audio
7
+
8
+ import torch
9
+ from torch.nn import functional as F
10
+ from torchaudio import transforms as T
11
+
12
+ # Define prior types enum
13
+ class PriorType(Enum):
14
+ MonoToStereo = 1
15
+ SourceSeparation = 2
16
+
17
+ class DiffusionPrior(ConditionedDiffusionModelWrapper):
18
+ def __init__(self, *args, prior_type: PriorType=None, **kwargs):
19
+ super().__init__(*args, **kwargs)
20
+ self.prior_type = prior_type
21
+
22
+ class MonoToStereoDiffusionPrior(DiffusionPrior):
23
+ def __init__(self, *args, **kwargs):
24
+ super().__init__(*args, prior_type=PriorType.MonoToStereo, **kwargs)
25
+
26
+ def stereoize(
27
+ self,
28
+ audio: torch.Tensor, # (batch, channels, time)
29
+ in_sr: int,
30
+ steps: int,
31
+ sampler_kwargs: dict = {},
32
+ ):
33
+ """
34
+ Generate stereo audio from mono audio using a pre-trained diffusion prior
35
+
36
+ Args:
37
+ audio: The mono audio to convert to stereo
38
+ in_sr: The sample rate of the input audio
39
+ steps: The number of diffusion steps to run
40
+ sampler_kwargs: Keyword arguments to pass to the diffusion sampler
41
+ """
42
+
43
+ device = audio.device
44
+
45
+ sample_rate = self.sample_rate
46
+
47
+ # Resample input audio if necessary
48
+ if in_sr != sample_rate:
49
+ resample_tf = T.Resample(in_sr, sample_rate).to(audio.device)
50
+ audio = resample_tf(audio)
51
+
52
+ audio_length = audio.shape[-1]
53
+
54
+ # Pad input audio to be compatible with the model
55
+ min_length = self.min_input_length
56
+ padded_input_length = audio_length + (min_length - (audio_length % min_length)) % min_length
57
+
58
+ # Pad input audio to be compatible with the model
59
+ if padded_input_length > audio_length:
60
+ audio = F.pad(audio, (0, padded_input_length - audio_length))
61
+
62
+ # Make audio mono, duplicate to stereo
63
+ dual_mono = audio.mean(1, keepdim=True).repeat(1, 2, 1)
64
+
65
+ if self.pretransform is not None:
66
+ dual_mono = self.pretransform.encode(dual_mono)
67
+
68
+ conditioning = {"source": [dual_mono]}
69
+
70
+ stereo_audio = generate_diffusion_cond(
71
+ self,
72
+ conditioning_tensors=conditioning,
73
+ steps=steps,
74
+ sample_size=padded_input_length,
75
+ sample_rate=sample_rate,
76
+ device=device,
77
+ **sampler_kwargs,
78
+ )
79
+
80
+ return stereo_audio
81
+
82
+
83
+ class SourceSeparationDiffusionPrior(DiffusionPrior):
84
+ """
85
+ A diffusion prior model made for conditioned source separation
86
+ """
87
+ def __init__(self, *args, **kwargs):
88
+ super().__init__(*args, prior_type=PriorType.SourceSeparation, **kwargs)
89
+
90
+ def separate(
91
+ self,
92
+ mixed_audio: torch.Tensor, # (batch, channels, time)
93
+ in_sr: int,
94
+ steps: int,
95
+ conditioning: dict = None,
96
+ conditioning_tensors: tp.Optional[dict] = None,
97
+ sampler_kwargs: dict = {},
98
+ ):
99
+ """
100
+ Separate audio sources based on conditioning using a pre-trained diffusion prior
101
+
102
+ Args:
103
+ mixed_audio: The mixed audio to separate
104
+ in_sr: The sample rate of the input audio
105
+ steps: The number of diffusion steps to run
106
+ conditioning: The conditioning to use for source separation
107
+ conditioning_tensors: Pre-computed conditioning tensors to use for source separation. If provided, conditioning is ignored.
108
+ sampler_kwargs: Keyword arguments to pass to the diffusion sampler
109
+ """
110
+
111
+ device = mixed_audio.device
112
+
113
+ sample_rate = self.sample_rate
114
+
115
+ # Resample input audio if necessary
116
+ if in_sr != sample_rate:
117
+ resample_tf = T.Resample(in_sr, sample_rate).to(mixed_audio.device)
118
+ mixed_audio = resample_tf(mixed_audio)
119
+
120
+ audio_length = mixed_audio.shape[-1]
121
+
122
+ # Pad input audio to be compatible with the model
123
+ min_length = self.min_input_length
124
+ padded_input_length = audio_length + (min_length - (audio_length % min_length)) % min_length
125
+
126
+ # Pad input audio to be compatible with the model
127
+ if padded_input_length > audio_length:
128
+ mixed_audio = F.pad(mixed_audio, (0, padded_input_length - audio_length))
129
+
130
+ if self.pretransform is not None:
131
+ mixed_audio = self.pretransform.encode(mixed_audio)
132
+
133
+ # Conditioning
134
+ assert conditioning is not None or conditioning_tensors is not None, "Must provide either conditioning or conditioning_tensors for conditioned source separation"
135
+ if conditioning_tensors is None:
136
+ conditioning_tensors = self.conditioner(conditioning, device)
137
+
138
+ # Pass in the mixture audio as conditioning
139
+ conditioning_tensors["source"] = [mixed_audio]
140
+
141
+ stereo_audio = generate_diffusion_cond(
142
+ self,
143
+ conditioning_tensors=conditioning_tensors,
144
+ steps=steps,
145
+ sample_size=padded_input_length,
146
+ sample_rate=sample_rate,
147
+ device=device,
148
+ **sampler_kwargs,
149
+ )
150
+
151
+ return stereo_audio
stable_audio_tools/models/discriminators.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from functools import reduce
6
+ import typing as tp
7
+ from einops import rearrange
8
+ from audiotools import AudioSignal, STFTParams
9
+ from dac.model.discriminator import WNConv1d, WNConv2d
10
+
11
+ def get_hinge_losses(score_real, score_fake):
12
+ gen_loss = -score_fake.mean()
13
+ dis_loss = torch.relu(1 - score_real).mean() + torch.relu(1 + score_fake).mean()
14
+ return dis_loss, gen_loss
15
+
16
+ class EncodecDiscriminator(nn.Module):
17
+
18
+ def __init__(self, *args, **kwargs):
19
+ super().__init__()
20
+
21
+ from encodec.msstftd import MultiScaleSTFTDiscriminator
22
+
23
+ self.discriminators = MultiScaleSTFTDiscriminator(*args, **kwargs)
24
+
25
+ def forward(self, x):
26
+ logits, features = self.discriminators(x)
27
+ return logits, features
28
+
29
+ def loss(self, x, y):
30
+ feature_matching_distance = 0.
31
+ logits_true, feature_true = self.forward(x)
32
+ logits_fake, feature_fake = self.forward(y)
33
+
34
+ dis_loss = torch.tensor(0.)
35
+ adv_loss = torch.tensor(0.)
36
+
37
+ for i, (scale_true, scale_fake) in enumerate(zip(feature_true, feature_fake)):
38
+
39
+ feature_matching_distance = feature_matching_distance + sum(
40
+ map(
41
+ lambda x, y: abs(x - y).mean(),
42
+ scale_true,
43
+ scale_fake,
44
+ )) / len(scale_true)
45
+
46
+ _dis, _adv = get_hinge_losses(
47
+ logits_true[i],
48
+ logits_fake[i],
49
+ )
50
+
51
+ dis_loss = dis_loss + _dis
52
+ adv_loss = adv_loss + _adv
53
+
54
+ return dis_loss, adv_loss, feature_matching_distance
55
+
56
+ # Discriminators from oobleck
57
+
58
+ IndividualDiscriminatorOut = tp.Tuple[torch.Tensor, tp.Sequence[torch.Tensor]]
59
+
60
+ TensorDict = tp.Dict[str, torch.Tensor]
61
+
62
+ class SharedDiscriminatorConvNet(nn.Module):
63
+
64
+ def __init__(
65
+ self,
66
+ in_size: int,
67
+ convolution: tp.Union[nn.Conv1d, nn.Conv2d],
68
+ out_size: int = 1,
69
+ capacity: int = 32,
70
+ n_layers: int = 4,
71
+ kernel_size: int = 15,
72
+ stride: int = 4,
73
+ activation: tp.Callable[[], nn.Module] = lambda: nn.SiLU(),
74
+ normalization: tp.Callable[[nn.Module], nn.Module] = torch.nn.utils.weight_norm,
75
+ ) -> None:
76
+ super().__init__()
77
+ channels = [in_size]
78
+ channels += list(capacity * 2**np.arange(n_layers))
79
+
80
+ if isinstance(stride, int):
81
+ stride = n_layers * [stride]
82
+
83
+ net = []
84
+ for i in range(n_layers):
85
+ if isinstance(kernel_size, int):
86
+ pad = kernel_size // 2
87
+ s = stride[i]
88
+ else:
89
+ pad = kernel_size[0] // 2
90
+ s = (stride[i], 1)
91
+
92
+ net.append(
93
+ normalization(
94
+ convolution(
95
+ channels[i],
96
+ channels[i + 1],
97
+ kernel_size,
98
+ stride=s,
99
+ padding=pad,
100
+ )))
101
+ net.append(activation())
102
+
103
+ net.append(convolution(channels[-1], out_size, 1))
104
+
105
+ self.net = nn.ModuleList(net)
106
+
107
+ def forward(self, x) -> IndividualDiscriminatorOut:
108
+ features = []
109
+ for layer in self.net:
110
+ x = layer(x)
111
+ if isinstance(layer, nn.modules.conv._ConvNd):
112
+ features.append(x)
113
+ score = x.reshape(x.shape[0], -1).mean(-1)
114
+ return score, features
115
+
116
+
117
+ class MultiScaleDiscriminator(nn.Module):
118
+
119
+ def __init__(self,
120
+ in_channels: int,
121
+ n_scales: int,
122
+ **conv_kwargs) -> None:
123
+ super().__init__()
124
+ layers = []
125
+ for _ in range(n_scales):
126
+ layers.append(SharedDiscriminatorConvNet(in_channels, nn.Conv1d, **conv_kwargs))
127
+ self.layers = nn.ModuleList(layers)
128
+
129
+ def forward(self, x: torch.Tensor) -> IndividualDiscriminatorOut:
130
+ score = 0
131
+ features = []
132
+ for layer in self.layers:
133
+ s, f = layer(x)
134
+ score = score + s
135
+ features.extend(f)
136
+ x = nn.functional.avg_pool1d(x, 2)
137
+ return score, features
138
+
139
+ class MultiPeriodDiscriminator(nn.Module):
140
+
141
+ def __init__(self,
142
+ in_channels: int,
143
+ periods: tp.Sequence[int],
144
+ **conv_kwargs) -> None:
145
+ super().__init__()
146
+ layers = []
147
+ self.periods = periods
148
+
149
+ for _ in periods:
150
+ layers.append(SharedDiscriminatorConvNet(in_channels, nn.Conv2d, **conv_kwargs))
151
+
152
+ self.layers = nn.ModuleList(layers)
153
+
154
+ def forward(self, x: torch.Tensor) -> IndividualDiscriminatorOut:
155
+ score = 0
156
+ features = []
157
+ for layer, n in zip(self.layers, self.periods):
158
+ s, f = layer(self.fold(x, n))
159
+ score = score + s
160
+ features.extend(f)
161
+ return score, features
162
+
163
+ def fold(self, x: torch.Tensor, n: int) -> torch.Tensor:
164
+ pad = (n - (x.shape[-1] % n)) % n
165
+ x = nn.functional.pad(x, (0, pad))
166
+ return x.reshape(*x.shape[:2], -1, n)
167
+
168
+
169
+ class MultiDiscriminator(nn.Module):
170
+ """
171
+ Individual discriminators should take a single tensor as input (NxB C T) and
172
+ return a tuple composed of a score tensor (NxB) and a Sequence of Features
173
+ Sequence[NxB C' T'].
174
+ """
175
+
176
+ def __init__(self, discriminator_list: tp.Sequence[nn.Module],
177
+ keys: tp.Sequence[str]) -> None:
178
+ super().__init__()
179
+ self.discriminators = nn.ModuleList(discriminator_list)
180
+ self.keys = keys
181
+
182
+ def unpack_tensor_to_dict(self, features: torch.Tensor) -> TensorDict:
183
+ features = features.chunk(len(self.keys), 0)
184
+ return {k: features[i] for i, k in enumerate(self.keys)}
185
+
186
+ @staticmethod
187
+ def concat_dicts(dict_a, dict_b):
188
+ out_dict = {}
189
+ keys = set(list(dict_a.keys()) + list(dict_b.keys()))
190
+ for k in keys:
191
+ out_dict[k] = []
192
+ if k in dict_a:
193
+ if isinstance(dict_a[k], list):
194
+ out_dict[k].extend(dict_a[k])
195
+ else:
196
+ out_dict[k].append(dict_a[k])
197
+ if k in dict_b:
198
+ if isinstance(dict_b[k], list):
199
+ out_dict[k].extend(dict_b[k])
200
+ else:
201
+ out_dict[k].append(dict_b[k])
202
+ return out_dict
203
+
204
+ @staticmethod
205
+ def sum_dicts(dict_a, dict_b):
206
+ out_dict = {}
207
+ keys = set(list(dict_a.keys()) + list(dict_b.keys()))
208
+ for k in keys:
209
+ out_dict[k] = 0.
210
+ if k in dict_a:
211
+ out_dict[k] = out_dict[k] + dict_a[k]
212
+ if k in dict_b:
213
+ out_dict[k] = out_dict[k] + dict_b[k]
214
+ return out_dict
215
+
216
+ def forward(self, inputs: TensorDict) -> TensorDict:
217
+ discriminator_input = torch.cat([inputs[k] for k in self.keys], 0)
218
+ all_scores = []
219
+ all_features = []
220
+
221
+ for discriminator in self.discriminators:
222
+ score, features = discriminator(discriminator_input)
223
+ scores = self.unpack_tensor_to_dict(score)
224
+ scores = {f"score_{k}": scores[k] for k in scores.keys()}
225
+ all_scores.append(scores)
226
+
227
+ features = map(self.unpack_tensor_to_dict, features)
228
+ features = reduce(self.concat_dicts, features)
229
+ features = {f"features_{k}": features[k] for k in features.keys()}
230
+ all_features.append(features)
231
+
232
+ all_scores = reduce(self.sum_dicts, all_scores)
233
+ all_features = reduce(self.concat_dicts, all_features)
234
+
235
+ inputs.update(all_scores)
236
+ inputs.update(all_features)
237
+
238
+ return inputs
239
+
240
+ class OobleckDiscriminator(nn.Module):
241
+
242
+ def __init__(
243
+ self,
244
+ in_channels=1,
245
+ ):
246
+ super().__init__()
247
+
248
+ multi_scale_discriminator = MultiScaleDiscriminator(
249
+ in_channels=in_channels,
250
+ n_scales=3,
251
+ )
252
+
253
+ multi_period_discriminator = MultiPeriodDiscriminator(
254
+ in_channels=in_channels,
255
+ periods=[2, 3, 5, 7, 11]
256
+ )
257
+
258
+ # multi_resolution_discriminator = MultiScaleSTFTDiscriminator(
259
+ # filters=32,
260
+ # in_channels = in_channels,
261
+ # out_channels = 1,
262
+ # n_ffts = [2048, 1024, 512, 256, 128],
263
+ # hop_lengths = [512, 256, 128, 64, 32],
264
+ # win_lengths = [2048, 1024, 512, 256, 128]
265
+ # )
266
+
267
+ self.multi_discriminator = MultiDiscriminator(
268
+ [multi_scale_discriminator, multi_period_discriminator], #, multi_resolution_discriminator],
269
+ ["reals", "fakes"]
270
+ )
271
+
272
+ def loss(self, reals, fakes):
273
+ inputs = {
274
+ "reals": reals,
275
+ "fakes": fakes,
276
+ }
277
+
278
+ inputs = self.multi_discriminator(inputs)
279
+
280
+ scores_real = inputs["score_reals"]
281
+ scores_fake = inputs["score_fakes"]
282
+
283
+ features_real = inputs["features_reals"]
284
+ features_fake = inputs["features_fakes"]
285
+
286
+ dis_loss, gen_loss = get_hinge_losses(scores_real, scores_fake)
287
+
288
+ feature_matching_distance = torch.tensor(0.)
289
+
290
+ for _, (scale_real, scale_fake) in enumerate(zip(features_real, features_fake)):
291
+
292
+ feature_matching_distance = feature_matching_distance + sum(
293
+ map(
294
+ lambda real, fake: abs(real - fake).mean(),
295
+ scale_real,
296
+ scale_fake,
297
+ )) / len(scale_real)
298
+
299
+ return dis_loss, gen_loss, feature_matching_distance
300
+
301
+
302
+ ## Discriminators from Descript Audio Codec repo
303
+ ## Copied and modified under MIT license, see LICENSES/LICENSE_DESCRIPT.txt
304
+ class MPD(nn.Module):
305
+ def __init__(self, period, channels=1):
306
+ super().__init__()
307
+
308
+ self.period = period
309
+ self.convs = nn.ModuleList(
310
+ [
311
+ WNConv2d(channels, 32, (5, 1), (3, 1), padding=(2, 0)),
312
+ WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)),
313
+ WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)),
314
+ WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)),
315
+ WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)),
316
+ ]
317
+ )
318
+ self.conv_post = WNConv2d(
319
+ 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False
320
+ )
321
+
322
+ def pad_to_period(self, x):
323
+ t = x.shape[-1]
324
+ x = F.pad(x, (0, self.period - t % self.period), mode="reflect")
325
+ return x
326
+
327
+ def forward(self, x):
328
+ fmap = []
329
+
330
+ x = self.pad_to_period(x)
331
+ x = rearrange(x, "b c (l p) -> b c l p", p=self.period)
332
+
333
+ for layer in self.convs:
334
+ x = layer(x)
335
+ fmap.append(x)
336
+
337
+ x = self.conv_post(x)
338
+ fmap.append(x)
339
+
340
+ return fmap
341
+
342
+
343
+ class MSD(nn.Module):
344
+ def __init__(self, rate: int = 1, sample_rate: int = 44100, channels=1):
345
+ super().__init__()
346
+
347
+ self.convs = nn.ModuleList(
348
+ [
349
+ WNConv1d(channels, 16, 15, 1, padding=7),
350
+ WNConv1d(16, 64, 41, 4, groups=4, padding=20),
351
+ WNConv1d(64, 256, 41, 4, groups=16, padding=20),
352
+ WNConv1d(256, 1024, 41, 4, groups=64, padding=20),
353
+ WNConv1d(1024, 1024, 41, 4, groups=256, padding=20),
354
+ WNConv1d(1024, 1024, 5, 1, padding=2),
355
+ ]
356
+ )
357
+ self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False)
358
+ self.sample_rate = sample_rate
359
+ self.rate = rate
360
+
361
+ def forward(self, x):
362
+ x = AudioSignal(x, self.sample_rate)
363
+ x.resample(self.sample_rate // self.rate)
364
+ x = x.audio_data
365
+
366
+ fmap = []
367
+
368
+ for l in self.convs:
369
+ x = l(x)
370
+ fmap.append(x)
371
+ x = self.conv_post(x)
372
+ fmap.append(x)
373
+
374
+ return fmap
375
+
376
+
377
+ BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)]
378
+
379
+
380
+ class MRD(nn.Module):
381
+ def __init__(
382
+ self,
383
+ window_length: int,
384
+ hop_factor: float = 0.25,
385
+ sample_rate: int = 44100,
386
+ bands: list = BANDS,
387
+ channels: int = 1
388
+ ):
389
+ """Complex multi-band spectrogram discriminator.
390
+ Parameters
391
+ ----------
392
+ window_length : int
393
+ Window length of STFT.
394
+ hop_factor : float, optional
395
+ Hop factor of the STFT, defaults to ``0.25 * window_length``.
396
+ sample_rate : int, optional
397
+ Sampling rate of audio in Hz, by default 44100
398
+ bands : list, optional
399
+ Bands to run discriminator over.
400
+ """
401
+ super().__init__()
402
+
403
+ self.window_length = window_length
404
+ self.hop_factor = hop_factor
405
+ self.sample_rate = sample_rate
406
+ self.stft_params = STFTParams(
407
+ window_length=window_length,
408
+ hop_length=int(window_length * hop_factor),
409
+ match_stride=True,
410
+ )
411
+
412
+ self.channels = channels
413
+
414
+ n_fft = window_length // 2 + 1
415
+ bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
416
+ self.bands = bands
417
+
418
+ ch = 32
419
+ convs = lambda: nn.ModuleList(
420
+ [
421
+ WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)),
422
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
423
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
424
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
425
+ WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)),
426
+ ]
427
+ )
428
+ self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
429
+ self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False)
430
+
431
+ def spectrogram(self, x):
432
+ x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params)
433
+ x = torch.view_as_real(x.stft())
434
+ x = rearrange(x, "b ch f t c -> (b ch) c t f", ch=self.channels)
435
+ # Split into bands
436
+ x_bands = [x[..., b[0] : b[1]] for b in self.bands]
437
+ return x_bands
438
+
439
+ def forward(self, x):
440
+ x_bands = self.spectrogram(x)
441
+ fmap = []
442
+
443
+ x = []
444
+ for band, stack in zip(x_bands, self.band_convs):
445
+ for layer in stack:
446
+ band = layer(band)
447
+ fmap.append(band)
448
+ x.append(band)
449
+
450
+ x = torch.cat(x, dim=-1)
451
+ x = self.conv_post(x)
452
+ fmap.append(x)
453
+
454
+ return fmap
455
+
456
+
457
+ class DACDiscriminator(nn.Module):
458
+ def __init__(
459
+ self,
460
+ channels: int = 1,
461
+ rates: list = [],
462
+ periods: list = [2, 3, 5, 7, 11],
463
+ fft_sizes: list = [2048, 1024, 512],
464
+ sample_rate: int = 44100,
465
+ bands: list = BANDS,
466
+ ):
467
+ """Discriminator that combines multiple discriminators.
468
+
469
+ Parameters
470
+ ----------
471
+ rates : list, optional
472
+ sampling rates (in Hz) to run MSD at, by default []
473
+ If empty, MSD is not used.
474
+ periods : list, optional
475
+ periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11]
476
+ fft_sizes : list, optional
477
+ Window sizes of the FFT to run MRD at, by default [2048, 1024, 512]
478
+ sample_rate : int, optional
479
+ Sampling rate of audio in Hz, by default 44100
480
+ bands : list, optional
481
+ Bands to run MRD at, by default `BANDS`
482
+ """
483
+ super().__init__()
484
+ discs = []
485
+ discs += [MPD(p, channels=channels) for p in periods]
486
+ discs += [MSD(r, sample_rate=sample_rate, channels=channels) for r in rates]
487
+ discs += [MRD(f, sample_rate=sample_rate, bands=bands, channels=channels) for f in fft_sizes]
488
+ self.discriminators = nn.ModuleList(discs)
489
+
490
+ def preprocess(self, y):
491
+ # Remove DC offset
492
+ y = y - y.mean(dim=-1, keepdims=True)
493
+ # Peak normalize the volume of input audio
494
+ y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
495
+ return y
496
+
497
+ def forward(self, x):
498
+ x = self.preprocess(x)
499
+ fmaps = [d(x) for d in self.discriminators]
500
+ return fmaps
501
+
502
+ class DACGANLoss(nn.Module):
503
+ """
504
+ Computes a discriminator loss, given a discriminator on
505
+ generated waveforms/spectrograms compared to ground truth
506
+ waveforms/spectrograms. Computes the loss for both the
507
+ discriminator and the generator in separate functions.
508
+ """
509
+
510
+ def __init__(self, **discriminator_kwargs):
511
+ super().__init__()
512
+ self.discriminator = DACDiscriminator(**discriminator_kwargs)
513
+
514
+ def forward(self, fake, real):
515
+ d_fake = self.discriminator(fake)
516
+ d_real = self.discriminator(real)
517
+ return d_fake, d_real
518
+
519
+ def discriminator_loss(self, fake, real):
520
+ d_fake, d_real = self.forward(fake.clone().detach(), real)
521
+
522
+ loss_d = 0
523
+ for x_fake, x_real in zip(d_fake, d_real):
524
+ loss_d += torch.mean(x_fake[-1] ** 2)
525
+ loss_d += torch.mean((1 - x_real[-1]) ** 2)
526
+ return loss_d
527
+
528
+ def generator_loss(self, fake, real):
529
+ d_fake, d_real = self.forward(fake, real)
530
+
531
+ loss_g = 0
532
+ for x_fake in d_fake:
533
+ loss_g += torch.mean((1 - x_fake[-1]) ** 2)
534
+
535
+ loss_feature = 0
536
+
537
+ for i in range(len(d_fake)):
538
+ for j in range(len(d_fake[i]) - 1):
539
+ loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
540
+ return loss_g, loss_feature
541
+
542
+ def loss(self, fake, real):
543
+ gen_loss, feature_distance = self.generator_loss(fake, real)
544
+ dis_loss = self.discriminator_loss(fake, real)
545
+
546
+ return dis_loss, gen_loss, feature_distance
stable_audio_tools/models/dit.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing as tp
2
+
3
+ import torch
4
+
5
+ from einops import rearrange
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+ from x_transformers import ContinuousTransformerWrapper, Encoder
9
+
10
+ from .blocks import FourierFeatures
11
+ from .transformer import ContinuousTransformer
12
+
13
+ class DiffusionTransformer(nn.Module):
14
+ def __init__(self,
15
+ io_channels=32,
16
+ patch_size=1,
17
+ embed_dim=768,
18
+ cond_token_dim=0,
19
+ project_cond_tokens=True,
20
+ global_cond_dim=0,
21
+ input_concat_dim=0,
22
+ prepend_cond_dim=0,
23
+ depth=12,
24
+ num_heads=8,
25
+ transformer_type: tp.Literal["x-transformers", "continuous_transformer"] = "x-transformers",
26
+ global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend",
27
+ **kwargs):
28
+
29
+ super().__init__()
30
+
31
+ self.cond_token_dim = cond_token_dim
32
+
33
+ # Timestep embeddings
34
+ timestep_features_dim = 256
35
+
36
+ self.timestep_features = FourierFeatures(1, timestep_features_dim)
37
+
38
+ self.to_timestep_embed = nn.Sequential(
39
+ nn.Linear(timestep_features_dim, embed_dim, bias=True),
40
+ nn.SiLU(),
41
+ nn.Linear(embed_dim, embed_dim, bias=True),
42
+ )
43
+
44
+ if cond_token_dim > 0:
45
+ # Conditioning tokens
46
+
47
+ cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim
48
+ self.to_cond_embed = nn.Sequential(
49
+ nn.Linear(cond_token_dim, cond_embed_dim, bias=False),
50
+ nn.SiLU(),
51
+ nn.Linear(cond_embed_dim, cond_embed_dim, bias=False)
52
+ )
53
+ else:
54
+ cond_embed_dim = 0
55
+
56
+ if global_cond_dim > 0:
57
+ # Global conditioning
58
+ self.to_global_embed = nn.Sequential(
59
+ nn.Linear(global_cond_dim, embed_dim, bias=False),
60
+ nn.SiLU(),
61
+ nn.Linear(embed_dim, embed_dim, bias=False)
62
+ )
63
+
64
+ if prepend_cond_dim > 0:
65
+ # Prepend conditioning
66
+ self.to_prepend_embed = nn.Sequential(
67
+ nn.Linear(prepend_cond_dim, embed_dim, bias=False),
68
+ nn.SiLU(),
69
+ nn.Linear(embed_dim, embed_dim, bias=False)
70
+ )
71
+
72
+ self.input_concat_dim = input_concat_dim
73
+
74
+ dim_in = io_channels + self.input_concat_dim
75
+
76
+ self.patch_size = patch_size
77
+
78
+ # Transformer
79
+
80
+ self.transformer_type = transformer_type
81
+
82
+ self.global_cond_type = global_cond_type
83
+
84
+ if self.transformer_type == "x-transformers":
85
+ self.transformer = ContinuousTransformerWrapper(
86
+ dim_in=dim_in * patch_size,
87
+ dim_out=io_channels * patch_size,
88
+ max_seq_len=0, #Not relevant without absolute positional embeds
89
+ attn_layers = Encoder(
90
+ dim=embed_dim,
91
+ depth=depth,
92
+ heads=num_heads,
93
+ attn_flash = True,
94
+ cross_attend = cond_token_dim > 0,
95
+ dim_context=None if cond_embed_dim == 0 else cond_embed_dim,
96
+ zero_init_branch_output=True,
97
+ use_abs_pos_emb = False,
98
+ rotary_pos_emb=True,
99
+ ff_swish = True,
100
+ ff_glu = True,
101
+ **kwargs
102
+ )
103
+ )
104
+
105
+ elif self.transformer_type == "continuous_transformer":
106
+
107
+ global_dim = None
108
+
109
+ if self.global_cond_type == "adaLN":
110
+ # The global conditioning is projected to the embed_dim already at this point
111
+ global_dim = embed_dim
112
+
113
+ self.transformer = ContinuousTransformer(
114
+ dim=embed_dim,
115
+ depth=depth,
116
+ dim_heads=embed_dim // num_heads,
117
+ dim_in=dim_in * patch_size,
118
+ dim_out=io_channels * patch_size,
119
+ cross_attend = cond_token_dim > 0,
120
+ cond_token_dim = cond_embed_dim,
121
+ global_cond_dim=global_dim,
122
+ **kwargs
123
+ )
124
+
125
+ else:
126
+ raise ValueError(f"Unknown transformer type: {self.transformer_type}")
127
+
128
+ self.preprocess_conv = nn.Conv1d(dim_in, dim_in, 1, bias=False)
129
+ nn.init.zeros_(self.preprocess_conv.weight)
130
+ self.postprocess_conv = nn.Conv1d(io_channels, io_channels, 1, bias=False)
131
+ nn.init.zeros_(self.postprocess_conv.weight)
132
+
133
+ def _forward(
134
+ self,
135
+ x,
136
+ t,
137
+ mask=None,
138
+ cross_attn_cond=None,
139
+ cross_attn_cond_mask=None,
140
+ input_concat_cond=None,
141
+ global_embed=None,
142
+ prepend_cond=None,
143
+ prepend_cond_mask=None,
144
+ **kwargs):
145
+
146
+ if cross_attn_cond is not None:
147
+ cross_attn_cond = self.to_cond_embed(cross_attn_cond)
148
+
149
+ if global_embed is not None:
150
+ # Project the global conditioning to the embedding dimension
151
+ global_embed = self.to_global_embed(global_embed)
152
+
153
+ prepend_inputs = None
154
+ prepend_mask = None
155
+ prepend_length = 0
156
+ if prepend_cond is not None:
157
+ # Project the prepend conditioning to the embedding dimension
158
+ prepend_cond = self.to_prepend_embed(prepend_cond)
159
+
160
+ prepend_inputs = prepend_cond
161
+ if prepend_cond_mask is not None:
162
+ prepend_mask = prepend_cond_mask
163
+
164
+ if input_concat_cond is not None:
165
+
166
+ # Interpolate input_concat_cond to the same length as x
167
+ if input_concat_cond.shape[2] != x.shape[2]:
168
+ input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')
169
+
170
+ x = torch.cat([x, input_concat_cond], dim=1)
171
+
172
+ # Get the batch of timestep embeddings
173
+ timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim)
174
+
175
+ # Timestep embedding is considered a global embedding. Add to the global conditioning if it exists
176
+ if global_embed is not None:
177
+ global_embed = global_embed + timestep_embed
178
+ else:
179
+ global_embed = timestep_embed
180
+
181
+ # Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer
182
+ if self.global_cond_type == "prepend":
183
+ if prepend_inputs is None:
184
+ # Prepend inputs are just the global embed, and the mask is all ones
185
+ prepend_inputs = global_embed.unsqueeze(1)
186
+ prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)
187
+ else:
188
+ # Prepend inputs are the prepend conditioning + the global embed
189
+ prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1)
190
+ prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], dim=1)
191
+
192
+ prepend_length = prepend_inputs.shape[1]
193
+
194
+ x = self.preprocess_conv(x) + x
195
+
196
+ x = rearrange(x, "b c t -> b t c")
197
+
198
+ extra_args = {}
199
+
200
+ if self.global_cond_type == "adaLN":
201
+ extra_args["global_cond"] = global_embed
202
+
203
+ if self.patch_size > 1:
204
+ x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size)
205
+
206
+ if self.transformer_type == "x-transformers" or self.transformer_type == "continuous_transformer":
207
+ output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, **extra_args, **kwargs)
208
+
209
+ output = rearrange(output, "b t c -> b c t")[:,:,prepend_length:]
210
+
211
+ if self.patch_size > 1:
212
+ output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size)
213
+
214
+ output = self.postprocess_conv(output) + output
215
+
216
+ return output
217
+
218
+ def forward(
219
+ self,
220
+ x,
221
+ t,
222
+ cross_attn_cond=None,
223
+ cross_attn_cond_mask=None,
224
+ negative_cross_attn_cond=None,
225
+ negative_cross_attn_mask=None,
226
+ input_concat_cond=None,
227
+ global_embed=None,
228
+ negative_global_embed=None,
229
+ prepend_cond=None,
230
+ prepend_cond_mask=None,
231
+ cfg_scale=1.0,
232
+ cfg_dropout_prob=0.0,
233
+ causal=False,
234
+ scale_phi=0.0,
235
+ mask=None,
236
+ **kwargs):
237
+
238
+ assert causal == False, "Causal mode is not supported for DiffusionTransformer"
239
+
240
+ if cross_attn_cond_mask is not None:
241
+ cross_attn_cond_mask = cross_attn_cond_mask.bool()
242
+
243
+ cross_attn_cond_mask = None # Temporarily disabling conditioning masks due to kernel issue for flash attention
244
+
245
+ if prepend_cond_mask is not None:
246
+ prepend_cond_mask = prepend_cond_mask.bool()
247
+
248
+ # CFG dropout
249
+ if cfg_dropout_prob > 0.0:
250
+ if cross_attn_cond is not None:
251
+ null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
252
+ dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool)
253
+ cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond)
254
+
255
+ if prepend_cond is not None:
256
+ null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
257
+ dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool)
258
+ prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond)
259
+
260
+
261
+ if cfg_scale != 1.0 and (cross_attn_cond is not None or prepend_cond is not None):
262
+ # Classifier-free guidance
263
+ # Concatenate conditioned and unconditioned inputs on the batch dimension
264
+ batch_inputs = torch.cat([x, x], dim=0)
265
+ batch_timestep = torch.cat([t, t], dim=0)
266
+
267
+ if global_embed is not None:
268
+ batch_global_cond = torch.cat([global_embed, global_embed], dim=0)
269
+ else:
270
+ batch_global_cond = None
271
+
272
+ if input_concat_cond is not None:
273
+ batch_input_concat_cond = torch.cat([input_concat_cond, input_concat_cond], dim=0)
274
+ else:
275
+ batch_input_concat_cond = None
276
+
277
+ batch_cond = None
278
+ batch_cond_masks = None
279
+
280
+ # Handle CFG for cross-attention conditioning
281
+ if cross_attn_cond is not None:
282
+
283
+ null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
284
+
285
+ # For negative cross-attention conditioning, replace the null embed with the negative cross-attention conditioning
286
+ if negative_cross_attn_cond is not None:
287
+
288
+ # If there's a negative cross-attention mask, set the masked tokens to the null embed
289
+ if negative_cross_attn_mask is not None:
290
+ negative_cross_attn_mask = negative_cross_attn_mask.to(torch.bool).unsqueeze(2)
291
+
292
+ negative_cross_attn_cond = torch.where(negative_cross_attn_mask, negative_cross_attn_cond, null_embed)
293
+
294
+ batch_cond = torch.cat([cross_attn_cond, negative_cross_attn_cond], dim=0)
295
+
296
+ else:
297
+ batch_cond = torch.cat([cross_attn_cond, null_embed], dim=0)
298
+
299
+ if cross_attn_cond_mask is not None:
300
+ batch_cond_masks = torch.cat([cross_attn_cond_mask, cross_attn_cond_mask], dim=0)
301
+
302
+ batch_prepend_cond = None
303
+ batch_prepend_cond_mask = None
304
+
305
+ if prepend_cond is not None:
306
+
307
+ null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
308
+
309
+ batch_prepend_cond = torch.cat([prepend_cond, null_embed], dim=0)
310
+
311
+ if prepend_cond_mask is not None:
312
+ batch_prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0)
313
+
314
+
315
+ if mask is not None:
316
+ batch_masks = torch.cat([mask, mask], dim=0)
317
+ else:
318
+ batch_masks = None
319
+
320
+ batch_output = self._forward(
321
+ batch_inputs,
322
+ batch_timestep,
323
+ cross_attn_cond=batch_cond,
324
+ cross_attn_cond_mask=batch_cond_masks,
325
+ mask = batch_masks,
326
+ input_concat_cond=batch_input_concat_cond,
327
+ global_embed = batch_global_cond,
328
+ prepend_cond = batch_prepend_cond,
329
+ prepend_cond_mask = batch_prepend_cond_mask,
330
+ **kwargs)
331
+
332
+ cond_output, uncond_output = torch.chunk(batch_output, 2, dim=0)
333
+ cfg_output = uncond_output + (cond_output - uncond_output) * cfg_scale
334
+
335
+ if scale_phi != 0.0:
336
+
337
+ cond_out_std = cond_output.std(dim=1, keepdim=True)
338
+ out_cfg_std = cfg_output.std(dim=1, keepdim=True)
339
+
340
+ return scale_phi * (cfg_output * (cond_out_std/out_cfg_std)) + (1-scale_phi) * cfg_output
341
+
342
+ else:
343
+
344
+ return cfg_output
345
+
346
+ else:
347
+ return self._forward(
348
+ x,
349
+ t,
350
+ cross_attn_cond=cross_attn_cond,
351
+ cross_attn_cond_mask=cross_attn_cond_mask,
352
+ input_concat_cond=input_concat_cond,
353
+ global_embed=global_embed,
354
+ prepend_cond=prepend_cond,
355
+ prepend_cond_mask=prepend_cond_mask,
356
+ mask=mask,
357
+ **kwargs
358
+ )
stable_audio_tools/models/factory.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ def create_model_from_config(model_config):
4
+ model_type = model_config.get('model_type', None)
5
+
6
+ assert model_type is not None, 'model_type must be specified in model config'
7
+
8
+ if model_type == 'autoencoder':
9
+ from .autoencoders import create_autoencoder_from_config
10
+ return create_autoencoder_from_config(model_config)
11
+ elif model_type == 'diffusion_uncond':
12
+ from .diffusion import create_diffusion_uncond_from_config
13
+ return create_diffusion_uncond_from_config(model_config)
14
+ elif model_type == 'diffusion_cond' or model_type == 'diffusion_cond_inpaint' or model_type == "diffusion_prior":
15
+ from .diffusion import create_diffusion_cond_from_config
16
+ return create_diffusion_cond_from_config(model_config)
17
+ elif model_type == 'diffusion_autoencoder':
18
+ from .autoencoders import create_diffAE_from_config
19
+ return create_diffAE_from_config(model_config)
20
+ elif model_type == 'musicgen':
21
+ from .musicgen import create_musicgen_from_config
22
+ return create_musicgen_from_config(model_config)
23
+ elif model_type == 'lm':
24
+ from .lm import create_audio_lm_from_config
25
+ return create_audio_lm_from_config(model_config)
26
+ else:
27
+ raise NotImplementedError(f'Unknown model type: {model_type}')
28
+
29
+ def create_model_from_config_path(model_config_path):
30
+ with open(model_config_path) as f:
31
+ model_config = json.load(f)
32
+
33
+ return create_model_from_config(model_config)
34
+
35
+ def create_pretransform_from_config(pretransform_config, sample_rate):
36
+ pretransform_type = pretransform_config.get('type', None)
37
+
38
+ assert pretransform_type is not None, 'type must be specified in pretransform config'
39
+
40
+ if pretransform_type == 'autoencoder':
41
+ from .autoencoders import create_autoencoder_from_config
42
+ from .pretransforms import AutoencoderPretransform
43
+
44
+ # Create fake top-level config to pass sample rate to autoencoder constructor
45
+ # This is a bit of a hack but it keeps us from re-defining the sample rate in the config
46
+ autoencoder_config = {"sample_rate": sample_rate, "model": pretransform_config["config"]}
47
+ autoencoder = create_autoencoder_from_config(autoencoder_config)
48
+
49
+ scale = pretransform_config.get("scale", 1.0)
50
+ model_half = pretransform_config.get("model_half", False)
51
+ iterate_batch = pretransform_config.get("iterate_batch", False)
52
+ chunked = pretransform_config.get("chunked", False)
53
+
54
+ pretransform = AutoencoderPretransform(autoencoder, scale=scale, model_half=model_half, iterate_batch=iterate_batch, chunked=chunked)
55
+ elif pretransform_type == 'wavelet':
56
+ from .pretransforms import WaveletPretransform
57
+
58
+ wavelet_config = pretransform_config["config"]
59
+ channels = wavelet_config["channels"]
60
+ levels = wavelet_config["levels"]
61
+ wavelet = wavelet_config["wavelet"]
62
+
63
+ pretransform = WaveletPretransform(channels, levels, wavelet)
64
+ elif pretransform_type == 'pqmf':
65
+ from .pretransforms import PQMFPretransform
66
+ pqmf_config = pretransform_config["config"]
67
+ pretransform = PQMFPretransform(**pqmf_config)
68
+ elif pretransform_type == 'dac_pretrained':
69
+ from .pretransforms import PretrainedDACPretransform
70
+ pretrained_dac_config = pretransform_config["config"]
71
+ pretransform = PretrainedDACPretransform(**pretrained_dac_config)
72
+ elif pretransform_type == "audiocraft_pretrained":
73
+ from .pretransforms import AudiocraftCompressionPretransform
74
+
75
+ audiocraft_config = pretransform_config["config"]
76
+ pretransform = AudiocraftCompressionPretransform(**audiocraft_config)
77
+ else:
78
+ raise NotImplementedError(f'Unknown pretransform type: {pretransform_type}')
79
+
80
+ enable_grad = pretransform_config.get('enable_grad', False)
81
+ pretransform.enable_grad = enable_grad
82
+
83
+ pretransform.eval().requires_grad_(pretransform.enable_grad)
84
+
85
+ return pretransform
86
+
87
+ def create_bottleneck_from_config(bottleneck_config):
88
+ bottleneck_type = bottleneck_config.get('type', None)
89
+
90
+ assert bottleneck_type is not None, 'type must be specified in bottleneck config'
91
+
92
+ if bottleneck_type == 'tanh':
93
+ from .bottleneck import TanhBottleneck
94
+ return TanhBottleneck()
95
+ elif bottleneck_type == 'vae':
96
+ from .bottleneck import VAEBottleneck
97
+ return VAEBottleneck()
98
+ elif bottleneck_type == 'rvq':
99
+ from .bottleneck import RVQBottleneck
100
+
101
+ quantizer_params = {
102
+ "dim": 128,
103
+ "codebook_size": 1024,
104
+ "num_quantizers": 8,
105
+ "decay": 0.99,
106
+ "kmeans_init": True,
107
+ "kmeans_iters": 50,
108
+ "threshold_ema_dead_code": 2,
109
+ }
110
+
111
+ quantizer_params.update(bottleneck_config["config"])
112
+
113
+ return RVQBottleneck(**quantizer_params)
114
+ elif bottleneck_type == "dac_rvq":
115
+ from .bottleneck import DACRVQBottleneck
116
+
117
+ return DACRVQBottleneck(**bottleneck_config["config"])
118
+
119
+ elif bottleneck_type == 'rvq_vae':
120
+ from .bottleneck import RVQVAEBottleneck
121
+
122
+ quantizer_params = {
123
+ "dim": 128,
124
+ "codebook_size": 1024,
125
+ "num_quantizers": 8,
126
+ "decay": 0.99,
127
+ "kmeans_init": True,
128
+ "kmeans_iters": 50,
129
+ "threshold_ema_dead_code": 2,
130
+ }
131
+
132
+ quantizer_params.update(bottleneck_config["config"])
133
+
134
+ return RVQVAEBottleneck(**quantizer_params)
135
+
136
+ elif bottleneck_type == 'dac_rvq_vae':
137
+ from .bottleneck import DACRVQVAEBottleneck
138
+ return DACRVQVAEBottleneck(**bottleneck_config["config"])
139
+ elif bottleneck_type == 'l2_norm':
140
+ from .bottleneck import L2Bottleneck
141
+ return L2Bottleneck()
142
+ elif bottleneck_type == "wasserstein":
143
+ from .bottleneck import WassersteinBottleneck
144
+ return WassersteinBottleneck(**bottleneck_config.get("config", {}))
145
+ elif bottleneck_type == "fsq":
146
+ from .bottleneck import FSQBottleneck
147
+ return FSQBottleneck(**bottleneck_config["config"])
148
+ else:
149
+ raise NotImplementedError(f'Unknown bottleneck type: {bottleneck_type}')
stable_audio_tools/models/lm.py ADDED
@@ -0,0 +1,531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import torch
3
+ from tqdm.auto import trange
4
+ import typing as tp
5
+ from einops import rearrange
6
+ from torch import nn
7
+
8
+ from .autoencoders import AudioAutoencoder
9
+ from .conditioners import MultiConditioner, create_multi_conditioner_from_conditioning_config
10
+ from .factory import create_pretransform_from_config
11
+ from .lm_backbone import AudioLMBackbone, XTransformersAudioLMBackbone, ContinuousTransformerAudioLMBackbone
12
+ from .pretransforms import Pretransform, AutoencoderPretransform, PretrainedDACPretransform, AudiocraftCompressionPretransform
13
+ from .utils import multinomial, sample_top_k, sample_top_p
14
+
15
+ from audiocraft.modules.codebooks_patterns import (
16
+ CodebooksPatternProvider,
17
+ DelayedPatternProvider,
18
+ MusicLMPattern,
19
+ ParallelPatternProvider,
20
+ UnrolledPatternProvider,
21
+ VALLEPattern,
22
+ )
23
+
24
+ # Copied and modified from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/models/lm.py under MIT license
25
+ # License can be found in LICENSES/LICENSE_META.txt
26
+
27
+ @dataclass
28
+ class LMOutput:
29
+ # The logits are already re-aligned with the input codes
30
+ # hence no extra shift is required, e.g. when computing CE
31
+ logits: torch.Tensor # [B, K, T, card]
32
+ mask: torch.Tensor # [B, K, T]
33
+
34
+ # Wrapper for a multi-codebook language model
35
+ # Handles patterns and quantizer heads
36
+ class AudioLanguageModel(nn.Module):
37
+ def __init__(
38
+ self,
39
+ pattern_provider: CodebooksPatternProvider,
40
+ backbone: AudioLMBackbone,
41
+ num_quantizers: int,
42
+ codebook_size: int
43
+ ):
44
+ super().__init__()
45
+
46
+ self.pattern_provider = pattern_provider
47
+ self.backbone = backbone
48
+ self.num_quantizers = num_quantizers
49
+ self.codebook_size = codebook_size
50
+
51
+ self.masked_token_id = codebook_size
52
+
53
+ # Per-quantizer embedders
54
+ # Add one for the mask embed
55
+ self.embeds = nn.ModuleList([nn.Embedding(codebook_size + 1, backbone.embed_dim) for _ in range(num_quantizers)])
56
+
57
+ # Per-quantizer output heads
58
+ self.quantizer_heads = nn.ModuleList([
59
+ nn.Linear(backbone.embed_dim, codebook_size) for _ in range(num_quantizers)
60
+ ])
61
+
62
+ def forward(self,
63
+ sequence: torch.Tensor, #[batch, seq_len,
64
+ prepend_cond=None, #[batch, seq, channels]
65
+ prepend_cond_mask=None,
66
+ cross_attn_cond=None, #[batch, seq, channels],
67
+ **kwargs
68
+ ):
69
+
70
+ batch, num_quantizers, seq_len = sequence.shape
71
+
72
+ assert num_quantizers == self.num_quantizers, "Number of quantizers in sequence must match number of quantizers in model"
73
+
74
+ backbone_input = sum([self.embeds[i](sequence[:, i]) for i in range(num_quantizers)]) # [batch, seq_len, embed_dim]
75
+
76
+ output = self.backbone(
77
+ backbone_input,
78
+ cross_attn_cond=cross_attn_cond,
79
+ prepend_cond=prepend_cond,
80
+ prepend_cond_mask=prepend_cond_mask,
81
+ **kwargs
82
+ ) # [batch, seq_len, embed_dim]
83
+
84
+ # Run output through quantizer heads
85
+ logits = torch.stack([self.quantizer_heads[i](output) for i in range(num_quantizers)], dim=1) # [batch, num_quantizers, seq_len, codebook_size]
86
+
87
+ return logits
88
+
89
+ def compute_logits(
90
+ self,
91
+ codes, #[batch, num_quantizers, seq_len]
92
+ **kwargs):
93
+ """
94
+ Compute logits for a batch of codes, optionally conditioning on cross-attention and prepend conditioning
95
+ Handles translation between input sequence and pattern-shifted sequence
96
+ Only used during training
97
+ """
98
+
99
+ batch, _, seq_len = codes.shape
100
+
101
+ pattern = self.pattern_provider.get_pattern(seq_len)
102
+
103
+ # Apply the token pattern to the codes, shifting the codes as needed and masking out invalid steps
104
+ shifted_codes, _, _ = pattern.build_pattern_sequence(
105
+ codes,
106
+ self.masked_token_id,
107
+ keep_only_valid_steps=True
108
+ )
109
+
110
+ # Run the model to get logits for each quantizer [batch, num_quantizers, seq_len, codebook_size]
111
+ logits = self(shifted_codes, **kwargs)
112
+
113
+ # Rearrange logits to prepare to revert pattern
114
+ logits = rearrange(logits, "b n s c -> b c n s")
115
+
116
+ # Revert sequence logits back to original sequence length, removing masked steps
117
+ logits, _, logits_mask = pattern.revert_pattern_logits(
118
+ logits, float('nan'), keep_only_valid_steps=True
119
+ )
120
+
121
+ logits = rearrange(logits, "b c n t -> b n t c")
122
+
123
+ logits_mask = logits_mask[None, :, :].expand(batch, -1, -1) # [batch, num_quantizers, seq_len]
124
+
125
+ return LMOutput(logits=logits, mask=logits_mask)
126
+
127
+ # Conditioning and generation wrapper for a multi-codebook language model
128
+ # Handles conditioning, CFG, generation, and encoding/decoding
129
+ class AudioLanguageModelWrapper(nn.Module):
130
+ def __init__(
131
+ self,
132
+ pretransform: Pretransform,
133
+ lm: AudioLanguageModel,
134
+ sample_rate: int,
135
+ min_input_length: int,
136
+ conditioner: MultiConditioner = None,
137
+ cross_attn_cond_ids: tp.List[str] = [],
138
+ prepend_cond_ids: tp.List[str] = [],
139
+ global_cond_ids: tp.List[str] = []
140
+ ):
141
+ super().__init__()
142
+
143
+ assert pretransform.is_discrete, "Pretransform must be discrete"
144
+ self.pretransform = pretransform
145
+
146
+ self.pretransform.requires_grad_(False)
147
+ self.pretransform.eval()
148
+
149
+ if isinstance(self.pretransform, AutoencoderPretransform):
150
+ self.num_quantizers = self.pretransform.model.bottleneck.num_quantizers
151
+ self.codebook_size = self.pretransform.model.bottleneck.codebook_size
152
+ elif isinstance(self.pretransform, PretrainedDACPretransform):
153
+ self.num_quantizers = self.pretransform.model.num_quantizers
154
+ self.codebook_size = self.pretransform.model.codebook_size
155
+ elif isinstance(self.pretransform, AudiocraftCompressionPretransform):
156
+ self.num_quantizers = self.pretransform.num_quantizers
157
+ self.codebook_size = self.pretransform.codebook_size
158
+ else:
159
+ raise NotImplementedError(f"Unrecognized pretransform type {type(self.pretransform)}")
160
+
161
+ self.conditioner = conditioner
162
+
163
+ self.lm = lm
164
+
165
+ self.sample_rate = sample_rate
166
+ self.min_input_length = min_input_length
167
+
168
+ self.cross_attn_cond_ids = cross_attn_cond_ids
169
+ self.prepend_cond_ids = prepend_cond_ids
170
+ self.global_cond_ids = global_cond_ids
171
+
172
+ def get_conditioning_inputs(self, cond: tp.Dict[str, tp.Any], negative=False):
173
+ cross_attention_input = None
174
+ prepend_cond = None
175
+ prepend_cond_mask = None
176
+ global_cond = None
177
+
178
+ if len(self.cross_attn_cond_ids) > 0:
179
+ # Concatenate all cross-attention inputs over the sequence dimension
180
+ # Assumes that the cross-attention inputs are of shape (batch, seq, channels)
181
+ cross_attention_input = torch.cat([cond[key][0] for key in self.cross_attn_cond_ids], dim=1)
182
+
183
+ if len(self.prepend_cond_ids) > 0:
184
+ # Concatenate all prepend conditioning inputs over the sequence dimension
185
+ # Assumes that the prepend conditioning inputs are of shape (batch, seq, channels)
186
+ prepend_cond = torch.cat([cond[key][0] for key in self.prepend_cond_ids], dim=1)
187
+ prepend_cond_mask = torch.cat([cond[key][1] for key in self.prepend_cond_ids], dim=1)
188
+
189
+ if len(self.global_cond_ids) > 0:
190
+ # Concatenate all global conditioning inputs over the channel dimension
191
+ # Assumes that the global conditioning inputs are of shape (batch, channels)
192
+ global_cond = torch.cat([cond[key][0] for key in self.global_cond_ids], dim=-1)
193
+ if len(global_cond.shape) == 3:
194
+ global_cond = global_cond.squeeze(1)
195
+
196
+ if negative:
197
+ return {
198
+ "negative_cross_attn_cond": cross_attention_input,
199
+ "negative_prepend_cond": prepend_cond,
200
+ "negative_prepend_cond_mask": prepend_cond_mask,
201
+ "negative_global_cond": global_cond
202
+ }
203
+ else:
204
+ return {
205
+ "cross_attn_cond": cross_attention_input,
206
+ "prepend_cond": prepend_cond,
207
+ "prepend_cond_mask": prepend_cond_mask,
208
+ "global_cond": global_cond
209
+ }
210
+
211
+ def compute_logits(
212
+ self,
213
+ codes,
214
+ condition_tensors=None,
215
+ cfg_dropout_prob=0.0,
216
+ **kwargs
217
+ ):
218
+ """
219
+ Compute logits for a batch of codes, and translates from conditioning inputs to model inputs
220
+ Handles CFG dropout
221
+ """
222
+
223
+ if condition_tensors is None:
224
+ condition_tensors = {}
225
+
226
+ conditioning_inputs = self.get_conditioning_inputs(condition_tensors)
227
+
228
+ cross_attn_cond = conditioning_inputs["cross_attn_cond"]
229
+ prepend_cond = conditioning_inputs["prepend_cond"]
230
+ prepend_cond_mask = conditioning_inputs["prepend_cond_mask"]
231
+ global_cond = conditioning_inputs["global_cond"]
232
+
233
+ if cfg_dropout_prob > 0.0:
234
+ if cross_attn_cond is not None:
235
+ null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
236
+ dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool)
237
+ cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond)
238
+
239
+ if prepend_cond is not None:
240
+ null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
241
+ dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool)
242
+ prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond)
243
+
244
+ if global_cond is not None:
245
+ null_embed = torch.zeros_like(global_cond, device=global_cond.device)
246
+ dropout_mask = torch.bernoulli(torch.full((global_cond.shape[0], 1), cfg_dropout_prob, device=global_cond.device)).to(torch.bool)
247
+ global_cond = torch.where(dropout_mask, null_embed, global_cond)
248
+
249
+ return self.lm.compute_logits(codes, cross_attn_cond=cross_attn_cond, prepend_cond=prepend_cond, prepend_cond_mask=prepend_cond_mask, global_cond=global_cond, **kwargs)
250
+
251
+ def _sample_next_token(
252
+ self,
253
+ sequence, #[batch, num_quantizers, seq_len]
254
+ conditioning_tensors=None,
255
+ cross_attn_use_cfg=True,
256
+ prepend_use_cfg=True,
257
+ global_use_cfg=True,
258
+ cfg_scale=1.0,
259
+ top_k=250,
260
+ top_p=0.0,
261
+ temp=1.0,
262
+ **kwargs
263
+ ):
264
+ """
265
+ Sample the next token for a batch of codes, and translates from conditioning inputs to model inputs
266
+ Handles CFG inference
267
+ """
268
+
269
+ if conditioning_tensors is None:
270
+ conditioning_tensors = {}
271
+
272
+ conditioning_inputs = self.get_conditioning_inputs(conditioning_tensors)
273
+
274
+ cross_attn_cond = conditioning_inputs["cross_attn_cond"]
275
+ prepend_cond = conditioning_inputs["prepend_cond"]
276
+ prepend_cond_mask = conditioning_inputs["prepend_cond_mask"]
277
+ global_cond = conditioning_inputs["global_cond"]
278
+
279
+ if cfg_scale != 1.0:
280
+
281
+ # Batch size is doubled to account for negative samples
282
+ sequence = torch.cat([sequence, sequence], dim=0)
283
+
284
+ if cross_attn_cond is not None and cross_attn_use_cfg:
285
+ null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
286
+
287
+ cross_attn_cond = torch.cat([cross_attn_cond, null_embed], dim=0)
288
+
289
+ if prepend_cond is not None and prepend_use_cfg:
290
+ null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
291
+
292
+ prepend_cond = torch.cat([prepend_cond, null_embed], dim=0)
293
+
294
+ if prepend_cond_mask is not None:
295
+ prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0)
296
+
297
+ if global_cond is not None and global_use_cfg:
298
+ null_embed = torch.zeros_like(global_cond, device=global_cond.device)
299
+
300
+ global_cond = torch.cat([global_cond, null_embed], dim=0)
301
+
302
+ logits = self.lm(sequence, cross_attn_cond=cross_attn_cond, prepend_cond=prepend_cond, prepend_cond_mask=prepend_cond_mask, global_cond=global_cond, **kwargs)
303
+
304
+ if cfg_scale != 1.0:
305
+ cond_logits, uncond_logits = logits.chunk(2, dim=0)
306
+
307
+ logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
308
+
309
+ logits = rearrange(logits, "b n s c -> b n c s") # [batch, num_quantizers, codebook_size, seq_len]
310
+
311
+ # Grab the logits for the last step
312
+ logits = logits[:, :, :, -1] # [batch, num_quantizers, codebook_size]
313
+
314
+ # Apply top-k or top-p sampling
315
+
316
+ if temp > 0:
317
+ probs = torch.softmax(logits / temp, dim=-1)
318
+
319
+ if top_p > 0.0:
320
+ next_token = sample_top_p(probs, p=top_p)
321
+ elif top_k > 0:
322
+ next_token = sample_top_k(probs, k=top_k)
323
+ else:
324
+ next_token = multinomial(probs, num_samples=1)
325
+
326
+ else:
327
+ next_token = torch.argmax(logits, dim=-1, keepdim=True) # [batch, num_quantizers, 1]
328
+
329
+ return next_token
330
+
331
+ @torch.no_grad()
332
+ def generate(
333
+ self,
334
+ max_gen_len: int = 256,
335
+ batch_size: tp.Optional[int] = None,
336
+ init_data: tp.Optional[torch.Tensor] = None,
337
+ conditioning: tp.Optional[tp.Dict[str, tp.Any]] = None,
338
+ conditioning_tensors: tp.Optional[tp.Dict[str, tp.Any]] = None,
339
+ callback: tp.Optional[tp.Callable[[int, int], None]] = None,
340
+ use_cache: bool = True,
341
+ cfg_scale: float = 1.0,
342
+ **kwargs
343
+ ):
344
+ device = next(self.parameters()).device
345
+
346
+ if conditioning_tensors is None and conditioning is not None:
347
+ # Convert conditioning inputs to conditioning tensors
348
+ conditioning_tensors = self.conditioner(conditioning, device)
349
+
350
+ # Check that batch size is consistent across inputs
351
+ possible_batch_sizes = []
352
+
353
+ if batch_size is not None:
354
+ possible_batch_sizes.append(batch_size)
355
+ elif init_data is not None:
356
+ possible_batch_sizes.append(init_data.shape[0])
357
+ elif conditioning_tensors is not None:
358
+ # Assume that the first conditioning tensor has the batch dimension
359
+ possible_batch_sizes.append(conditioning_tensors[list(conditioning_tensors.keys())[0]][0].shape[0])
360
+ else:
361
+ possible_batch_sizes.append(1)
362
+
363
+ assert [x == possible_batch_sizes[0] for x in possible_batch_sizes], "Batch size must be consistent across inputs"
364
+
365
+ batch_size = possible_batch_sizes[0]
366
+
367
+ if init_data is None:
368
+ # Initialize with zeros
369
+ assert batch_size > 0
370
+ init_data = torch.zeros((batch_size, self.num_quantizers, 0), device=device, dtype=torch.long)
371
+
372
+ batch_size, num_quantizers, seq_len = init_data.shape
373
+
374
+ start_offset = seq_len
375
+ assert start_offset < max_gen_len, "init data longer than max gen length"
376
+
377
+ pattern = self.lm.pattern_provider.get_pattern(max_gen_len)
378
+
379
+ unknown_token = -1
380
+
381
+ # Initialize the generated codes with the init data, padded with unknown tokens
382
+ gen_codes = torch.full((batch_size, num_quantizers, max_gen_len), unknown_token, device=device, dtype=torch.long)
383
+ gen_codes[:, :, :start_offset] = init_data # [batch, num_quantizers, max_gen_len]
384
+
385
+ gen_sequence, _, mask = pattern.build_pattern_sequence(gen_codes, self.lm.masked_token_id) # [batch, num_quantizers, gen_sequence_len]
386
+
387
+ start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
388
+ assert start_offset_sequence is not None
389
+
390
+ # Generation
391
+ prev_offset = 0
392
+ gen_sequence_len = gen_sequence.shape[-1]
393
+
394
+ # Reset generation cache
395
+ if use_cache and self.lm.backbone.use_generation_cache:
396
+ self.lm.backbone.reset_generation_cache(max_gen_len, batch_size if cfg_scale == 1.0 else batch_size * 2)
397
+
398
+ for offset in trange(start_offset_sequence, gen_sequence_len):
399
+
400
+ # Get the full sequence up to the current offset
401
+ curr_sequence = gen_sequence[..., prev_offset:offset]
402
+
403
+ next_token = self._sample_next_token(
404
+ curr_sequence,
405
+ conditioning_tensors=conditioning_tensors,
406
+ use_cache=use_cache,
407
+ cfg_scale=cfg_scale,
408
+ **kwargs
409
+ )
410
+
411
+ valid_mask = mask[..., offset:offset+1].expand(batch_size, -1, -1)
412
+ next_token[~valid_mask] = self.lm.masked_token_id
413
+
414
+ # Update the generated sequence with the next token
415
+ gen_sequence[..., offset:offset+1] = torch.where(
416
+ gen_sequence[..., offset:offset+1] == unknown_token,
417
+ next_token,
418
+ gen_sequence[..., offset:offset+1]
419
+ )
420
+
421
+ if use_cache and self.lm.backbone.use_generation_cache:
422
+ # Only update the offset if caching is being used
423
+ prev_offset = offset
424
+
425
+ self.lm.backbone.update_generation_cache(offset)
426
+
427
+ if callback is not None:
428
+ # Callback to report progress
429
+ # Pass in the offset relative to the start of the sequence, and the length of the current sequence
430
+ callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence)
431
+
432
+ assert not (gen_sequence == unknown_token).any(), "Unknown tokens in generated sequence"
433
+
434
+ out_codes, _, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
435
+
436
+ # sanity checks over the returned codes and corresponding masks
437
+ assert (out_codes[..., :max_gen_len] != unknown_token).all()
438
+ assert (out_mask[..., :max_gen_len] == 1).all()
439
+
440
+ #out_codes = out_codes[..., 0:max_gen_len]
441
+
442
+ return out_codes
443
+
444
+
445
+ def generate_audio(
446
+ self,
447
+ **kwargs
448
+ ):
449
+ """
450
+ Generate audio from a batch of codes
451
+ """
452
+
453
+ codes = self.generate(**kwargs)
454
+
455
+ audio = self.pretransform.decode_tokens(codes)
456
+
457
+ return audio
458
+
459
+
460
+ def create_audio_lm_from_config(config):
461
+ model_config = config.get('model', None)
462
+ assert model_config is not None, 'model config must be specified in config'
463
+
464
+ sample_rate = config.get('sample_rate', None)
465
+ assert sample_rate is not None, "Must specify sample_rate in config"
466
+
467
+ lm_config = model_config.get('lm', None)
468
+ assert lm_config is not None, 'lm config must be specified in model config'
469
+
470
+ codebook_pattern = lm_config.get("codebook_pattern", "delay")
471
+
472
+ pattern_providers = {
473
+ 'parallel': ParallelPatternProvider,
474
+ 'delay': DelayedPatternProvider,
475
+ 'unroll': UnrolledPatternProvider,
476
+ 'valle': VALLEPattern,
477
+ 'musiclm': MusicLMPattern,
478
+ }
479
+
480
+ pretransform_config = model_config.get("pretransform", None)
481
+
482
+ pretransform = create_pretransform_from_config(pretransform_config, sample_rate)
483
+
484
+ assert pretransform.is_discrete, "Pretransform must be discrete"
485
+
486
+ min_input_length = pretransform.downsampling_ratio
487
+
488
+ pattern_provider = pattern_providers[codebook_pattern](n_q=pretransform.num_quantizers)
489
+
490
+ conditioning_config = model_config.get('conditioning', None)
491
+
492
+ conditioner = None
493
+ if conditioning_config is not None:
494
+ conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config)
495
+
496
+ cross_attn_cond_ids = lm_config.get('cross_attention_cond_ids', [])
497
+ prepend_cond_ids = lm_config.get('prepend_cond_ids', [])
498
+ global_cond_ids = lm_config.get('global_cond_ids', [])
499
+
500
+ lm_type = lm_config.get("type", None)
501
+ lm_model_config = lm_config.get("config", None)
502
+
503
+ assert lm_type is not None, "Must specify lm type in lm config"
504
+ assert lm_model_config is not None, "Must specify lm model config in lm config"
505
+
506
+ if lm_type == "x-transformers":
507
+ backbone = XTransformersAudioLMBackbone(**lm_model_config)
508
+ elif lm_type == "continuous_transformer":
509
+ backbone = ContinuousTransformerAudioLMBackbone(**lm_model_config)
510
+ else:
511
+ raise NotImplementedError(f"Unrecognized lm type {lm_type}")
512
+
513
+ lm = AudioLanguageModel(
514
+ pattern_provider=pattern_provider,
515
+ backbone=backbone,
516
+ num_quantizers=pretransform.num_quantizers,
517
+ codebook_size=pretransform.codebook_size
518
+ )
519
+
520
+ model = AudioLanguageModelWrapper(
521
+ pretransform=pretransform,
522
+ lm=lm,
523
+ conditioner=conditioner,
524
+ sample_rate=sample_rate,
525
+ min_input_length=min_input_length,
526
+ cross_attn_cond_ids=cross_attn_cond_ids,
527
+ prepend_cond_ids=prepend_cond_ids,
528
+ global_cond_ids=global_cond_ids
529
+ )
530
+
531
+ return model
stable_audio_tools/models/lm_backbone.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from x_transformers import ContinuousTransformerWrapper, Decoder
4
+
5
+ from .transformer import ContinuousTransformer
6
+
7
+ # Interface for backbone of a language model
8
+ # Handles conditioning and cross-attention
9
+ # Does not have to deal with patterns or quantizer heads
10
+ class AudioLMBackbone(nn.Module):
11
+ def __init__(self, embed_dim: int, use_generation_cache=False, **kwargs):
12
+ super().__init__()
13
+
14
+ self.embed_dim = embed_dim
15
+ self.use_generation_cache = use_generation_cache
16
+
17
+ def forward(
18
+ self,
19
+ x,
20
+ cross_attn_cond=None,
21
+ prepend_cond=None,
22
+ prepend_cond_mask=None,
23
+ global_cond=None,
24
+ use_cache=False,
25
+ **kwargs
26
+ ):
27
+ raise NotImplementedError
28
+
29
+ def reset_generation_cache(
30
+ self,
31
+ max_seq_len,
32
+ batch_size,
33
+ dtype=None
34
+ ):
35
+ pass
36
+
37
+ def update_generation_cache(
38
+ self,
39
+ seqlen_offset
40
+ ):
41
+ pass
42
+
43
+ class XTransformersAudioLMBackbone(AudioLMBackbone):
44
+ def __init__(self,
45
+ embed_dim: int,
46
+ cross_attn_cond_dim: int = 0,
47
+ prepend_cond_dim: int = 0,
48
+ **kwargs):
49
+ super().__init__(embed_dim=embed_dim)
50
+
51
+ # Embeddings are done in the AudioLanguageModel, so we use the continuous-input transformer
52
+ self.model = ContinuousTransformerWrapper(
53
+ dim_in=embed_dim,
54
+ dim_out=embed_dim,
55
+ max_seq_len=0, #Not relevant without absolute positional embeds,
56
+ attn_layers=Decoder(
57
+ dim=embed_dim,
58
+ attn_flash = True,
59
+ cross_attend = cross_attn_cond_dim > 0,
60
+ zero_init_branch_output=True,
61
+ use_abs_pos_emb = False,
62
+ rotary_pos_emb=True,
63
+ ff_swish = True,
64
+ ff_glu = True,
65
+ **kwargs
66
+ )
67
+ )
68
+
69
+ if prepend_cond_dim > 0:
70
+ # Prepend conditioning
71
+ self.to_prepend_embed = nn.Sequential(
72
+ nn.Linear(prepend_cond_dim, embed_dim, bias=False),
73
+ nn.SiLU(),
74
+ nn.Linear(embed_dim, embed_dim, bias=False)
75
+ )
76
+
77
+ if cross_attn_cond_dim > 0:
78
+ # Cross-attention conditioning
79
+ self.to_cross_attn_embed = nn.Sequential(
80
+ nn.Linear(cross_attn_cond_dim, embed_dim, bias=False),
81
+ nn.SiLU(),
82
+ nn.Linear(embed_dim, embed_dim, bias=False)
83
+ )
84
+
85
+ def forward(self, x, mask=None, prepend_cond=None, prepend_cond_mask=None, cross_attn_cond=None, global_cond=None, use_cache=False):
86
+
87
+ prepend_length = 0
88
+ if prepend_cond is not None:
89
+ # Project the prepend conditioning to the embedding dimension
90
+ prepend_cond = self.to_prepend_embed(prepend_cond)
91
+ prepend_length = prepend_cond.shape[1]
92
+
93
+ if prepend_cond_mask is not None:
94
+ # Cast mask to bool
95
+ prepend_cond_mask = prepend_cond_mask.bool()
96
+
97
+ if cross_attn_cond is not None:
98
+ # Project the cross-attention conditioning to the embedding dimension
99
+ cross_attn_cond = self.to_cross_attn_embed(cross_attn_cond)
100
+
101
+ return self.model(x, mask=mask, context=cross_attn_cond, prepend_embeds=prepend_cond, prepend_mask=prepend_cond_mask)[:, prepend_length:, :]
102
+
103
+ class ContinuousTransformerAudioLMBackbone(AudioLMBackbone):
104
+ def __init__(self,
105
+ embed_dim: int,
106
+ cross_attn_cond_dim: int = 0,
107
+ prepend_cond_dim: int = 0,
108
+ project_cross_attn_cond: bool = False,
109
+ **kwargs):
110
+ super().__init__(embed_dim=embed_dim)
111
+
112
+ # Embeddings are done in the AudioLanguageModel, so we use the continuous-input transformer
113
+ self.model = ContinuousTransformer(
114
+ dim=embed_dim,
115
+ dim_in=embed_dim,
116
+ dim_out=embed_dim,
117
+ cross_attend = cross_attn_cond_dim > 0,
118
+ cond_token_dim = embed_dim if project_cross_attn_cond else cross_attn_cond_dim,
119
+ causal=True,
120
+ **kwargs
121
+ )
122
+
123
+ if prepend_cond_dim > 0:
124
+ # Prepend conditioning
125
+ self.to_prepend_embed = nn.Sequential(
126
+ nn.Linear(prepend_cond_dim, embed_dim, bias=False),
127
+ nn.SiLU(),
128
+ nn.Linear(embed_dim, embed_dim, bias=False)
129
+ )
130
+
131
+ if cross_attn_cond_dim > 0 and project_cross_attn_cond:
132
+ # Cross-attention conditioning
133
+ self.to_cross_attn_embed = nn.Sequential(
134
+ nn.Linear(cross_attn_cond_dim, embed_dim, bias=False),
135
+ nn.SiLU(),
136
+ nn.Linear(embed_dim, embed_dim, bias=False)
137
+ )
138
+ else:
139
+ self.to_cross_attn_embed = nn.Identity()
140
+
141
+ def forward(self, x, mask=None, prepend_cond=None, prepend_cond_mask=None, cross_attn_cond=None, global_cond=None, use_cache=False):
142
+
143
+ prepend_length = 0
144
+ if prepend_cond is not None:
145
+ # Project the prepend conditioning to the embedding dimension
146
+ prepend_cond = self.to_prepend_embed(prepend_cond)
147
+ prepend_length = prepend_cond.shape[1]
148
+
149
+ if prepend_cond_mask is not None:
150
+ # Cast mask to bool
151
+ prepend_cond_mask = prepend_cond_mask.bool()
152
+
153
+ if cross_attn_cond is not None:
154
+ # Project the cross-attention conditioning to the embedding dimension
155
+ cross_attn_cond = self.to_cross_attn_embed(cross_attn_cond)
156
+
157
+ return self.model(x, mask=mask, context=cross_attn_cond, prepend_embeds=prepend_cond, prepend_mask=prepend_cond_mask)[:, prepend_length:, :]
stable_audio_tools/models/local_attention.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from einops import rearrange
4
+ from torch import nn
5
+
6
+ from .blocks import AdaRMSNorm
7
+ from .transformer import Attention, FeedForward, RotaryEmbedding, LayerNorm
8
+
9
+ def checkpoint(function, *args, **kwargs):
10
+ kwargs.setdefault("use_reentrant", False)
11
+ return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
12
+
13
+ # Adapted from https://github.com/lucidrains/local-attention/blob/master/local_attention/transformer.py
14
+ class ContinuousLocalTransformer(nn.Module):
15
+ def __init__(
16
+ self,
17
+ *,
18
+ dim,
19
+ depth,
20
+ dim_in = None,
21
+ dim_out = None,
22
+ causal = False,
23
+ local_attn_window_size = 64,
24
+ heads = 8,
25
+ ff_mult = 2,
26
+ cond_dim = 0,
27
+ cross_attn_cond_dim = 0,
28
+ **kwargs
29
+ ):
30
+ super().__init__()
31
+
32
+ dim_head = dim//heads
33
+
34
+ self.layers = nn.ModuleList([])
35
+
36
+ self.project_in = nn.Linear(dim_in, dim) if dim_in is not None else nn.Identity()
37
+
38
+ self.project_out = nn.Linear(dim, dim_out) if dim_out is not None else nn.Identity()
39
+
40
+ self.local_attn_window_size = local_attn_window_size
41
+
42
+ self.cond_dim = cond_dim
43
+
44
+ self.cross_attn_cond_dim = cross_attn_cond_dim
45
+
46
+ self.rotary_pos_emb = RotaryEmbedding(max(dim_head // 2, 32))
47
+
48
+ for _ in range(depth):
49
+
50
+ self.layers.append(nn.ModuleList([
51
+ AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim),
52
+ Attention(
53
+ dim=dim,
54
+ dim_heads=dim_head,
55
+ causal=causal,
56
+ zero_init_output=True,
57
+ natten_kernel_size=local_attn_window_size,
58
+ ),
59
+ Attention(
60
+ dim=dim,
61
+ dim_heads=dim_head,
62
+ dim_context = cross_attn_cond_dim,
63
+ zero_init_output=True
64
+ ) if self.cross_attn_cond_dim > 0 else nn.Identity(),
65
+ AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim),
66
+ FeedForward(dim = dim, mult = ff_mult, no_bias=True)
67
+ ]))
68
+
69
+ def forward(self, x, mask = None, cond = None, cross_attn_cond = None, cross_attn_cond_mask = None, prepend_cond = None):
70
+
71
+ x = checkpoint(self.project_in, x)
72
+
73
+ if prepend_cond is not None:
74
+ x = torch.cat([prepend_cond, x], dim=1)
75
+
76
+ pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1])
77
+
78
+ for attn_norm, attn, xattn, ff_norm, ff in self.layers:
79
+
80
+ residual = x
81
+ if cond is not None:
82
+ x = checkpoint(attn_norm, x, cond)
83
+ else:
84
+ x = checkpoint(attn_norm, x)
85
+
86
+ x = checkpoint(attn, x, mask = mask, rotary_pos_emb=pos_emb) + residual
87
+
88
+ if cross_attn_cond is not None:
89
+ x = checkpoint(xattn, x, context=cross_attn_cond, context_mask=cross_attn_cond_mask) + x
90
+
91
+ residual = x
92
+
93
+ if cond is not None:
94
+ x = checkpoint(ff_norm, x, cond)
95
+ else:
96
+ x = checkpoint(ff_norm, x)
97
+
98
+ x = checkpoint(ff, x) + residual
99
+
100
+ return checkpoint(self.project_out, x)
101
+
102
+ class TransformerDownsampleBlock1D(nn.Module):
103
+ def __init__(
104
+ self,
105
+ in_channels,
106
+ embed_dim = 768,
107
+ depth = 3,
108
+ heads = 12,
109
+ downsample_ratio = 2,
110
+ local_attn_window_size = 64,
111
+ **kwargs
112
+ ):
113
+ super().__init__()
114
+
115
+ self.downsample_ratio = downsample_ratio
116
+
117
+ self.transformer = ContinuousLocalTransformer(
118
+ dim=embed_dim,
119
+ depth=depth,
120
+ heads=heads,
121
+ local_attn_window_size=local_attn_window_size,
122
+ **kwargs
123
+ )
124
+
125
+ self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity()
126
+
127
+ self.project_down = nn.Linear(embed_dim * self.downsample_ratio, embed_dim, bias=False)
128
+
129
+
130
+ def forward(self, x):
131
+
132
+ x = checkpoint(self.project_in, x)
133
+
134
+ # Compute
135
+ x = self.transformer(x)
136
+
137
+ # Trade sequence length for channels
138
+ x = rearrange(x, "b (n r) c -> b n (c r)", r=self.downsample_ratio)
139
+
140
+ # Project back to embed dim
141
+ x = checkpoint(self.project_down, x)
142
+
143
+ return x
144
+
145
+ class TransformerUpsampleBlock1D(nn.Module):
146
+ def __init__(
147
+ self,
148
+ in_channels,
149
+ embed_dim,
150
+ depth = 3,
151
+ heads = 12,
152
+ upsample_ratio = 2,
153
+ local_attn_window_size = 64,
154
+ **kwargs
155
+ ):
156
+ super().__init__()
157
+
158
+ self.upsample_ratio = upsample_ratio
159
+
160
+ self.transformer = ContinuousLocalTransformer(
161
+ dim=embed_dim,
162
+ depth=depth,
163
+ heads=heads,
164
+ local_attn_window_size = local_attn_window_size,
165
+ **kwargs
166
+ )
167
+
168
+ self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity()
169
+
170
+ self.project_up = nn.Linear(embed_dim, embed_dim * self.upsample_ratio, bias=False)
171
+
172
+ def forward(self, x):
173
+
174
+ # Project to embed dim
175
+ x = checkpoint(self.project_in, x)
176
+
177
+ # Project to increase channel dim
178
+ x = checkpoint(self.project_up, x)
179
+
180
+ # Trade channels for sequence length
181
+ x = rearrange(x, "b n (c r) -> b (n r) c", r=self.upsample_ratio)
182
+
183
+ # Compute
184
+ x = self.transformer(x)
185
+
186
+ return x
187
+
188
+
189
+ class TransformerEncoder1D(nn.Module):
190
+ def __init__(
191
+ self,
192
+ in_channels,
193
+ out_channels,
194
+ embed_dims = [96, 192, 384, 768],
195
+ heads = [12, 12, 12, 12],
196
+ depths = [3, 3, 3, 3],
197
+ ratios = [2, 2, 2, 2],
198
+ local_attn_window_size = 64,
199
+ **kwargs
200
+ ):
201
+ super().__init__()
202
+
203
+ layers = []
204
+
205
+ for layer in range(len(depths)):
206
+ prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0]
207
+
208
+ layers.append(
209
+ TransformerDownsampleBlock1D(
210
+ in_channels = prev_dim,
211
+ embed_dim = embed_dims[layer],
212
+ heads = heads[layer],
213
+ depth = depths[layer],
214
+ downsample_ratio = ratios[layer],
215
+ local_attn_window_size = local_attn_window_size,
216
+ **kwargs
217
+ )
218
+ )
219
+
220
+ self.layers = nn.Sequential(*layers)
221
+
222
+ self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False)
223
+ self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False)
224
+
225
+ def forward(self, x):
226
+ x = rearrange(x, "b c n -> b n c")
227
+ x = checkpoint(self.project_in, x)
228
+ x = self.layers(x)
229
+ x = checkpoint(self.project_out, x)
230
+ x = rearrange(x, "b n c -> b c n")
231
+
232
+ return x
233
+
234
+
235
+ class TransformerDecoder1D(nn.Module):
236
+ def __init__(
237
+ self,
238
+ in_channels,
239
+ out_channels,
240
+ embed_dims = [768, 384, 192, 96],
241
+ heads = [12, 12, 12, 12],
242
+ depths = [3, 3, 3, 3],
243
+ ratios = [2, 2, 2, 2],
244
+ local_attn_window_size = 64,
245
+ **kwargs
246
+ ):
247
+
248
+ super().__init__()
249
+
250
+ layers = []
251
+
252
+ for layer in range(len(depths)):
253
+ prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0]
254
+
255
+ layers.append(
256
+ TransformerUpsampleBlock1D(
257
+ in_channels = prev_dim,
258
+ embed_dim = embed_dims[layer],
259
+ heads = heads[layer],
260
+ depth = depths[layer],
261
+ upsample_ratio = ratios[layer],
262
+ local_attn_window_size = local_attn_window_size,
263
+ **kwargs
264
+ )
265
+ )
266
+
267
+ self.layers = nn.Sequential(*layers)
268
+
269
+ self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False)
270
+ self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False)
271
+
272
+ def forward(self, x):
273
+ x = rearrange(x, "b c n -> b n c")
274
+ x = checkpoint(self.project_in, x)
275
+ x = self.layers(x)
276
+ x = checkpoint(self.project_out, x)
277
+ x = rearrange(x, "b n c -> b c n")
278
+ return x
stable_audio_tools/models/musicgen.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import typing as tp
3
+ from audiocraft.models import MusicGen, CompressionModel, LMModel
4
+ import audiocraft.quantization as qt
5
+ from .autoencoders import AudioAutoencoder
6
+ from .bottleneck import DACRVQBottleneck, DACRVQVAEBottleneck
7
+
8
+ from audiocraft.modules.codebooks_patterns import (
9
+ DelayedPatternProvider,
10
+ MusicLMPattern,
11
+ ParallelPatternProvider,
12
+ UnrolledPatternProvider,
13
+ VALLEPattern,
14
+ )
15
+
16
+ from audiocraft.modules.conditioners import (
17
+ ConditionFuser,
18
+ ConditioningProvider,
19
+ T5Conditioner,
20
+ )
21
+
22
+ def create_musicgen_from_config(config):
23
+ model_config = config.get('model', None)
24
+ assert model_config is not None, 'model config must be specified in config'
25
+
26
+ if model_config.get("pretrained", False):
27
+ model = MusicGen.get_pretrained(model_config["pretrained"], device="cpu")
28
+
29
+ if model_config.get("reinit_lm", False):
30
+ model.lm._init_weights("gaussian", "current", True)
31
+
32
+ return model
33
+
34
+ # Create MusicGen model from scratch
35
+ compression_config = model_config.get('compression', None)
36
+ assert compression_config is not None, 'compression config must be specified in model config'
37
+
38
+ compression_type = compression_config.get('type', None)
39
+ assert compression_type is not None, 'type must be specified in compression config'
40
+
41
+ if compression_type == 'pretrained':
42
+ compression_model = CompressionModel.get_pretrained(compression_config["config"]["name"])
43
+ elif compression_type == "dac_rvq_ae":
44
+ from .autoencoders import create_autoencoder_from_config
45
+ autoencoder = create_autoencoder_from_config({"model": compression_config["config"], "sample_rate": config["sample_rate"]})
46
+ autoencoder.load_state_dict(torch.load(compression_config["ckpt_path"], map_location="cpu")["state_dict"])
47
+ compression_model = DACRVQCompressionModel(autoencoder)
48
+
49
+ lm_config = model_config.get('lm', None)
50
+ assert lm_config is not None, 'lm config must be specified in model config'
51
+
52
+ codebook_pattern = lm_config.pop("codebook_pattern", "delay")
53
+
54
+ pattern_providers = {
55
+ 'parallel': ParallelPatternProvider,
56
+ 'delay': DelayedPatternProvider,
57
+ 'unroll': UnrolledPatternProvider,
58
+ 'valle': VALLEPattern,
59
+ 'musiclm': MusicLMPattern,
60
+ }
61
+
62
+ pattern_provider = pattern_providers[codebook_pattern](n_q=compression_model.num_codebooks)
63
+
64
+ conditioning_config = model_config.get("conditioning", {})
65
+
66
+ condition_output_dim = conditioning_config.get("output_dim", 768)
67
+
68
+ condition_provider = ConditioningProvider(
69
+ conditioners = {
70
+ "description": T5Conditioner(
71
+ name="t5-base",
72
+ output_dim=condition_output_dim,
73
+ word_dropout=0.3,
74
+ normalize_text=False,
75
+ finetune=False,
76
+ device="cpu"
77
+ )
78
+ }
79
+ )
80
+
81
+ condition_fuser = ConditionFuser(fuse2cond={
82
+ "cross": ["description"],
83
+ "prepend": [],
84
+ "sum": []
85
+ })
86
+
87
+ lm = LMModel(
88
+ pattern_provider = pattern_provider,
89
+ condition_provider = condition_provider,
90
+ fuser = condition_fuser,
91
+ n_q = compression_model.num_codebooks,
92
+ card = compression_model.cardinality,
93
+ **lm_config
94
+ )
95
+
96
+
97
+ model = MusicGen(
98
+ name = model_config.get("name", "musicgen-scratch"),
99
+ compression_model = compression_model,
100
+ lm = lm,
101
+ max_duration=30
102
+ )
103
+
104
+ return model
105
+
106
+ class DACRVQCompressionModel(CompressionModel):
107
+ def __init__(self, autoencoder: AudioAutoencoder):
108
+ super().__init__()
109
+ self.model = autoencoder.eval()
110
+
111
+ assert isinstance(self.model.bottleneck, DACRVQBottleneck) or isinstance(self.model.bottleneck, DACRVQVAEBottleneck), "Autoencoder must have a DACRVQBottleneck or DACRVQVAEBottleneck"
112
+
113
+ self.n_quantizers = self.model.bottleneck.num_quantizers
114
+
115
+ def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
116
+ raise NotImplementedError("Forward and training with DAC RVQ not supported")
117
+
118
+ def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
119
+ _, info = self.model.encode(x, return_info=True, n_quantizers=self.n_quantizers)
120
+ codes = info["codes"]
121
+ return codes, None
122
+
123
+ def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
124
+ assert scale is None
125
+ z_q = self.decode_latent(codes)
126
+ return self.model.decode(z_q)
127
+
128
+ def decode_latent(self, codes: torch.Tensor):
129
+ """Decode from the discrete codes to continuous latent space."""
130
+ return self.model.bottleneck.quantizer.from_codes(codes)[0]
131
+
132
+ @property
133
+ def channels(self) -> int:
134
+ return self.model.io_channels
135
+
136
+ @property
137
+ def frame_rate(self) -> float:
138
+ return self.model.sample_rate / self.model.downsampling_ratio
139
+
140
+ @property
141
+ def sample_rate(self) -> int:
142
+ return self.model.sample_rate
143
+
144
+ @property
145
+ def cardinality(self) -> int:
146
+ return self.model.bottleneck.quantizer.codebook_size
147
+
148
+ @property
149
+ def num_codebooks(self) -> int:
150
+ return self.n_quantizers
151
+
152
+ @property
153
+ def total_codebooks(self) -> int:
154
+ self.model.bottleneck.num_quantizers
155
+
156
+ def set_num_codebooks(self, n: int):
157
+ """Set the active number of codebooks used by the quantizer.
158
+ """
159
+ assert n >= 1
160
+ assert n <= self.total_codebooks
161
+ self.n_quantizers = n
stable_audio_tools/models/pqmf.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ from einops import rearrange
6
+ from scipy.optimize import fmin
7
+ from scipy.signal import firwin, kaiser, kaiser_beta, kaiserord
8
+
9
+ class PQMF(nn.Module):
10
+ """
11
+ Pseudo Quadrature Mirror Filter (PQMF) for multiband signal decomposition and reconstruction.
12
+ Uses polyphase representation which is computationally more efficient for real-time.
13
+
14
+ Parameters:
15
+ - attenuation (int): Desired attenuation of the rejected frequency bands, usually between 80 and 120 dB.
16
+ - num_bands (int): Number of desired frequency bands. It must be a power of 2.
17
+ """
18
+
19
+ def __init__(self, attenuation, num_bands):
20
+ super(PQMF, self).__init__()
21
+
22
+ # Ensure num_bands is a power of 2
23
+ is_power_of_2 = (math.log2(num_bands) == int(math.log2(num_bands)))
24
+ assert is_power_of_2, "'num_bands' must be a power of 2."
25
+
26
+ # Create the prototype filter
27
+ prototype_filter = design_prototype_filter(attenuation, num_bands)
28
+ filter_bank = generate_modulated_filter_bank(prototype_filter, num_bands)
29
+ padded_filter_bank = pad_to_nearest_power_of_two(filter_bank)
30
+
31
+ # Register filters and settings
32
+ self.register_buffer("filter_bank", padded_filter_bank)
33
+ self.register_buffer("prototype", prototype_filter)
34
+ self.num_bands = num_bands
35
+
36
+ def forward(self, signal):
37
+ """Decompose the signal into multiple frequency bands."""
38
+ # If signal is not a pytorch tensor of Batch x Channels x Length, convert it
39
+ signal = prepare_signal_dimensions(signal)
40
+ # The signal length must be a multiple of num_bands. Pad it with zeros.
41
+ signal = pad_signal(signal, self.num_bands)
42
+ # run it
43
+ signal = polyphase_analysis(signal, self.filter_bank)
44
+ return apply_alias_cancellation(signal)
45
+
46
+ def inverse(self, bands):
47
+ """Reconstruct the original signal from the frequency bands."""
48
+ bands = apply_alias_cancellation(bands)
49
+ return polyphase_synthesis(bands, self.filter_bank)
50
+
51
+
52
+ def prepare_signal_dimensions(signal):
53
+ """
54
+ Rearrange signal into Batch x Channels x Length.
55
+
56
+ Parameters
57
+ ----------
58
+ signal : torch.Tensor or numpy.ndarray
59
+ The input signal.
60
+
61
+ Returns
62
+ -------
63
+ torch.Tensor
64
+ Preprocessed signal tensor.
65
+ """
66
+ # Convert numpy to torch tensor
67
+ if isinstance(signal, np.ndarray):
68
+ signal = torch.from_numpy(signal)
69
+
70
+ # Ensure tensor
71
+ if not isinstance(signal, torch.Tensor):
72
+ raise ValueError("Input should be either a numpy array or a PyTorch tensor.")
73
+
74
+ # Modify dimension of signal to Batch x Channels x Length
75
+ if signal.dim() == 1:
76
+ # This is just a mono signal. Unsqueeze to 1 x 1 x Length
77
+ signal = signal.unsqueeze(0).unsqueeze(0)
78
+ elif signal.dim() == 2:
79
+ # This is a multi-channel signal (e.g. stereo)
80
+ # Rearrange so that larger dimension (Length) is last
81
+ if signal.shape[0] > signal.shape[1]:
82
+ signal = signal.T
83
+ # Unsqueeze to 1 x Channels x Length
84
+ signal = signal.unsqueeze(0)
85
+ return signal
86
+
87
+ def pad_signal(signal, num_bands):
88
+ """
89
+ Pads the signal to make its length divisible by the given number of bands.
90
+
91
+ Parameters
92
+ ----------
93
+ signal : torch.Tensor
94
+ The input signal tensor, where the last dimension represents the signal length.
95
+
96
+ num_bands : int
97
+ The number of bands by which the signal length should be divisible.
98
+
99
+ Returns
100
+ -------
101
+ torch.Tensor
102
+ The padded signal tensor. If the original signal length was already divisible
103
+ by num_bands, returns the original signal unchanged.
104
+ """
105
+ remainder = signal.shape[-1] % num_bands
106
+ if remainder > 0:
107
+ padding_size = num_bands - remainder
108
+ signal = nn.functional.pad(signal, (0, padding_size))
109
+ return signal
110
+
111
+ def generate_modulated_filter_bank(prototype_filter, num_bands):
112
+ """
113
+ Generate a QMF bank of cosine modulated filters based on a given prototype filter.
114
+
115
+ Parameters
116
+ ----------
117
+ prototype_filter : torch.Tensor
118
+ The prototype filter used as the basis for modulation.
119
+ num_bands : int
120
+ The number of desired subbands or filters.
121
+
122
+ Returns
123
+ -------
124
+ torch.Tensor
125
+ A bank of cosine modulated filters.
126
+ """
127
+
128
+ # Initialize indices for modulation.
129
+ subband_indices = torch.arange(num_bands).reshape(-1, 1)
130
+
131
+ # Calculate the length of the prototype filter.
132
+ filter_length = prototype_filter.shape[-1]
133
+
134
+ # Generate symmetric time indices centered around zero.
135
+ time_indices = torch.arange(-(filter_length // 2), (filter_length // 2) + 1)
136
+
137
+ # Calculate phase offsets to ensure orthogonality between subbands.
138
+ phase_offsets = (-1)**subband_indices * np.pi / 4
139
+
140
+ # Compute the cosine modulation function.
141
+ modulation = torch.cos(
142
+ (2 * subband_indices + 1) * np.pi / (2 * num_bands) * time_indices + phase_offsets
143
+ )
144
+
145
+ # Apply modulation to the prototype filter.
146
+ modulated_filters = 2 * prototype_filter * modulation
147
+
148
+ return modulated_filters
149
+
150
+
151
+ def design_kaiser_lowpass(angular_cutoff, attenuation, filter_length=None):
152
+ """
153
+ Design a lowpass filter using the Kaiser window.
154
+
155
+ Parameters
156
+ ----------
157
+ angular_cutoff : float
158
+ The angular frequency cutoff of the filter.
159
+ attenuation : float
160
+ The desired stopband attenuation in decibels (dB).
161
+ filter_length : int, optional
162
+ Desired length of the filter. If not provided, it's computed based on the given specs.
163
+
164
+ Returns
165
+ -------
166
+ ndarray
167
+ The designed lowpass filter coefficients.
168
+ """
169
+
170
+ estimated_length, beta = kaiserord(attenuation, angular_cutoff / np.pi)
171
+
172
+ # Ensure the estimated length is odd.
173
+ estimated_length = 2 * (estimated_length // 2) + 1
174
+
175
+ if filter_length is None:
176
+ filter_length = estimated_length
177
+
178
+ return firwin(filter_length, angular_cutoff, window=('kaiser', beta), scale=False, nyq=np.pi)
179
+
180
+
181
+ def evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length):
182
+ """
183
+ Evaluate the filter's objective value based on the criteria from https://ieeexplore.ieee.org/document/681427
184
+
185
+ Parameters
186
+ ----------
187
+ angular_cutoff : float
188
+ Angular frequency cutoff of the filter.
189
+ attenuation : float
190
+ Desired stopband attenuation in dB.
191
+ num_bands : int
192
+ Number of bands for the multiband filter system.
193
+ filter_length : int, optional
194
+ Desired length of the filter.
195
+
196
+ Returns
197
+ -------
198
+ float
199
+ The computed objective (loss) value for the given filter specs.
200
+ """
201
+
202
+ filter_coeffs = design_kaiser_lowpass(angular_cutoff, attenuation, filter_length)
203
+ convolved_filter = np.convolve(filter_coeffs, filter_coeffs[::-1], "full")
204
+
205
+ return np.max(np.abs(convolved_filter[convolved_filter.shape[-1] // 2::2 * num_bands][1:]))
206
+
207
+
208
+ def design_prototype_filter(attenuation, num_bands, filter_length=None):
209
+ """
210
+ Design the optimal prototype filter for a multiband system given the desired specs.
211
+
212
+ Parameters
213
+ ----------
214
+ attenuation : float
215
+ The desired stopband attenuation in dB.
216
+ num_bands : int
217
+ Number of bands for the multiband filter system.
218
+ filter_length : int, optional
219
+ Desired length of the filter. If not provided, it's computed based on the given specs.
220
+
221
+ Returns
222
+ -------
223
+ ndarray
224
+ The optimal prototype filter coefficients.
225
+ """
226
+
227
+ optimal_angular_cutoff = fmin(lambda angular_cutoff: evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length),
228
+ 1 / num_bands, disp=0)[0]
229
+
230
+ prototype_filter = design_kaiser_lowpass(optimal_angular_cutoff, attenuation, filter_length)
231
+ return torch.tensor(prototype_filter, dtype=torch.float32)
232
+
233
+ def pad_to_nearest_power_of_two(x):
234
+ """
235
+ Pads the input tensor 'x' on both sides such that its last dimension
236
+ becomes the nearest larger power of two.
237
+
238
+ Parameters:
239
+ -----------
240
+ x : torch.Tensor
241
+ The input tensor to be padded.
242
+
243
+ Returns:
244
+ --------
245
+ torch.Tensor
246
+ The padded tensor.
247
+ """
248
+ current_length = x.shape[-1]
249
+ target_length = 2**math.ceil(math.log2(current_length))
250
+
251
+ total_padding = target_length - current_length
252
+ left_padding = total_padding // 2
253
+ right_padding = total_padding - left_padding
254
+
255
+ return nn.functional.pad(x, (left_padding, right_padding))
256
+
257
+ def apply_alias_cancellation(x):
258
+ """
259
+ Applies alias cancellation by inverting the sign of every
260
+ second element of every second row, starting from the second
261
+ row's first element in a tensor.
262
+
263
+ This operation helps ensure that the aliasing introduced in
264
+ each band during the decomposition will be counteracted during
265
+ the reconstruction.
266
+
267
+ Parameters:
268
+ -----------
269
+ x : torch.Tensor
270
+ The input tensor.
271
+
272
+ Returns:
273
+ --------
274
+ torch.Tensor
275
+ Tensor with specific elements' sign inverted for alias cancellation.
276
+ """
277
+
278
+ # Create a mask of the same shape as 'x', initialized with all ones
279
+ mask = torch.ones_like(x)
280
+
281
+ # Update specific elements in the mask to -1 to perform inversion
282
+ mask[..., 1::2, ::2] = -1
283
+
284
+ # Apply the mask to the input tensor 'x'
285
+ return x * mask
286
+
287
+ def ensure_odd_length(tensor):
288
+ """
289
+ Pads the last dimension of a tensor to ensure its size is odd.
290
+
291
+ Parameters:
292
+ -----------
293
+ tensor : torch.Tensor
294
+ Input tensor whose last dimension might need padding.
295
+
296
+ Returns:
297
+ --------
298
+ torch.Tensor
299
+ The original tensor if its last dimension was already odd,
300
+ or the padded tensor with an odd-sized last dimension.
301
+ """
302
+
303
+ last_dim_size = tensor.shape[-1]
304
+
305
+ if last_dim_size % 2 == 0:
306
+ tensor = nn.functional.pad(tensor, (0, 1))
307
+
308
+ return tensor
309
+
310
+ def polyphase_analysis(signal, filter_bank):
311
+ """
312
+ Applies the polyphase method to efficiently analyze the signal using a filter bank.
313
+
314
+ Parameters:
315
+ -----------
316
+ signal : torch.Tensor
317
+ Input signal tensor with shape (Batch x Channels x Length).
318
+
319
+ filter_bank : torch.Tensor
320
+ Filter bank tensor with shape (Bands x Length).
321
+
322
+ Returns:
323
+ --------
324
+ torch.Tensor
325
+ Signal split into sub-bands. (Batch x Channels x Bands x Length)
326
+ """
327
+
328
+ num_bands = filter_bank.shape[0]
329
+ num_channels = signal.shape[1]
330
+
331
+ # Rearrange signal for polyphase processing.
332
+ # Also combine Batch x Channel into one dimension for now.
333
+ #signal = rearrange(signal, "b c (t n) -> b (c n) t", n=num_bands)
334
+ signal = rearrange(signal, "b c (t n) -> (b c) n t", n=num_bands)
335
+
336
+ # Rearrange the filter bank for matching signal shape
337
+ filter_bank = rearrange(filter_bank, "c (t n) -> c n t", n=num_bands)
338
+
339
+ # Apply convolution with appropriate padding to maintain spatial dimensions
340
+ padding = filter_bank.shape[-1] // 2
341
+ filtered_signal = nn.functional.conv1d(signal, filter_bank, padding=padding)
342
+
343
+ # Truncate the last dimension post-convolution to adjust the output shape
344
+ filtered_signal = filtered_signal[..., :-1]
345
+ # Rearrange the first dimension back into Batch x Channels
346
+ filtered_signal = rearrange(filtered_signal, "(b c) n t -> b c n t", c=num_channels)
347
+
348
+ return filtered_signal
349
+
350
+ def polyphase_synthesis(signal, filter_bank):
351
+ """
352
+ Polyphase Inverse: Apply polyphase filter bank synthesis to reconstruct a signal.
353
+
354
+ Parameters
355
+ ----------
356
+ signal : torch.Tensor
357
+ Decomposed signal to be reconstructed (shape: Batch x Channels x Bands x Length).
358
+
359
+ filter_bank : torch.Tensor
360
+ Analysis filter bank (shape: Bands x Length).
361
+
362
+ should_rearrange : bool, optional
363
+ Flag to determine if the filters should be rearranged for polyphase synthesis. Default is True.
364
+
365
+ Returns
366
+ -------
367
+ torch.Tensor
368
+ Reconstructed signal (shape: Batch x Channels X Length)
369
+ """
370
+
371
+ num_bands = filter_bank.shape[0]
372
+ num_channels = signal.shape[1]
373
+
374
+ # Rearrange the filter bank
375
+ filter_bank = filter_bank.flip(-1)
376
+ filter_bank = rearrange(filter_bank, "c (t n) -> n c t", n=num_bands)
377
+
378
+ # Combine Batch x Channels into one dimension for now.
379
+ signal = rearrange(signal, "b c n t -> (b c) n t")
380
+
381
+ # Apply convolution with appropriate padding
382
+ padding_amount = filter_bank.shape[-1] // 2 + 1
383
+ reconstructed_signal = nn.functional.conv1d(signal, filter_bank, padding=int(padding_amount))
384
+
385
+ # Scale the result
386
+ reconstructed_signal = reconstructed_signal[..., :-1] * num_bands
387
+
388
+ # Reorganize the output and truncate
389
+ reconstructed_signal = reconstructed_signal.flip(1)
390
+ reconstructed_signal = rearrange(reconstructed_signal, "(b c) n t -> b c (t n)", c=num_channels, n=num_bands)
391
+ reconstructed_signal = reconstructed_signal[..., 2 * filter_bank.shape[1]:]
392
+
393
+ return reconstructed_signal