“fred-dev” commited on
Commit
5915064
·
1 Parent(s): 9eceef4

added repo files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +21 -0
  2. LICENSES/LICENSE_ADP.txt +21 -0
  3. LICENSES/LICENSE_DESCRIPT.txt +21 -0
  4. LICENSES/LICENSE_META.txt +21 -0
  5. LICENSES/LICENSE_NVIDIA.txt +21 -0
  6. LICENSES/LICENSE_XTRANSFORMERS.txt +21 -0
  7. app.py +16 -3
  8. defaults.ini +56 -0
  9. docs/autoencoders.md +354 -0
  10. docs/conditioning.md +151 -0
  11. docs/datasets.md +75 -0
  12. docs/pretransforms.md +43 -0
  13. model_config_float_conditioning_dit_all.json +208 -0
  14. pyproject.toml +3 -0
  15. run_gradio.py +31 -0
  16. run_tests.py +44 -0
  17. scripts/ds_zero_to_pl_ckpt.py +14 -0
  18. setup.py +46 -0
  19. stable_audio_tools/__init__.py +2 -0
  20. stable_audio_tools/configs/dataset_configs/custom_metadata/custom_md_example.py +4 -0
  21. stable_audio_tools/configs/dataset_configs/local_training_example.json +11 -0
  22. stable_audio_tools/configs/dataset_configs/s3_wds_example.json +10 -0
  23. stable_audio_tools/configs/model_configs/autoencoders/dac_2048_32_vae.json +71 -0
  24. stable_audio_tools/configs/model_configs/autoencoders/encodec_musicgen_rvq.json +88 -0
  25. stable_audio_tools/configs/model_configs/autoencoders/stable_audio_1_0_vae.json +111 -0
  26. stable_audio_tools/configs/model_configs/autoencoders/stable_audio_2_0_vae.json +122 -0
  27. stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_base.json +18 -0
  28. stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_base_16k.json +18 -0
  29. stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_base_44k.json +18 -0
  30. stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_large.json +18 -0
  31. stable_audio_tools/configs/model_configs/txt2audio/musicgen_small_finetune.json +22 -0
  32. stable_audio_tools/configs/model_configs/txt2audio/stable_audio_1_0.json +107 -0
  33. stable_audio_tools/configs/model_configs/txt2audio/stable_audio_2_0.json +127 -0
  34. stable_audio_tools/data/__init__.py +0 -0
  35. stable_audio_tools/data/dataset.py +597 -0
  36. stable_audio_tools/data/utils.py +96 -0
  37. stable_audio_tools/inference/__init__.py +0 -0
  38. stable_audio_tools/inference/generation.py +243 -0
  39. stable_audio_tools/inference/sampling.py +170 -0
  40. stable_audio_tools/inference/utils.py +35 -0
  41. stable_audio_tools/interface/__init__.py +0 -0
  42. stable_audio_tools/interface/gradio.py +782 -0
  43. stable_audio_tools/interface/testing.py +409 -0
  44. stable_audio_tools/models/__init__.py +1 -0
  45. stable_audio_tools/models/adp.py +1588 -0
  46. stable_audio_tools/models/autoencoders.py +800 -0
  47. stable_audio_tools/models/blocks.py +339 -0
  48. stable_audio_tools/models/bottleneck.py +326 -0
  49. stable_audio_tools/models/conditioners.py +558 -0
  50. stable_audio_tools/models/diffusion.py +678 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Stability AI
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
LICENSES/LICENSE_ADP.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 archinet.ai
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
LICENSES/LICENSE_DESCRIPT.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023-present, Descript
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
LICENSES/LICENSE_META.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) Meta Platforms, Inc. and affiliates.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
LICENSES/LICENSE_NVIDIA.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 NVIDIA CORPORATION.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
LICENSES/LICENSE_XTRANSFORMERS.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2020 Phil Wang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
app.py CHANGED
@@ -1,7 +1,20 @@
1
  import gradio as gr
 
 
 
 
 
2
 
3
  def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
1
  import gradio as gr
2
+ from stable_audio_tools import get_pretrained_model
3
+ from stable_audio_tools.interface.gradio import create_ui
4
+ import json
5
+
6
+ import torch
7
 
8
  def greet(name):
9
+ torch.manual_seed(42)
10
+
11
+ interface = create_ui(
12
+ model_config_path = "model_config_float_conditioning_dit_all.json",
13
+ ckpt_path="epoch=684-step=319200.ckpt",
14
+ pretrained_name="",
15
+ pretransform_ckpt_path="",
16
+ model_half=False
17
+ )
18
+ interface.queue()
19
+ interface.launch(share=True)
20
 
 
 
defaults.ini ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ [DEFAULTS]
3
+
4
+ #name of the run
5
+ name = stable_audio_tools
6
+
7
+ # the batch size
8
+ batch_size = 8
9
+
10
+ # number of GPUs to use for training
11
+ num_gpus = 1
12
+
13
+ # number of nodes to use for training
14
+ num_nodes = 1
15
+
16
+ # Multi-GPU strategy for PyTorch Lightning
17
+ strategy = ""
18
+
19
+ # Precision to use for training
20
+ precision = "16-mixed"
21
+
22
+ # number of CPU workers for the DataLoader
23
+ num_workers = 8
24
+
25
+ # the random seed
26
+ seed = 42
27
+
28
+ # Batches for gradient accumulation
29
+ accum_batches = 1
30
+
31
+ # Number of steps between checkpoints
32
+ checkpoint_every = 10000
33
+
34
+ # trainer checkpoint file to restart training from
35
+ ckpt_path = ''
36
+
37
+ # model checkpoint file to start a new training run from
38
+ pretrained_ckpt_path = ''
39
+
40
+ # Checkpoint path for the pretransform model if needed
41
+ pretransform_ckpt_path = ''
42
+
43
+ # configuration model specifying model hyperparameters
44
+ model_config = ''
45
+
46
+ # configuration for datasets
47
+ dataset_config = ''
48
+
49
+ # directory to save the checkpoints in
50
+ save_dir = ''
51
+
52
+ # gradient_clip_val passed into PyTorch Lightning Trainer
53
+ gradient_clip_val = 0.0
54
+
55
+ # remove the weight norm from the pretransform model
56
+ remove_pretransform_weight_norm = ''
docs/autoencoders.md ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Autoencoders
2
+ At a high level, autoencoders are models constructed of two parts: an *encoder*, and a *decoder*.
3
+
4
+ The *encoder* takes in an sequence (such as mono or stereo audio) and outputs a compressed representation of that sequence as a d-channel "latent sequence", usually heavily downsampled by a constant factor.
5
+
6
+ The *decoder* takes in a d-channel latent sequence and upsamples it back to the original input sequence length, reversing the compression of the encoder.
7
+
8
+ Autoencoders are trained with a combination of reconstruction and adversarial losses in order to create a compact and invertible representation of raw audio data that allows downstream models to work in a data-compressed "latent space", with various desirable and controllable properties such as reduced sequence length, noise resistance, and discretization.
9
+
10
+ The autoencoder architectures defined in `stable-audio-tools` are largely fully-convolutional, which allows autoencoders trained on small lengths to be applied to arbitrary-length sequences. For example, an autoencoder trained on 1-second samples could be used to encode 45-second inputs to a latent diffusion model.
11
+
12
+ # Model configs
13
+ The model config file for an autoencoder should set the `model_type` to `autoencoder`, and the `model` object should have the following properties:
14
+
15
+ - `encoder`
16
+ - Configuration for the autoencoder's encoder half
17
+ - `decoder`
18
+ - Configuration for the autoencoder's decoder half
19
+ - `latent_dim`
20
+ - Latent dimension of the autoencoder, used by inference scripts and downstream models
21
+ - `downsampling_ratio`
22
+ - Downsampling ratio between the input sequence and the latent sequence, used by inference scripts and downstream models
23
+ - `io_channels`
24
+ - Number of input and output channels for the autoencoder when they're the same, used by inference scripts and downstream models
25
+ - `bottleneck`
26
+ - Configuration for the autoencoder's bottleneck
27
+ - Optional
28
+ - `pretransform`
29
+ - A pretransform definition for the autoencoder, such as wavelet decomposition or another autoencoder
30
+ - See [pretransforms.md](pretransforms.md) for more information
31
+ - Optional
32
+ - `in_channels`
33
+ - Specifies the number of input channels for the autoencoder, when it's different from `io_channels`, such as in a mono-to-stereo model
34
+ - Optional
35
+ - `out_channels`
36
+ - Specifies the number of output channels for the autoencoder, when it's different from `io_channels`
37
+ - Optional
38
+
39
+ # Training configs
40
+ The `training` config in the autoencoder model config file should have the following properties:
41
+ - `learning_rate`
42
+ - The learning rate to use during training, fixed learning rate is currently the only option
43
+ - `use_ema`
44
+ - If true, a copy of the model weights is maintained during training and updated as an exponential moving average of the trained model's weights.
45
+ - Optional. Default: `false`
46
+ - `warmup_steps`
47
+ - The number of training steps before turning on adversarial losses
48
+ - Optional. Default: `0`
49
+ - `encoder_freeze_on_warmup`
50
+ - If true, freezes the encoder after the warmup steps have completed, so adversarial training only affects the decoder.
51
+ - Optional. Default: `false`
52
+ - `loss_configs`
53
+ - Configurations for the loss function calculation
54
+ - Optional
55
+
56
+ ## Loss configs
57
+ There are few different types of losses that are used for autoencoder training, including spectral losses, time-domain losses, adversarial losses, and bottleneck-specific losses.
58
+
59
+ Hyperparameters fo these losses as well as loss weighting factors can be configured in the `loss_configs` property in the `training` config.
60
+
61
+ ### Spectral losses
62
+ Multi-resolution STFT losses are the main reconstruction loss used for our audio autoencoders. We use the [auraloss](https://github.com/csteinmetz1/auraloss/tree/main/auraloss) library for our spectral loss functions.
63
+
64
+ For mono autoencoders (`io_channels` == 1), we use the [MultiResolutionSTFTLoss](https://github.com/csteinmetz1/auraloss/blob/1576b0cd6e927abc002b23cf3bfc455b660f663c/auraloss/freq.py#L329) module.
65
+
66
+ For stereo autoencoders (`io_channels` == 2), we use the [SumAndDifferenceSTFTLoss](https://github.com/csteinmetz1/auraloss/blob/1576b0cd6e927abc002b23cf3bfc455b660f663c/auraloss/freq.py#L533) module.
67
+
68
+ #### Example config
69
+ ```json
70
+ "spectral": {
71
+ "type": "mrstft",
72
+ "config": {
73
+ "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32],
74
+ "hop_sizes": [512, 256, 128, 64, 32, 16, 8],
75
+ "win_lengths": [2048, 1024, 512, 256, 128, 64, 32],
76
+ "perceptual_weighting": true
77
+ },
78
+ "weights": {
79
+ "mrstft": 1.0
80
+ }
81
+ }
82
+ ```
83
+
84
+ ### Time-domain loss
85
+ We compute the L1 distance between the original audio and the decoded audio to provide a time-domain loss.
86
+
87
+ #### Example config
88
+ ```json
89
+ "time": {
90
+ "type": "l1",
91
+ "weights": {
92
+ "l1": 0.1
93
+ }
94
+ }
95
+ ```
96
+
97
+ ### Adversarial losses
98
+ Adversarial losses bring in an ensemble of discriminator models to discriminate between real and fake audio, providing a signal to the autoencoder on perceptual discrepancies to fix.
99
+
100
+ We largely rely on the [multi-scale STFT discriminator](https://github.com/facebookresearch/encodec/blob/0e2d0aed29362c8e8f52494baf3e6f99056b214f/encodec/msstftd.py#L99) from the EnCodec repo
101
+
102
+ #### Example config
103
+ ```json
104
+ "discriminator": {
105
+ "type": "encodec",
106
+ "config": {
107
+ "filters": 32,
108
+ "n_ffts": [2048, 1024, 512, 256, 128],
109
+ "hop_lengths": [512, 256, 128, 64, 32],
110
+ "win_lengths": [2048, 1024, 512, 256, 128]
111
+ },
112
+ "weights": {
113
+ "adversarial": 0.1,
114
+ "feature_matching": 5.0
115
+ }
116
+ }
117
+ ```
118
+
119
+ ## Demo config
120
+ The only property to set for autoencoder training demos is the `demo_every` property, determining the number of steps between demos.
121
+
122
+ ### Example config
123
+ ```json
124
+ "demo": {
125
+ "demo_every": 2000
126
+ }
127
+ ```
128
+
129
+ # Encoder and decoder types
130
+ Encoders and decoders are defined separately in the model configuration, so encoders and decoders from different model architectures and libraries can be used interchangeably.
131
+
132
+ ## Oobleck
133
+ Oobleck is Harmonai's in-house autoencoder architecture, implementing features from a variety of other autoencoder architectures.
134
+
135
+ ### Example config
136
+ ```json
137
+ "encoder": {
138
+ "type": "oobleck",
139
+ "config": {
140
+ "in_channels": 2,
141
+ "channels": 128,
142
+ "c_mults": [1, 2, 4, 8],
143
+ "strides": [2, 4, 8, 8],
144
+ "latent_dim": 128,
145
+ "use_snake": true
146
+ }
147
+ },
148
+ "decoder": {
149
+ "type": "oobleck",
150
+ "config": {
151
+ "out_channels": 2,
152
+ "channels": 128,
153
+ "c_mults": [1, 2, 4, 8],
154
+ "strides": [2, 4, 8, 8],
155
+ "latent_dim": 64,
156
+ "use_snake": true,
157
+ "use_nearest_upsample": false
158
+ }
159
+ }
160
+ ```
161
+
162
+ ## DAC
163
+ This is the Encoder and Decoder definitions from the `descript-audio-codec` repo. It's a simple fully-convolutional autoencoder with channels doubling every level. The encoder and decoder configs are passed directly into the constructors for the DAC [Encoder](https://github.com/descriptinc/descript-audio-codec/blob/c7cfc5d2647e26471dc394f95846a0830e7bec34/dac/model/dac.py#L64) and [Decoder](https://github.com/descriptinc/descript-audio-codec/blob/c7cfc5d2647e26471dc394f95846a0830e7bec34/dac/model/dac.py#L115).
164
+
165
+ **Note: This does not include the DAC quantizer, and does not load pre-trained DAC models, this is just the encoder and decoder definitions.**
166
+
167
+ ### Example config
168
+ ```json
169
+ "encoder": {
170
+ "type": "dac",
171
+ "config": {
172
+ "in_channels": 2,
173
+ "latent_dim": 32,
174
+ "d_model": 128,
175
+ "strides": [2, 4, 4, 4]
176
+ }
177
+ },
178
+ "decoder": {
179
+ "type": "dac",
180
+ "config": {
181
+ "out_channels": 2,
182
+ "latent_dim": 32,
183
+ "channels": 1536,
184
+ "rates": [4, 4, 4, 2]
185
+ }
186
+ }
187
+ ```
188
+
189
+ ## SEANet
190
+ This is the SEANetEncoder and SEANetDecoder definitions from Meta's EnCodec repo. This is the same encoder and decoder architecture used in the EnCodec models used in MusicGen, without the quantizer.
191
+
192
+ The encoder and decoder configs are passed directly into the [SEANetEncoder](https://github.com/facebookresearch/encodec/blob/0e2d0aed29362c8e8f52494baf3e6f99056b214f/encodec/modules/seanet.py#L66C12-L66C12) and [SEANetDecoder](https://github.com/facebookresearch/encodec/blob/0e2d0aed29362c8e8f52494baf3e6f99056b214f/encodec/modules/seanet.py#L147) classes directly, though we reverse the input order of the strides (ratios) in the encoder to make it consistent with the order in the decoder.
193
+
194
+ ### Example config
195
+ ```json
196
+ "encoder": {
197
+ "type": "seanet",
198
+ "config": {
199
+ "channels": 2,
200
+ "dimension": 128,
201
+ "n_filters": 64,
202
+ "ratios": [4, 4, 8, 8],
203
+ "n_residual_layers": 1,
204
+ "dilation_base": 2,
205
+ "lstm": 2,
206
+ "norm": "weight_norm"
207
+ }
208
+ },
209
+ "decoder": {
210
+ "type": "seanet",
211
+ "config": {
212
+ "channels": 2,
213
+ "dimension": 64,
214
+ "n_filters": 64,
215
+ "ratios": [4, 4, 8, 8],
216
+ "n_residual_layers": 1,
217
+ "dilation_base": 2,
218
+ "lstm": 2,
219
+ "norm": "weight_norm"
220
+ }
221
+ },
222
+ ```
223
+
224
+ # Bottlenecks
225
+ In our terminology, the "bottleneck" of an autoencoder is a module placed between the encoder and decoder to enforce particular constraints on the latent space the encoder creates.
226
+
227
+ Bottlenecks have a similar interface to the autoencoder with `encode()` and `decode()` functions defined. Some bottlenecks return extra information in addition to the output latent series, such as quantized token indices, or additional losses to be considered during training.
228
+
229
+ To define a bottleneck for the autoencoder, you can provide the `bottleneck` object in the autoencoder's model configuration, with the following
230
+
231
+ ## VAE
232
+
233
+ The Variational Autoencoder (VAE) bottleneck splits the encoder's output in half along the channel dimension, treats the two halves as the "mean" and "scale" parameters for VAE sampling, and performs the latent sampling. At a basic level, the "scale" values determine the amount of noise to add to the "mean" latents, which creates a noise-resistant latent space where more of the latent space decodes to perceptually "valid" audio. This is particularly helpful for diffusion models where the outpus of the diffusion sampling process leave a bit of Gaussian error noise.
234
+
235
+ **Note: For the VAE bottleneck to work, the output dimension of the encoder must be twice the size of the input dimension for the decoder.**
236
+
237
+ ### Example config
238
+ ```json
239
+ "bottleneck": {
240
+ "type": "vae"
241
+ }
242
+ ```
243
+
244
+ ### Extra info
245
+ The VAE bottleneck also returns a `kl` value in the encoder info. This is the [KL divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence) between encoded/sampled latent space and a Gaussian distribution. By including this value as a loss value to optimize, we push our latent distribution closer to a normal distribution, potentially trading off reconstruction quality.
246
+
247
+ ### Example loss config
248
+ ```json
249
+ "bottleneck": {
250
+ "type": "kl",
251
+ "weights": {
252
+ "kl": 1e-4
253
+ }
254
+ }
255
+ ```
256
+
257
+ ## Tanh
258
+ This bottleneck applies the tanh function to the latent series, "soft-clipping" the latent values to be between -1 and 1. This is a quick and dirty way to enforce a limit on the variance of the latent space, but training these models can be unstable as it's seemingly easy for the latent space to saturate the values to -1 or 1 and never recover.
259
+
260
+ ### Example config
261
+ ```json
262
+ "bottleneck": {
263
+ "type": "tanh"
264
+ }
265
+ ```
266
+
267
+ ## Wasserstein
268
+ The Wasserstein bottleneck implements the WAE-MMD regularization method from the [Wasserstein Auto-Encoders](https://arxiv.org/abs/1711.01558) paper, calculating the Maximum Mean Discrepancy (MMD) between the latent space and a Gaussian distribution. Including this value as a loss value to optimize leads to a more Gaussian latent space, but does not require stochastic sampling as with a VAE, so the encoder is deterministic.
269
+
270
+ The Wasserstein bottleneck also exposes the `noise_augment_dim` property, which concatenates `noise_augment_dim` channels of Gaussian noise to the latent series before passing into the decoder. This adds some stochasticity to the latents which can be helpful for adversarial training, while keeping the encoder outputs deterministic.
271
+
272
+ **Note: The MMD calculation is very VRAM-intensive for longer sequence lengths, so training a Wasserstein autoencoder is best done on autoencoders with a decent downsampling factor, or on short sequence lengths. For inference, the MMD calculation is disabled.**
273
+
274
+ ### Example config
275
+ ```json
276
+ "bottleneck": {
277
+ "type": "wasserstein"
278
+ }
279
+ ```
280
+
281
+ ### Extra info
282
+ This bottleneck adds the `mmd` value to the encoder info, representing the Maximum Mean Discrepancy.
283
+
284
+ ### Example loss config
285
+ ```json
286
+ "bottleneck": {
287
+ "type": "mmd",
288
+ "weights": {
289
+ "mmd": 100
290
+ }
291
+ }
292
+ ```
293
+
294
+ ## L2 normalization (Spherical autoencoder)
295
+ The L2 normalization bottleneck normalizes the latents across the channel-dimension, projecting the latents to a d-dimensional hypersphere. This acts as a form of latent space normalization.
296
+
297
+
298
+ ### Example config
299
+ ```json
300
+ "bottleneck": {
301
+ "type": "l2_norm"
302
+ }
303
+ ```
304
+
305
+
306
+ ## RVQ
307
+ Residual vector quantization (RVQ) is currently the leading method for learning discrete neural audio codecs (tokenizers for audio). In vector quantization, each item in the latent sequence is individually "snapped" to the nearest vector in a discrete "codebook" of learned vectors. The index of the vector in the codebook can then be used as a token index for things like autoregressive transformers. Residual vector quantization improves the precision of normal vector quantization by adding additional codebooks. For a deeper dive into RVQ, check out [this blog post by Dr. Scott Hawley](https://drscotthawley.github.io/blog/posts/2023-06-12-RVQ.html).
308
+
309
+ This RVQ bottleneck uses [lucidrains' implementation](https://github.com/lucidrains/vector-quantize-pytorch/tree/master) from the `vector-quantize-pytorch` repo, which provides a lot of different quantizer options. The bottleneck config is passed through to the `ResidualVQ` [constructor](https://github.com/lucidrains/vector-quantize-pytorch/blob/0c6cea24ce68510b607f2c9997e766d9d55c085b/vector_quantize_pytorch/residual_vq.py#L26).
310
+
311
+ **Note: This RVQ implementation uses manual replacement of codebook vectors to reduce codebook collapse. This does not work with multi-GPU training as the random replacement is not synchronized across devices.**
312
+
313
+ ### Example config
314
+ ```json
315
+ "bottleneck": {
316
+ "type": "rvq",
317
+ "config": {
318
+ "num_quantizers": 4,
319
+ "codebook_size": 2048,
320
+ "dim": 1024,
321
+ "decay": 0.99,
322
+ }
323
+ }
324
+ ```
325
+
326
+ ## DAC RVQ
327
+ This is the residual vector quantization implementation from the `descript-audio-codec` repo. It differs from the above implementation in that it does not use manual replacements to improve codebook usage, but instead uses learnable linear layers to project the latents down to a lower-dimensional space before performing the individual quantization operations. This means it's compatible with distributed training.
328
+
329
+ The bottleneck config is passed directly into the `ResidualVectorQuantize` [constructor](https://github.com/descriptinc/descript-audio-codec/blob/c7cfc5d2647e26471dc394f95846a0830e7bec34/dac/nn/quantize.py#L97).
330
+
331
+ The `quantize_on_decode` property is also exposed, which moves the quantization process to the decoder. This should not be used during training, but is helpful when training latent diffusion models that use the quantization process as a way to remove error after the diffusion sampling process.
332
+
333
+ ### Example config
334
+ ```json
335
+ "bottleneck": {
336
+ "type": "dac_rvq",
337
+ "config": {
338
+ "input_dim": 64,
339
+ "n_codebooks": 9,
340
+ "codebook_dim": 32,
341
+ "codebook_size": 1024,
342
+ "quantizer_dropout": 0.5
343
+ }
344
+ }
345
+ ```
346
+
347
+ ### Extra info
348
+ The DAC RVQ bottleneck also adds the following properties to the `info` object:
349
+ - `pre_quantizer`
350
+ - The pre-quantization latent series, useful in combination with `quantize_on_decode` for training latent diffusion models.
351
+ - `vq/commitment_loss`
352
+ - Commitment loss for the quantizer
353
+ - `vq/codebook_loss`
354
+ - Codebook loss for the quantizer
docs/conditioning.md ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Conditioning
2
+ Conditioning, in the context of `stable-audio-tools` is the use of additional signals in a model that are used to add an additional level of control over the model's behavior. For example, we can condition the outputs of a diffusion model on a text prompt, creating a text-to-audio model.
3
+
4
+ # Conditioning types
5
+ There are a few different kinds of conditioning depending on the conditioning signal being used.
6
+
7
+ ## Cross attention
8
+ Cross attention is a type of conditioning that allows us to find correlations between two sequences of potentially different lengths. For example, cross attention allows us to find correlations between a sequence of features from a text encoder and a sequence of high-level audio features.
9
+
10
+ Signals used for cross-attention conditioning should be of the shape `[batch, sequence, channels]`.
11
+
12
+ ## Global conditioning
13
+ Global conditioning is the use of a single n-dimensional tensor to provide conditioning information that pertains to the whole sequence being conditioned. For example, this could be the single embedding output of a CLAP model, or a learned class embedding.
14
+
15
+ Signals used for global conditioning should be of the shape `[batch, channels]`.
16
+
17
+ ## Input concatenation
18
+ Input concatenation applies a spatial conditioning signal to the model that correlates in the sequence dimension with the model's input, and is of the same length. The conditioning signal will be concatenated with the model's input data along the channel dimension. This can be used for things like inpainting information, melody conditioning, or for creating a diffusion autoencoder.
19
+
20
+ Signals used for input concatenation conditioning should be of the shape `[batch, channels, sequence]` and must be the same length as the model's input.
21
+
22
+ # Conditioners and conditioning configs
23
+ `stable-audio-tools` uses Conditioner modules to translate human-readable metadata such as text prompts or a number of seconds into tensors that the model can take as input.
24
+
25
+ Each conditioner has a corresponding `id` that it expects to find in the conditioning dictionary provided during training or inference. Each conditioner takes in the relevant conditioning data and returns a tuple containing the corresponding tensor and a mask.
26
+
27
+ The ConditionedDiffusionModelWrapper manages the translation between the user-provided metadata dictionary (e.g. `{"prompt": "a beautiful song", "seconds_start": 22, "seconds_total": 193}`) and the dictionary of different conditioning types that the model uses (e.g. `{"cross_attn_cond": ...}`).
28
+
29
+ To apply conditioning to a model, you must provide a `conditioning` configuration in the model's config. At the moment, we only support conditioning diffusion models though the `diffusion_cond` model type.
30
+
31
+ The `conditioning` configuration should contain a `configs` array, which allows you to define multiple conditioning signals.
32
+
33
+ Each item in `configs` array should define the `id` for the corresponding metadata, the type of conditioner to be used, and the config for that conditioner.
34
+
35
+ The `cond_dim` property is used to enforce the same dimension on all conditioning inputs, however that can be overridden with an explicit `output_dim` property on any of the individual configs.
36
+
37
+ ## Example config
38
+ ```json
39
+ "conditioning": {
40
+ "configs": [
41
+ {
42
+ "id": "prompt",
43
+ "type": "t5",
44
+ "config": {
45
+ "t5_model_name": "t5-base",
46
+ "max_length": 77,
47
+ "project_out": true
48
+ }
49
+ }
50
+ ],
51
+ "cond_dim": 768
52
+ }
53
+ ```
54
+
55
+ # Conditioners
56
+
57
+ ## Text encoders
58
+
59
+ ### `t5`
60
+ This uses a frozen [T5](https://huggingface.co/docs/transformers/model_doc/t5) text encoder from the `transformers` library to encode text prompts into a sequence of text features.
61
+
62
+ The `t5_model_name` property determines which T5 model is loaded from the `transformers` library.
63
+
64
+ The `max_length` property determines the maximum number of tokens that the text encoder will take in, as well as the sequence length of the output text features.
65
+
66
+ If you set `enable_grad` to `true`, the T5 model will be un-frozen and saved with the model checkpoint, allowing you to fine-tune the T5 model.
67
+
68
+ T5 encodings are only compatible with cross attention conditioning.
69
+
70
+ #### Example config
71
+ ```json
72
+ {
73
+ "id": "prompt",
74
+ "type": "t5",
75
+ "config": {
76
+ "t5_model_name": "t5-base",
77
+ "max_length": 77,
78
+ "project_out": true
79
+ }
80
+ }
81
+ ```
82
+
83
+ ### `clap_text`
84
+ This loads the text encoder from a [CLAP](https://github.com/LAION-AI/CLAP) model, which can provide either a sequence of text features, or a single multimodal text/audio embedding.
85
+
86
+ The CLAP model must be provided with a local file path, set in the `clap_ckpt_path` property,along with the correct `audio_model_type` and `enable_fusion` properties for the provided model.
87
+
88
+ If the `use_text_features` property is set to `true`, the conditioner output will be a sequence of text features, instead of a single multimodal embedding. This allows for more fine-grained text information to be used by the model, at the cost of losing the ability to prompt with CLAP audio embeddings.
89
+
90
+ By default, if `use_text_features` is true, the last layer of the CLAP text encoder's features are returned. You can return the text features of earlier layers by specifying the index of the layer to return in the `feature_layer_ix` property. For example, you can return the text features of the next-to-last layer of the CLAP model by setting `feature_layer_ix` to `-2`.
91
+
92
+ If you set `enable_grad` to `true`, the CLAP model will be un-frozen and saved with the model checkpoint, allowing you to fine-tune the CLAP model.
93
+
94
+ CLAP text embeddings are compatible with global conditioning and cross attention conditioning. If `use_text_features` is set to `true`, the features are not compatible with global conditioning.
95
+
96
+ #### Example config
97
+ ```json
98
+ {
99
+ "id": "prompt",
100
+ "type": "clap_text",
101
+ "config": {
102
+ "clap_ckpt_path": "/path/to/clap/model.ckpt",
103
+ "audio_model_type": "HTSAT-base",
104
+ "enable_fusion": true,
105
+ "use_text_features": true,
106
+ "feature_layer_ix": -2
107
+ }
108
+ }
109
+ ```
110
+
111
+ ## Number encoders
112
+
113
+ ### `int`
114
+ The IntConditioner takes in a list of integers in a given range, and returns a discrete learned embedding for each of those integers.
115
+
116
+ The `min_val` and `max_val` properties set the range of the embedding values. Input integers are clamped to this range.
117
+
118
+ This can be used for things like discrete timing embeddings, or learned class embeddings.
119
+
120
+ Int embeddings are compatible with global conditioning and cross attention conditioning.
121
+
122
+ #### Example config
123
+ ```json
124
+ {
125
+ "id": "seconds_start",
126
+ "type": "int",
127
+ "config": {
128
+ "min_val": 0,
129
+ "max_val": 512
130
+ }
131
+ }
132
+ ```
133
+
134
+ ### `number`
135
+ The NumberConditioner takes in a a list of floats in a given range, and returns a continuous Fourier embedding of the provided floats.
136
+
137
+ The `min_val` and `max_val` properties set the range of the float values. This is the range used to normalize the input float values.
138
+
139
+ Number embeddings are compatible with global conditioning and cross attention conditioning.
140
+
141
+ #### Example config
142
+ ```json
143
+ {
144
+ "id": "seconds_total",
145
+ "type": "number",
146
+ "config": {
147
+ "min_val": 0,
148
+ "max_val": 512
149
+ }
150
+ }
151
+ ```
docs/datasets.md ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Datasets
2
+ `stable-audio-tools` supports loading data from local file storage, as well as loading audio files and JSON files in the [WebDataset](https://github.com/webdataset/webdataset/tree/main/webdataset) format from Amazon S3 buckets.
3
+
4
+ # Dataset configs
5
+ To specify the dataset used for training, you must provide a dataset config JSON file to `train.py`.
6
+
7
+ The dataset config consists of a `dataset_type` property specifying the type of data loader to use, a `datasets` array to provide multiple data sources, and a `random_crop` property, which decides if the cropped audio from the training samples is from a random place in the audio file, or always from the beginning.
8
+
9
+ ## Local audio files
10
+ To use a local directory of audio samples, set the `dataset_type` property in your dataset config to `"audio_dir"`, and provide a list of objects to the `datasets` property including the `path` property, which should be the path to your directory of audio samples.
11
+
12
+ This will load all of the compatible audio files from the provided directory and all subdirectories.
13
+
14
+ ### Example config
15
+ ```json
16
+ {
17
+ "dataset_type": "audio_dir",
18
+ "datasets": [
19
+ {
20
+ "id": "my_audio",
21
+ "path": "/path/to/audio/dataset/"
22
+ }
23
+ ],
24
+ "random_crop": true
25
+ }
26
+ ```
27
+
28
+ ## S3 WebDataset
29
+ To load audio files and related metadata from .tar files in the WebDataset format hosted in Amazon S3 buckets, you can set the `dataset_type` property to `s3`, and provide the `datasets` parameter with a list of objects containing the AWS S3 path to the shared S3 bucket prefix of the WebDataset .tar files. The S3 bucket will be searched recursively given the path, and assumes any .tar files found contain audio files and corresponding JSON files where the related files differ only in file extension (e.g. "000001.flac", "000001.json", "00002.flac", "00002.json", etc.)
30
+
31
+ ### Example config
32
+ ```json
33
+ {
34
+ "dataset_type": "s3",
35
+ "datasets": [
36
+ {
37
+ "id": "s3-test",
38
+ "s3_path": "s3://my-bucket/datasets/webdataset/audio/"
39
+ }
40
+ ],
41
+ "random_crop": true
42
+ }
43
+ ```
44
+
45
+ # Custom metadata
46
+ To customize the metadata provided to the conditioners during model training, you can provide a separate custom metadata module to the dataset config. This metadata module should be a Python file that must contain a function called `get_custom_metadata` that takes in two parameters, `info`, and `audio`, and returns a dictionary.
47
+
48
+ For local training, the `info` parameter will contain a few pieces of information about the loaded audio file, such as the path, and information about how the audio was cropped from the original training sample. For S3 WebDataset datasets, it will also contain the metadata from the related JSON files.
49
+
50
+ The `audio` parameter contains the audio sample that will be passed to the model at training time. This lets you analyze the audio for extra properties that you can then pass in as extra conditioning signals.
51
+
52
+ The dictionary returned from the `get_custom_metadata` function will have its properties added to the `metadata` object used at training time. For more information on how conditioning works, please see the [Conditioning documentation](./conditioning.md)
53
+
54
+ ## Example config and custom metadata module
55
+ ```json
56
+ {
57
+ "dataset_type": "audio_dir",
58
+ "datasets": [
59
+ {
60
+ "id": "my_audio",
61
+ "path": "/path/to/audio/dataset/"
62
+ }
63
+ ],
64
+ "custom_metadata_module": "/path/to/custom_metadata.py",
65
+ "random_crop": true
66
+ }
67
+ ```
68
+
69
+ `custom_metadata.py`:
70
+ ```py
71
+ def get_custom_metadata(info, audio):
72
+
73
+ # Pass in the relative path of the audio file as the prompt
74
+ return {"prompt": info["relpath"]}
75
+ ```
docs/pretransforms.md ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Pretransforms
2
+ Many models require some fixed transform to be applied to the input audio before the audio is passed in to the trainable layers of the model, as well as a corresponding inverse transform to be applied to the outputs of the model. We refer to these as "pretransforms".
3
+
4
+ At the moment, `stable-audio-tools` supports two pretransforms, frozen autoencoders for latent diffusion models and wavelet decompositions.
5
+
6
+ Pretransforms have a similar interface to autoencoders with "encode" and "decode" functions defined for each pretransform.
7
+
8
+ ## Autoencoder pretransform
9
+ To define a model with an autoencoder pretransform, you can define the "pretransform" property in the model config, with the `type` property set to `autoencoder`. The `config` property should be an autoencoder model definition.
10
+
11
+ Example:
12
+ ```json
13
+ "pretransform": {
14
+ "type": "autoencoder",
15
+ "config": {
16
+ "encoder": {
17
+ ...
18
+ },
19
+ "decoder": {
20
+ ...
21
+ }
22
+ ...normal autoencoder configuration
23
+ }
24
+ }
25
+ ```
26
+
27
+ ### Latent rescaling
28
+ The original [Latent Diffusion paper](https://arxiv.org/abs/2112.10752) found that rescaling the latent series to unit variance before performing diffusion improved quality. To this end, we expose a `scale` property on autoencoder pretransforms that will take care of this rescaling. The scale should be set to the original standard deviation of the latents, which can be determined experimentally, or by looking at the `latent_std` value during training. The pretransform code will divide by this scale factor in the `encode` function and multiply by this scale in the `decode` function.
29
+
30
+ ## Wavelet pretransform
31
+ `stable-audio-tools` also exposes wavelet decomposition as a pretransform. Wavelet decomposition is a quick way to trade off sequence length for channels in autoencoders, while maintaining a multi-band implicit bias.
32
+
33
+ Wavelet pretransforms take the following properties:
34
+
35
+ - `channels`
36
+ - The number of input and output audio channels for the wavelet transform
37
+ - `levels`
38
+ - The number of successive wavelet decompositions to perform. Each level doubles the channel count and halves the sequence length
39
+ - `wavelet`
40
+ - The specific wavelet from [PyWavelets](https://pywavelets.readthedocs.io/en/latest/ref/wavelets.html) to use, currently limited to `"bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"`
41
+
42
+ ## Future work
43
+ We hope to add more filters and transforms to this list, including PQMF and STFT transforms.
model_config_float_conditioning_dit_all.json ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "diffusion_cond",
3
+ "sample_size": 1048576,
4
+ "sample_rate": 44100,
5
+ "audio_channels": 1,
6
+ "model": {
7
+ "pretransform": {
8
+ "type": "autoencoder",
9
+ "iterate_batch": true,
10
+ "config": {
11
+ "encoder": {
12
+ "type": "dac",
13
+ "config": {
14
+ "in_channels": 1,
15
+ "latent_dim": 128,
16
+ "d_model": 128,
17
+ "strides": [4, 4, 8, 8]
18
+ }
19
+ },
20
+ "decoder": {
21
+ "type": "dac",
22
+ "config": {
23
+ "out_channels": 1,
24
+ "latent_dim": 64,
25
+ "channels": 1536,
26
+ "rates": [8, 8, 4, 4]
27
+ }
28
+ },
29
+ "bottleneck": {
30
+ "type": "vae"
31
+ },
32
+ "latent_dim": 64,
33
+ "downsampling_ratio": 1024,
34
+ "io_channels": 1
35
+ }
36
+ },
37
+ "conditioning": {
38
+ "configs": [
39
+ {
40
+ "id": "latitude",
41
+ "type": "number",
42
+ "config": {
43
+ "min_val": -54.617412,
44
+ "max_val": -10.13994
45
+ }
46
+ },
47
+ {
48
+ "id": "longitude",
49
+ "type": "number",
50
+ "config": {
51
+ "min_val": 96.8233,
52
+ "max_val": 167.9619
53
+ }
54
+ },
55
+ {
56
+ "id": "temperature",
57
+ "type": "number",
58
+ "config": {
59
+ "min_val": -10.0,
60
+ "max_val": 55.0
61
+ }
62
+ },
63
+ {
64
+ "id": "humidity",
65
+ "type": "number",
66
+ "config": {
67
+ "min_val": 1,
68
+ "max_val": 100.0
69
+ }
70
+ },
71
+ {
72
+ "id": "wind_speed",
73
+ "type": "number",
74
+ "config": {
75
+ "min_val": 0,
76
+ "max_val": 50.0
77
+ }
78
+ },
79
+ {
80
+ "id": "pressure",
81
+ "type": "number",
82
+ "config": {
83
+ "min_val": 800.0,
84
+ "max_val": 1200.0
85
+ }
86
+ },
87
+ {
88
+ "id": "minutes_of_day",
89
+ "type": "number",
90
+ "config": {
91
+ "min_val": 0,
92
+ "max_val": 1439
93
+ }
94
+ },
95
+ {
96
+ "id": "day_of_year",
97
+ "type": "number",
98
+ "config": {
99
+ "min_val": 1,
100
+ "max_val": 366
101
+ }
102
+ },
103
+ {
104
+ "id": "seconds_start",
105
+ "type": "number",
106
+ "config": {
107
+ "min_val": 0,
108
+ "max_val": 512
109
+ }
110
+ },
111
+ {
112
+ "id": "seconds_total",
113
+ "type": "number",
114
+ "config": {
115
+ "min_val": 0,
116
+ "max_val": 512
117
+ }
118
+ }
119
+ ],
120
+ "cond_dim": 768
121
+ },
122
+ "diffusion": {
123
+ "cross_attention_cond_ids": ["latitude", "longitude", "temperature", "humidity", "wind_speed", "pressure", "minutes_of_day", "day_of_year","seconds_start", "seconds_total"],
124
+ "global_cond_ids": ["seconds_start", "seconds_total"],
125
+ "type": "dit",
126
+ "config": {
127
+ "io_channels": 64,
128
+ "embed_dim": 768,
129
+ "depth": 24,
130
+ "num_heads": 24,
131
+ "cond_token_dim": 768,
132
+ "global_cond_dim": 1536,
133
+ "project_cond_tokens": false,
134
+ "transformer_type": "continuous_transformer"
135
+ }
136
+ },
137
+ "io_channels": 64
138
+ },
139
+
140
+ "training": {
141
+ "use_ema": true,
142
+ "log_loss_info": false,
143
+ "optimizer_configs": {
144
+ "diffusion": {
145
+ "optimizer": {
146
+ "type": "AdamW",
147
+ "config": {
148
+ "lr": 5e-5,
149
+ "betas": [0.9, 0.999],
150
+ "weight_decay": 1e-3
151
+ }
152
+ },
153
+ "scheduler": {
154
+ "type": "InverseLR",
155
+ "config": {
156
+ "inv_gamma": 1000000,
157
+ "power": 0.5,
158
+ "warmup": 0.99
159
+ }
160
+ }
161
+ }
162
+ },
163
+ "demo": {
164
+ "demo_every": 2500,
165
+ "demo_steps": 100,
166
+ "num_demos": 3,
167
+ "demo_cfg_scales": [3, 5, 7],
168
+ "demo_cond": [
169
+ {
170
+ "latitude": -24.005512,
171
+ "longitude": 133.368348,
172
+ "temperature": 25.5,
173
+ "humidity": 60,
174
+ "wind_speed": 8,
175
+ "pressure": 1000,
176
+ "minutes_of_day": 400,
177
+ "day_of_year": 110,
178
+ "seconds_start": 0,
179
+ "seconds_total": 22
180
+ },
181
+ {
182
+ "latitude": -26.987815,
183
+ "longitude": 153.129068,
184
+ "temperature": 31.5,
185
+ "humidity": 70,
186
+ "wind_speed": 12,
187
+ "pressure": 1010,
188
+ "minutes_of_day": 600,
189
+ "day_of_year": 57,
190
+ "seconds_start": 0,
191
+ "seconds_total": 22
192
+ },
193
+ {
194
+ "latitude": -12.546364,
195
+ "longitude": 130.919605,
196
+ "temperature": 28.5,
197
+ "humidity": 60,
198
+ "wind_speed": 18,
199
+ "pressure": 1015,
200
+ "minutes_of_day": 1140,
201
+ "day_of_year": 280,
202
+ "seconds_start": 0,
203
+ "seconds_total": 22
204
+ }
205
+ ]
206
+ }
207
+ }
208
+ }
pyproject.toml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools"]
3
+ build-backend = "setuptools.build_meta"
run_gradio.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from stable_audio_tools import get_pretrained_model
2
+ from stable_audio_tools.interface.gradio import create_ui
3
+ import json
4
+
5
+ import torch
6
+
7
+ def main(args):
8
+ torch.manual_seed(42)
9
+
10
+ interface = create_ui(
11
+ model_config_path = args.model_config,
12
+ ckpt_path=args.ckpt_path,
13
+ pretrained_name=args.pretrained_name,
14
+ pretransform_ckpt_path=args.pretransform_ckpt_path,
15
+ model_half=args.model_half
16
+ )
17
+ interface.queue()
18
+ interface.launch(share=True, auth=(args.username, args.password) if args.username is not None else None)
19
+
20
+ if __name__ == "__main__":
21
+ import argparse
22
+ parser = argparse.ArgumentParser(description='Run gradio interface')
23
+ parser.add_argument('--pretrained-name', type=str, help='Name of pretrained model', required=False)
24
+ parser.add_argument('--model-config', type=str, help='Path to model config', required=False)
25
+ parser.add_argument('--ckpt-path', type=str, help='Path to model checkpoint', required=False)
26
+ parser.add_argument('--pretransform-ckpt-path', type=str, help='Optional to model pretransform checkpoint', required=False)
27
+ parser.add_argument('--username', type=str, help='Gradio username', required=False)
28
+ parser.add_argument('--password', type=str, help='Gradio password', required=False)
29
+ parser.add_argument('--model-half', action='store_true', help='Whether to use half precision', required=False)
30
+ args = parser.parse_args()
31
+ main(args)
run_tests.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from stable_audio_tools import get_pretrained_model
2
+ from stable_audio_tools.interface.testing import runTests
3
+ print(runTests) # Check if it prints a function reference
4
+
5
+
6
+ import torch
7
+
8
+ def main(args):
9
+ torch.manual_seed(42)
10
+ runTests(model_config_path = args.model_config,
11
+ ckpt_path=args.ckpt_path,
12
+ pretrained_name=args.pretrained_name,
13
+ pretransform_ckpt_path=args.pretransform_ckpt_path,
14
+ model_half=args.model_half,
15
+ output_dir=args.output_dir,
16
+ json_dir=args.json_dir
17
+ )
18
+
19
+
20
+
21
+
22
+
23
+ if __name__ == "__main__":
24
+ import argparse
25
+ import sys
26
+ parser = argparse.ArgumentParser(description='Run generation tests')
27
+ parser.add_argument('--pretrained-name', type=str, help='Name of pretrained model', required=False)
28
+ parser.add_argument('--model-config', type=str, help='Path to model config', required=False)
29
+ parser.add_argument('--ckpt-path', type=str, help='Path to model checkpoint', required=False)
30
+ parser.add_argument('--pretransform-ckpt-path', type=str, help='Optional to model pretransform checkpoint', required=False)
31
+ parser.add_argument('--model-half', action='store_true', help='Whether to use half precision', required=False)
32
+ parser.add_argument('--output-dir', type=str, help='Path to output directory', required=True)
33
+ parser.add_argument('--json-dir', type=str, help='Path to directory containing JSON files', required=True)
34
+ print("Running tests")
35
+
36
+ print("Arguments provided:", sys.argv[1:])
37
+
38
+ args = parser.parse_args()
39
+ print("Parsed arguments:", args)
40
+ main(args)
41
+
42
+
43
+
44
+
scripts/ds_zero_to_pl_ckpt.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from lightning.pytorch.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict
3
+
4
+ if __name__ == "__main__":
5
+
6
+ parser = argparse.ArgumentParser()
7
+ parser.add_argument("--save_path", type=str, help="Path to the zero checkpoint")
8
+ parser.add_argument("--output_path", type=str, help="Path to the output checkpoint", default="lightning_model.pt")
9
+ args = parser.parse_args()
10
+
11
+ # lightning deepspeed has saved a directory instead of a file
12
+ save_path = args.save_path
13
+ output_path = args.output_path
14
+ convert_zero_checkpoint_to_fp32_state_dict(save_path, output_path)
setup.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name='stable-audio-tools',
5
+ version='0.0.12',
6
+ url='https://github.com/Stability-AI/stable-audio-tools.git',
7
+ author='Stability AI',
8
+ description='Training and inference tools for generative audio models from Stability AI',
9
+ packages=find_packages(),
10
+ install_requires=[
11
+ 'audiocraft==1.0.0',
12
+ 'aeiou==0.0.20',
13
+ 'alias-free-torch==0.0.6',
14
+ 'auraloss==0.4.0',
15
+ 'descript-audio-codec==1.0.0',
16
+ 'einops==0.7.0',
17
+ 'einops-exts==0.0.4',
18
+ 'ema-pytorch==0.2.3',
19
+ 'encodec==0.1.1',
20
+ 'flash-attn>=2.5.0',
21
+ 'gradio>=3.42.0',
22
+ 'huggingface_hub',
23
+ 'importlib-resources==5.12.0',
24
+ 'k-diffusion==0.1.1',
25
+ 'laion-clap==1.1.4',
26
+ 'local-attention==1.8.6',
27
+ 'pandas==2.0.2',
28
+ 'pedalboard==0.7.4',
29
+ 'prefigure==0.0.9',
30
+ 'pytorch_lightning==2.1.0',
31
+ 'PyWavelets==1.4.1',
32
+ 'safetensors',
33
+ 'sentencepiece==0.1.99',
34
+ 's3fs',
35
+ 'torch>=2.0.1',
36
+ 'torchaudio>=2.0.2',
37
+ 'torchmetrics==0.11.4',
38
+ 'tqdm',
39
+ 'transformers==4.33.3',
40
+ 'v-diffusion-pytorch==0.0.2',
41
+ 'vector-quantize-pytorch==1.9.14',
42
+ 'wandb==0.15.4',
43
+ 'webdataset==0.2.48',
44
+ 'x-transformers<1.27.0'
45
+ ],
46
+ )
stable_audio_tools/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .models.factory import create_model_from_config, create_model_from_config_path
2
+ from .models.pretrained import get_pretrained_model
stable_audio_tools/configs/dataset_configs/custom_metadata/custom_md_example.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ def get_custom_metadata(info, audio):
2
+
3
+ # Use relative path as the prompt
4
+ return {"prompt": info["relpath"]}
stable_audio_tools/configs/dataset_configs/local_training_example.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_type": "audio_dir",
3
+ "datasets": [
4
+ {
5
+ "id": "my_audio",
6
+ "path": "/path/to/audio/dataset/"
7
+ }
8
+ ],
9
+ "custom_metadata_module": "/path/to/custom_metadata/custom_md_example.py",
10
+ "random_crop": true
11
+ }
stable_audio_tools/configs/dataset_configs/s3_wds_example.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_type": "s3",
3
+ "datasets": [
4
+ {
5
+ "id": "s3-test",
6
+ "s3_path": "s3://my-bucket/datasets/webdataset/audio/"
7
+ }
8
+ ],
9
+ "random_crop": true
10
+ }
stable_audio_tools/configs/model_configs/autoencoders/dac_2048_32_vae.json ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "autoencoder",
3
+ "sample_size": 65536,
4
+ "sample_rate": 44100,
5
+ "audio_channels": 1,
6
+ "model": {
7
+ "encoder": {
8
+ "type": "dac",
9
+ "config": {
10
+ "latent_dim": 64,
11
+ "d_model": 128,
12
+ "strides": [4, 8, 8, 8]
13
+ }
14
+ },
15
+ "decoder": {
16
+ "type": "dac",
17
+ "config": {
18
+ "latent_dim": 32,
19
+ "channels": 1536,
20
+ "rates": [8, 8, 8, 4]
21
+ }
22
+ },
23
+ "bottleneck": {
24
+ "type": "vae"
25
+ },
26
+ "latent_dim": 32,
27
+ "downsampling_ratio": 2048,
28
+ "io_channels": 1
29
+ },
30
+ "training": {
31
+ "learning_rate": 1e-4,
32
+ "warmup_steps": 0,
33
+ "use_ema": false,
34
+ "loss_configs": {
35
+ "discriminator": {
36
+ "type": "encodec",
37
+ "config": {
38
+ "filters": 32,
39
+ "n_ffts": [2048, 1024, 512, 256, 128, 64, 32],
40
+ "hop_lengths": [512, 256, 128, 64, 32, 16, 8],
41
+ "win_lengths": [2048, 1024, 512, 256, 128, 64, 32]
42
+ },
43
+ "weights": {
44
+ "adversarial": 0.1,
45
+ "feature_matching": 5.0
46
+ }
47
+ },
48
+ "spectral": {
49
+ "type": "mrstft",
50
+ "config": {
51
+ "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32],
52
+ "hop_sizes": [512, 256, 128, 64, 32, 16, 8],
53
+ "win_lengths": [2048, 1024, 512, 256, 128, 64, 32],
54
+ "perceptual_weighting": true
55
+ },
56
+ "weights": {
57
+ "mrstft": 1.0
58
+ }
59
+ },
60
+ "time": {
61
+ "type": "l1",
62
+ "weights": {
63
+ "l1": 0.0
64
+ }
65
+ }
66
+ },
67
+ "demo": {
68
+ "demo_every": 2000
69
+ }
70
+ }
71
+ }
stable_audio_tools/configs/model_configs/autoencoders/encodec_musicgen_rvq.json ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "autoencoder",
3
+ "sample_size": 32000,
4
+ "sample_rate": 32000,
5
+ "audio_channels": 1,
6
+ "model": {
7
+ "encoder": {
8
+ "type": "seanet",
9
+ "config": {
10
+ "channels": 1,
11
+ "dimension": 128,
12
+ "n_filters": 64,
13
+ "ratios": [4, 4, 5, 8],
14
+ "n_residual_layers": 1,
15
+ "dilation_base": 2,
16
+ "lstm": 2,
17
+ "norm": "weight_norm"
18
+ }
19
+ },
20
+ "decoder": {
21
+ "type": "seanet",
22
+ "config": {
23
+ "channels": 1,
24
+ "dimension": 128,
25
+ "n_filters": 64,
26
+ "ratios": [4, 4, 5, 8],
27
+ "n_residual_layers": 1,
28
+ "dilation_base": 2,
29
+ "lstm": 2,
30
+ "norm": "weight_norm"
31
+ }
32
+ },
33
+ "bottleneck": {
34
+ "type": "rvq",
35
+ "config": {
36
+ "num_quantizers": 4,
37
+ "codebook_size": 2048,
38
+ "dim": 128,
39
+ "decay": 0.99,
40
+ "threshold_ema_dead_code": 2
41
+ }
42
+ },
43
+ "latent_dim": 128,
44
+ "downsampling_ratio": 640,
45
+ "io_channels": 1
46
+ },
47
+ "training": {
48
+ "learning_rate": 1e-4,
49
+ "warmup_steps": 0,
50
+ "use_ema": true,
51
+ "loss_configs": {
52
+ "discriminator": {
53
+ "type": "encodec",
54
+ "config": {
55
+ "filters": 32,
56
+ "n_ffts": [2048, 1024, 512, 256, 128],
57
+ "hop_lengths": [512, 256, 128, 64, 32],
58
+ "win_lengths": [2048, 1024, 512, 256, 128]
59
+ },
60
+ "weights": {
61
+ "adversarial": 0.1,
62
+ "feature_matching": 5.0
63
+ }
64
+ },
65
+ "spectral": {
66
+ "type": "mrstft",
67
+ "config": {
68
+ "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32],
69
+ "hop_sizes": [512, 256, 128, 64, 32, 16, 8],
70
+ "win_lengths": [2048, 1024, 512, 256, 128, 64, 32],
71
+ "perceptual_weighting": true
72
+ },
73
+ "weights": {
74
+ "mrstft": 1.0
75
+ }
76
+ },
77
+ "time": {
78
+ "type": "l1",
79
+ "weights": {
80
+ "l1": 0.0
81
+ }
82
+ }
83
+ },
84
+ "demo": {
85
+ "demo_every": 2000
86
+ }
87
+ }
88
+ }
stable_audio_tools/configs/model_configs/autoencoders/stable_audio_1_0_vae.json ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "autoencoder",
3
+ "sample_size": 65536,
4
+ "sample_rate": 44100,
5
+ "audio_channels": 2,
6
+ "model": {
7
+ "encoder": {
8
+ "type": "dac",
9
+ "config": {
10
+ "in_channels": 2,
11
+ "latent_dim": 128,
12
+ "d_model": 128,
13
+ "strides": [4, 4, 8, 8]
14
+ }
15
+ },
16
+ "decoder": {
17
+ "type": "dac",
18
+ "config": {
19
+ "out_channels": 2,
20
+ "latent_dim": 64,
21
+ "channels": 1536,
22
+ "rates": [8, 8, 4, 4]
23
+ }
24
+ },
25
+ "bottleneck": {
26
+ "type": "vae"
27
+ },
28
+ "latent_dim": 64,
29
+ "downsampling_ratio": 1024,
30
+ "io_channels": 2
31
+ },
32
+ "training": {
33
+ "learning_rate": 1e-4,
34
+ "warmup_steps": 0,
35
+ "use_ema": true,
36
+ "optimizer_configs": {
37
+ "autoencoder": {
38
+ "optimizer": {
39
+ "type": "AdamW",
40
+ "config": {
41
+ "betas": [0.8, 0.99],
42
+ "lr": 1e-4
43
+ }
44
+ },
45
+ "scheduler": {
46
+ "type": "ExponentialLR",
47
+ "config": {
48
+ "gamma": 0.999996
49
+ }
50
+ }
51
+ },
52
+ "discriminator": {
53
+ "optimizer": {
54
+ "type": "AdamW",
55
+ "config": {
56
+ "betas": [0.8, 0.99],
57
+ "lr": 1e-4
58
+ }
59
+ },
60
+ "scheduler": {
61
+ "type": "ExponentialLR",
62
+ "config": {
63
+ "gamma": 0.999996
64
+ }
65
+ }
66
+ }
67
+ },
68
+ "loss_configs": {
69
+ "discriminator": {
70
+ "type": "encodec",
71
+ "config": {
72
+ "filters": 32,
73
+ "n_ffts": [2048, 1024, 512, 256, 128],
74
+ "hop_lengths": [512, 256, 128, 64, 32],
75
+ "win_lengths": [2048, 1024, 512, 256, 128]
76
+ },
77
+ "weights": {
78
+ "adversarial": 0.1,
79
+ "feature_matching": 5.0
80
+ }
81
+ },
82
+ "spectral": {
83
+ "type": "mrstft",
84
+ "config": {
85
+ "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32],
86
+ "hop_sizes": [512, 256, 128, 64, 32, 16, 8],
87
+ "win_lengths": [2048, 1024, 512, 256, 128, 64, 32],
88
+ "perceptual_weighting": true
89
+ },
90
+ "weights": {
91
+ "mrstft": 1.0
92
+ }
93
+ },
94
+ "time": {
95
+ "type": "l1",
96
+ "weights": {
97
+ "l1": 0.0
98
+ }
99
+ },
100
+ "bottleneck": {
101
+ "type": "kl",
102
+ "weights": {
103
+ "kl": 1e-6
104
+ }
105
+ }
106
+ },
107
+ "demo": {
108
+ "demo_every": 2000
109
+ }
110
+ }
111
+ }
stable_audio_tools/configs/model_configs/autoencoders/stable_audio_2_0_vae.json ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "autoencoder",
3
+ "sample_size": 65536,
4
+ "sample_rate": 44100,
5
+ "audio_channels": 2,
6
+ "model": {
7
+ "encoder": {
8
+ "type": "oobleck",
9
+ "config": {
10
+ "in_channels": 2,
11
+ "channels": 128,
12
+ "c_mults": [1, 2, 4, 8, 16],
13
+ "strides": [2, 4, 4, 8, 8],
14
+ "latent_dim": 128,
15
+ "use_snake": true
16
+ }
17
+ },
18
+ "decoder": {
19
+ "type": "oobleck",
20
+ "config": {
21
+ "out_channels": 2,
22
+ "channels": 128,
23
+ "c_mults": [1, 2, 4, 8, 16],
24
+ "strides": [2, 4, 4, 8, 8],
25
+ "latent_dim": 64,
26
+ "use_snake": true,
27
+ "final_tanh": false
28
+ }
29
+ },
30
+ "bottleneck": {
31
+ "type": "vae"
32
+ },
33
+ "latent_dim": 64,
34
+ "downsampling_ratio": 2048,
35
+ "io_channels": 2
36
+ },
37
+ "training": {
38
+ "learning_rate": 1.5e-4,
39
+ "warmup_steps": 0,
40
+ "use_ema": true,
41
+ "optimizer_configs": {
42
+ "autoencoder": {
43
+ "optimizer": {
44
+ "type": "AdamW",
45
+ "config": {
46
+ "betas": [0.8, 0.99],
47
+ "lr": 1.5e-4,
48
+ "weight_decay": 1e-3
49
+ }
50
+ },
51
+ "scheduler": {
52
+ "type": "InverseLR",
53
+ "config": {
54
+ "inv_gamma": 200000,
55
+ "power": 0.5,
56
+ "warmup": 0.999
57
+ }
58
+ }
59
+ },
60
+ "discriminator": {
61
+ "optimizer": {
62
+ "type": "AdamW",
63
+ "config": {
64
+ "betas": [0.8, 0.99],
65
+ "lr": 3e-4,
66
+ "weight_decay": 1e-3
67
+ }
68
+ },
69
+ "scheduler": {
70
+ "type": "InverseLR",
71
+ "config": {
72
+ "inv_gamma": 200000,
73
+ "power": 0.5,
74
+ "warmup": 0.999
75
+ }
76
+ }
77
+ }
78
+ },
79
+ "loss_configs": {
80
+ "discriminator": {
81
+ "type": "encodec",
82
+ "config": {
83
+ "filters": 64,
84
+ "n_ffts": [2048, 1024, 512, 256, 128],
85
+ "hop_lengths": [512, 256, 128, 64, 32],
86
+ "win_lengths": [2048, 1024, 512, 256, 128]
87
+ },
88
+ "weights": {
89
+ "adversarial": 0.1,
90
+ "feature_matching": 5.0
91
+ }
92
+ },
93
+ "spectral": {
94
+ "type": "mrstft",
95
+ "config": {
96
+ "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32],
97
+ "hop_sizes": [512, 256, 128, 64, 32, 16, 8],
98
+ "win_lengths": [2048, 1024, 512, 256, 128, 64, 32],
99
+ "perceptual_weighting": true
100
+ },
101
+ "weights": {
102
+ "mrstft": 1.0
103
+ }
104
+ },
105
+ "time": {
106
+ "type": "l1",
107
+ "weights": {
108
+ "l1": 0.0
109
+ }
110
+ },
111
+ "bottleneck": {
112
+ "type": "kl",
113
+ "weights": {
114
+ "kl": 1e-4
115
+ }
116
+ }
117
+ },
118
+ "demo": {
119
+ "demo_every": 2000
120
+ }
121
+ }
122
+ }
stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_base.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "diffusion_uncond",
3
+ "sample_size": 65536,
4
+ "sample_rate": 48000,
5
+ "model": {
6
+ "type": "DAU1d",
7
+ "config": {
8
+ "n_attn_layers": 5
9
+ }
10
+ },
11
+ "training": {
12
+ "learning_rate": 1e-4,
13
+ "demo": {
14
+ "demo_every": 2000,
15
+ "demo_steps": 250
16
+ }
17
+ }
18
+ }
stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_base_16k.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "diffusion_uncond",
3
+ "sample_size": 65536,
4
+ "sample_rate": 16000,
5
+ "model": {
6
+ "type": "DAU1d",
7
+ "config": {
8
+ "n_attn_layers": 5
9
+ }
10
+ },
11
+ "training": {
12
+ "learning_rate": 1e-4,
13
+ "demo": {
14
+ "demo_every": 2000,
15
+ "demo_steps": 250
16
+ }
17
+ }
18
+ }
stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_base_44k.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "diffusion_uncond",
3
+ "sample_size": 65536,
4
+ "sample_rate": 44100,
5
+ "model": {
6
+ "type": "DAU1d",
7
+ "config": {
8
+ "n_attn_layers": 5
9
+ }
10
+ },
11
+ "training": {
12
+ "learning_rate": 4e-5,
13
+ "demo": {
14
+ "demo_every": 2000,
15
+ "demo_steps": 250
16
+ }
17
+ }
18
+ }
stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_large.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "diffusion_uncond",
3
+ "sample_size": 131072,
4
+ "sample_rate": 48000,
5
+ "model": {
6
+ "type": "DAU1d",
7
+ "config": {
8
+ "n_attn_layers": 5
9
+ }
10
+ },
11
+ "training": {
12
+ "learning_rate": 1e-4,
13
+ "demo": {
14
+ "demo_every": 2000,
15
+ "demo_steps": 250
16
+ }
17
+ }
18
+ }
stable_audio_tools/configs/model_configs/txt2audio/musicgen_small_finetune.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "musicgen",
3
+ "sample_size": 320000,
4
+ "sample_rate": 32000,
5
+ "audio_channels": 1,
6
+ "model": {
7
+ "pretrained": "small"
8
+ },
9
+ "training": {
10
+ "learning_rate": 1e-4,
11
+ "demo": {
12
+ "demo_every": 2000,
13
+ "demo_cond": [
14
+ {"prompt": "Keywords: Atmospheres, Orchestral Drone, Bass, Sci-Fi Ambient Soundscape, Synthesiser, Middle Eastern Vocal, dramatic piano"},
15
+ {"prompt": "Genre: Corporate|Instruments: Ukulele, Drums, Clapping, Glockenspiel"},
16
+ {"prompt": "Description: 116 BPM rock drums, drum track for a rock song"},
17
+ {"prompt": "A grand orchestral arrangement with thunderous percussion, epic brass fanfares, and soaring strings, creating a cinematic atmosphere fit for a heroic battle."}
18
+ ],
19
+ "demo_cfg_scales": [3, 6, 9]
20
+ }
21
+ }
22
+ }
stable_audio_tools/configs/model_configs/txt2audio/stable_audio_1_0.json ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "diffusion_cond",
3
+ "sample_size": 4194304,
4
+ "sample_rate": 44100,
5
+ "audio_channels": 2,
6
+ "model": {
7
+ "pretransform": {
8
+ "type": "autoencoder",
9
+ "iterate_batch": true,
10
+ "config": {
11
+ "encoder": {
12
+ "type": "dac",
13
+ "config": {
14
+ "in_channels": 2,
15
+ "latent_dim": 128,
16
+ "d_model": 128,
17
+ "strides": [4, 4, 8, 8]
18
+ }
19
+ },
20
+ "decoder": {
21
+ "type": "dac",
22
+ "config": {
23
+ "out_channels": 2,
24
+ "latent_dim": 64,
25
+ "channels": 1536,
26
+ "rates": [8, 8, 4, 4]
27
+ }
28
+ },
29
+ "bottleneck": {
30
+ "type": "vae"
31
+ },
32
+ "latent_dim": 64,
33
+ "downsampling_ratio": 1024,
34
+ "io_channels": 2
35
+ }
36
+ },
37
+ "conditioning": {
38
+ "configs": [
39
+ {
40
+ "id": "prompt",
41
+ "type": "clap_text",
42
+ "config": {
43
+ "audio_model_type": "HTSAT-base",
44
+ "enable_fusion": true,
45
+ "clap_ckpt_path": "/path/to/clap.ckpt",
46
+ "use_text_features": true,
47
+ "feature_layer_ix": -2
48
+ }
49
+ },
50
+ {
51
+ "id": "seconds_start",
52
+ "type": "int",
53
+ "config": {
54
+ "min_val": 0,
55
+ "max_val": 512
56
+ }
57
+ },
58
+ {
59
+ "id": "seconds_total",
60
+ "type": "int",
61
+ "config": {
62
+ "min_val": 0,
63
+ "max_val": 512
64
+ }
65
+ }
66
+ ],
67
+ "cond_dim": 768
68
+ },
69
+ "diffusion": {
70
+ "type": "adp_cfg_1d",
71
+ "cross_attention_cond_ids": ["prompt", "seconds_start", "seconds_total"],
72
+ "config": {
73
+ "in_channels": 64,
74
+ "context_embedding_features": 768,
75
+ "context_embedding_max_length": 79,
76
+ "channels": 256,
77
+ "resnet_groups": 16,
78
+ "kernel_multiplier_downsample": 2,
79
+ "multipliers": [4, 4, 4, 5, 5],
80
+ "factors": [1, 2, 2, 4],
81
+ "num_blocks": [2, 2, 2, 2],
82
+ "attentions": [1, 3, 3, 3, 3],
83
+ "attention_heads": 16,
84
+ "attention_multiplier": 4,
85
+ "use_nearest_upsample": false,
86
+ "use_skip_scale": true,
87
+ "use_context_time": true
88
+ }
89
+ },
90
+ "io_channels": 64
91
+ },
92
+ "training": {
93
+ "learning_rate": 4e-5,
94
+ "demo": {
95
+ "demo_every": 2000,
96
+ "demo_steps": 250,
97
+ "num_demos": 4,
98
+ "demo_cond": [
99
+ {"prompt": "A beautiful piano arpeggio", "seconds_start": 0, "seconds_total": 95},
100
+ {"prompt": "A tropical house track with upbeat melodies, a driving bassline, and cheery vibes", "seconds_start": 0, "seconds_total": 90},
101
+ {"prompt": "A cool 80s glam rock song with driving drums and distorted guitars", "seconds_start": 0, "seconds_total": 180},
102
+ {"prompt": "A grand orchestral arrangement with thunderous percussion, epic brass fanfares, and soaring strings, creating a cinematic atmosphere fit for a heroic battle.", "seconds_start": 0, "seconds_total": 60}
103
+ ],
104
+ "demo_cfg_scales": [3, 6, 9]
105
+ }
106
+ }
107
+ }
stable_audio_tools/configs/model_configs/txt2audio/stable_audio_2_0.json ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "diffusion_cond",
3
+ "sample_size": 12582912,
4
+ "sample_rate": 44100,
5
+ "audio_channels": 2,
6
+ "model": {
7
+ "pretransform": {
8
+ "type": "autoencoder",
9
+ "iterate_batch": true,
10
+ "config": {
11
+ "encoder": {
12
+ "type": "oobleck",
13
+ "config": {
14
+ "in_channels": 2,
15
+ "channels": 128,
16
+ "c_mults": [1, 2, 4, 8, 16],
17
+ "strides": [2, 4, 4, 8, 8],
18
+ "latent_dim": 128,
19
+ "use_snake": true
20
+ }
21
+ },
22
+ "decoder": {
23
+ "type": "oobleck",
24
+ "config": {
25
+ "out_channels": 2,
26
+ "channels": 128,
27
+ "c_mults": [1, 2, 4, 8, 16],
28
+ "strides": [2, 4, 4, 8, 8],
29
+ "latent_dim": 64,
30
+ "use_snake": true,
31
+ "final_tanh": false
32
+ }
33
+ },
34
+ "bottleneck": {
35
+ "type": "vae"
36
+ },
37
+ "latent_dim": 64,
38
+ "downsampling_ratio": 2048,
39
+ "io_channels": 2
40
+ }
41
+ },
42
+ "conditioning": {
43
+ "configs": [
44
+ {
45
+ "id": "prompt",
46
+ "type": "clap_text",
47
+ "config": {
48
+ "audio_model_type": "HTSAT-base",
49
+ "enable_fusion": true,
50
+ "clap_ckpt_path": "/path/to/clap.ckpt",
51
+ "use_text_features": true,
52
+ "feature_layer_ix": -2
53
+ }
54
+ },
55
+ {
56
+ "id": "seconds_start",
57
+ "type": "number",
58
+ "config": {
59
+ "min_val": 0,
60
+ "max_val": 512
61
+ }
62
+ },
63
+ {
64
+ "id": "seconds_total",
65
+ "type": "number",
66
+ "config": {
67
+ "min_val": 0,
68
+ "max_val": 512
69
+ }
70
+ }
71
+ ],
72
+ "cond_dim": 768
73
+ },
74
+ "diffusion": {
75
+ "cross_attention_cond_ids": ["prompt", "seconds_start", "seconds_total"],
76
+ "global_cond_ids": ["seconds_start", "seconds_total"],
77
+ "type": "dit",
78
+ "config": {
79
+ "io_channels": 64,
80
+ "embed_dim": 1536,
81
+ "depth": 24,
82
+ "num_heads": 24,
83
+ "cond_token_dim": 768,
84
+ "global_cond_dim": 1536,
85
+ "project_cond_tokens": false,
86
+ "transformer_type": "continuous_transformer"
87
+ }
88
+ },
89
+ "io_channels": 64
90
+ },
91
+ "training": {
92
+ "use_ema": true,
93
+ "log_loss_info": false,
94
+ "optimizer_configs": {
95
+ "diffusion": {
96
+ "optimizer": {
97
+ "type": "AdamW",
98
+ "config": {
99
+ "lr": 5e-5,
100
+ "betas": [0.9, 0.999],
101
+ "weight_decay": 1e-3
102
+ }
103
+ },
104
+ "scheduler": {
105
+ "type": "InverseLR",
106
+ "config": {
107
+ "inv_gamma": 1000000,
108
+ "power": 0.5,
109
+ "warmup": 0.99
110
+ }
111
+ }
112
+ }
113
+ },
114
+ "demo": {
115
+ "demo_every": 2000,
116
+ "demo_steps": 250,
117
+ "num_demos": 4,
118
+ "demo_cond": [
119
+ {"prompt": "A beautiful piano arpeggio", "seconds_start": 0, "seconds_total": 80},
120
+ {"prompt": "A tropical house track with upbeat melodies, a driving bassline, and cheery vibes", "seconds_start": 0, "seconds_total": 250},
121
+ {"prompt": "A cool 80s glam rock song with driving drums and distorted guitars", "seconds_start": 0, "seconds_total": 180},
122
+ {"prompt": "A grand orchestral arrangement with thunderous percussion, epic brass fanfares, and soaring strings, creating a cinematic atmosphere fit for a heroic battle.", "seconds_start": 0, "seconds_total": 190}
123
+ ],
124
+ "demo_cfg_scales": [3, 6, 9]
125
+ }
126
+ }
127
+ }
stable_audio_tools/data/__init__.py ADDED
File without changes
stable_audio_tools/data/dataset.py ADDED
@@ -0,0 +1,597 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import numpy as np
3
+ import io
4
+ import os
5
+ import posixpath
6
+ import random
7
+ import re
8
+ import subprocess
9
+ import time
10
+ import torch
11
+ import torchaudio
12
+ import webdataset as wds
13
+
14
+ from aeiou.core import is_silence
15
+ from os import path
16
+ from pedalboard.io import AudioFile
17
+ from torchaudio import transforms as T
18
+ from typing import Optional, Callable, List
19
+
20
+ from .utils import Stereo, Mono, PhaseFlipper, PadCrop_Normalized_T
21
+
22
+ AUDIO_KEYS = ("flac", "wav", "mp3", "m4a", "ogg", "opus")
23
+
24
+ # fast_scandir implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py
25
+
26
+ def fast_scandir(
27
+ dir:str, # top-level directory at which to begin scanning
28
+ ext:list, # list of allowed file extensions,
29
+ #max_size = 1 * 1000 * 1000 * 1000 # Only files < 1 GB
30
+ ):
31
+ "very fast `glob` alternative. from https://stackoverflow.com/a/59803793/4259243"
32
+ subfolders, files = [], []
33
+ ext = ['.'+x if x[0]!='.' else x for x in ext] # add starting period to extensions if needed
34
+ try: # hope to avoid 'permission denied' by this try
35
+ for f in os.scandir(dir):
36
+ try: # 'hope to avoid too many levels of symbolic links' error
37
+ if f.is_dir():
38
+ subfolders.append(f.path)
39
+ elif f.is_file():
40
+ file_ext = os.path.splitext(f.name)[1].lower()
41
+ is_hidden = os.path.basename(f.path).startswith(".")
42
+
43
+ if file_ext in ext and not is_hidden:
44
+ files.append(f.path)
45
+ except:
46
+ pass
47
+ except:
48
+ pass
49
+
50
+ for dir in list(subfolders):
51
+ sf, f = fast_scandir(dir, ext)
52
+ subfolders.extend(sf)
53
+ files.extend(f)
54
+ return subfolders, files
55
+
56
+ def keyword_scandir(
57
+ dir: str, # top-level directory at which to begin scanning
58
+ ext: list, # list of allowed file extensions
59
+ keywords: list, # list of keywords to search for in the file name
60
+ ):
61
+ "very fast `glob` alternative. from https://stackoverflow.com/a/59803793/4259243"
62
+ subfolders, files = [], []
63
+ # make keywords case insensitive
64
+ keywords = [keyword.lower() for keyword in keywords]
65
+ # add starting period to extensions if needed
66
+ ext = ['.'+x if x[0] != '.' else x for x in ext]
67
+ banned_words = ["paxheader", "__macosx"]
68
+ try: # hope to avoid 'permission denied' by this try
69
+ for f in os.scandir(dir):
70
+ try: # 'hope to avoid too many levels of symbolic links' error
71
+ if f.is_dir():
72
+ subfolders.append(f.path)
73
+ elif f.is_file():
74
+ is_hidden = f.name.split("/")[-1][0] == '.'
75
+ has_ext = os.path.splitext(f.name)[1].lower() in ext
76
+ name_lower = f.name.lower()
77
+ has_keyword = any(
78
+ [keyword in name_lower for keyword in keywords])
79
+ has_banned = any(
80
+ [banned_word in name_lower for banned_word in banned_words])
81
+ if has_ext and has_keyword and not has_banned and not is_hidden and not os.path.basename(f.path).startswith("._"):
82
+ files.append(f.path)
83
+ except:
84
+ pass
85
+ except:
86
+ pass
87
+
88
+ for dir in list(subfolders):
89
+ sf, f = keyword_scandir(dir, ext, keywords)
90
+ subfolders.extend(sf)
91
+ files.extend(f)
92
+ return subfolders, files
93
+
94
+ def get_audio_filenames(
95
+ paths: list, # directories in which to search
96
+ keywords=None,
97
+ exts=['.wav', '.mp3', '.flac', '.ogg', '.aif', '.opus']
98
+ ):
99
+ "recursively get a list of audio filenames"
100
+ filenames = []
101
+ if type(paths) is str:
102
+ paths = [paths]
103
+ for path in paths: # get a list of relevant filenames
104
+ if keywords is not None:
105
+ subfolders, files = keyword_scandir(path, exts, keywords)
106
+ else:
107
+ subfolders, files = fast_scandir(path, exts)
108
+ filenames.extend(files)
109
+ return filenames
110
+
111
+ class SampleDataset(torch.utils.data.Dataset):
112
+ def __init__(
113
+ self,
114
+ paths,
115
+ sample_size=65536,
116
+ sample_rate=48000,
117
+ keywords=None,
118
+ relpath=None,
119
+ random_crop=True,
120
+ force_channels="stereo",
121
+ custom_metadata_fn: Optional[Callable[[str], str]] = None
122
+ ):
123
+ super().__init__()
124
+ self.filenames = []
125
+ self.relpath = relpath
126
+
127
+ self.augs = torch.nn.Sequential(
128
+ PhaseFlipper(),
129
+ )
130
+
131
+ self.pad_crop = PadCrop_Normalized_T(sample_size, sample_rate, randomize=random_crop)
132
+
133
+ self.force_channels = force_channels
134
+
135
+ self.encoding = torch.nn.Sequential(
136
+ Stereo() if self.force_channels == "stereo" else torch.nn.Identity(),
137
+ Mono() if self.force_channels == "mono" else torch.nn.Identity(),
138
+ )
139
+
140
+ self.filenames = get_audio_filenames(paths, keywords)
141
+
142
+ print(f'Found {len(self.filenames)} files')
143
+
144
+ self.sr = sample_rate
145
+
146
+ self.custom_metadata_fn = custom_metadata_fn
147
+
148
+ def load_file(self, filename):
149
+ ext = filename.split(".")[-1]
150
+
151
+ if ext == "mp3":
152
+ with AudioFile(filename) as f:
153
+ audio = f.read(f.frames)
154
+ audio = torch.from_numpy(audio)
155
+ in_sr = f.samplerate
156
+ else:
157
+ audio, in_sr = torchaudio.load(filename, format=ext)
158
+
159
+ if in_sr != self.sr:
160
+ resample_tf = T.Resample(in_sr, self.sr)
161
+ audio = resample_tf(audio)
162
+
163
+ return audio
164
+
165
+ def __len__(self):
166
+ return len(self.filenames)
167
+
168
+ def __getitem__(self, idx):
169
+ audio_filename = self.filenames[idx]
170
+ try:
171
+ start_time = time.time()
172
+ audio = self.load_file(audio_filename)
173
+
174
+ audio, t_start, t_end, seconds_start, seconds_total, padding_mask = self.pad_crop(audio)
175
+
176
+ # Run augmentations on this sample (including random crop)
177
+ if self.augs is not None:
178
+ audio = self.augs(audio)
179
+
180
+ audio = audio.clamp(-1, 1)
181
+
182
+ # Encode the file to assist in prediction
183
+ if self.encoding is not None:
184
+ audio = self.encoding(audio)
185
+
186
+ info = {}
187
+
188
+ info["path"] = audio_filename
189
+
190
+ if self.relpath is not None:
191
+ info["relpath"] = path.relpath(audio_filename, self.relpath)
192
+
193
+ info["timestamps"] = (t_start, t_end)
194
+ info["seconds_start"] = seconds_start
195
+ info["seconds_total"] = seconds_total
196
+ info["padding_mask"] = padding_mask
197
+
198
+ end_time = time.time()
199
+
200
+ info["load_time"] = end_time - start_time
201
+
202
+ if self.custom_metadata_fn is not None:
203
+ custom_metadata = self.custom_metadata_fn(info, audio)
204
+ info.update(custom_metadata)
205
+
206
+ if "__reject__" in info and info["__reject__"]:
207
+ return self[random.randrange(len(self))]
208
+
209
+ return (audio, info)
210
+ except Exception as e:
211
+ print(f'Couldn\'t load file {audio_filename}: {e}')
212
+ return self[random.randrange(len(self))]
213
+
214
+ def group_by_keys(data, keys=wds.tariterators.base_plus_ext, lcase=True, suffixes=None, handler=None):
215
+ """Return function over iterator that groups key, value pairs into samples.
216
+ :param keys: function that splits the key into key and extension (base_plus_ext)
217
+ :param lcase: convert suffixes to lower case (Default value = True)
218
+ """
219
+ current_sample = None
220
+ for filesample in data:
221
+ assert isinstance(filesample, dict)
222
+ fname, value = filesample["fname"], filesample["data"]
223
+ prefix, suffix = keys(fname)
224
+ if wds.tariterators.trace:
225
+ print(
226
+ prefix,
227
+ suffix,
228
+ current_sample.keys() if isinstance(current_sample, dict) else None,
229
+ )
230
+ if prefix is None:
231
+ continue
232
+ if lcase:
233
+ suffix = suffix.lower()
234
+ if current_sample is None or prefix != current_sample["__key__"]:
235
+ if wds.tariterators.valid_sample(current_sample):
236
+ yield current_sample
237
+ current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
238
+ if suffix in current_sample:
239
+ print(f"{fname}: duplicate file name in tar file {suffix} {current_sample.keys()}")
240
+ if suffixes is None or suffix in suffixes:
241
+ current_sample[suffix] = value
242
+ if wds.tariterators.valid_sample(current_sample):
243
+ yield current_sample
244
+
245
+ wds.tariterators.group_by_keys = group_by_keys
246
+
247
+ # S3 code and WDS preprocessing code based on implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py
248
+
249
+ def get_s3_contents(dataset_path, s3_url_prefix=None, filter='', recursive=True, debug=False, profile=None):
250
+ """
251
+ Returns a list of full S3 paths to files in a given S3 bucket and directory path.
252
+ """
253
+ # Ensure dataset_path ends with a trailing slash
254
+ if dataset_path != '' and not dataset_path.endswith('/'):
255
+ dataset_path += '/'
256
+ # Use posixpath to construct the S3 URL path
257
+ bucket_path = posixpath.join(s3_url_prefix or '', dataset_path)
258
+ # Construct the `aws s3 ls` command
259
+ cmd = ['aws', 's3', 'ls', bucket_path]
260
+
261
+ if profile is not None:
262
+ cmd.extend(['--profile', profile])
263
+
264
+ if recursive:
265
+ # Add the --recursive flag if requested
266
+ cmd.append('--recursive')
267
+
268
+ # Run the `aws s3 ls` command and capture the output
269
+ run_ls = subprocess.run(cmd, capture_output=True, check=True)
270
+ # Split the output into lines and strip whitespace from each line
271
+ contents = run_ls.stdout.decode('utf-8').split('\n')
272
+ contents = [x.strip() for x in contents if x]
273
+ # Remove the timestamp from lines that begin with a timestamp
274
+ contents = [re.sub(r'^\S+\s+\S+\s+\d+\s+', '', x)
275
+ if re.match(r'^\S+\s+\S+\s+\d+\s+', x) else x for x in contents]
276
+ # Construct a full S3 path for each file in the contents list
277
+ contents = [posixpath.join(s3_url_prefix or '', x)
278
+ for x in contents if not x.endswith('/')]
279
+ # Apply the filter, if specified
280
+ if filter:
281
+ contents = [x for x in contents if filter in x]
282
+ # Remove redundant directory names in the S3 URL
283
+ if recursive:
284
+ # Get the main directory name from the S3 URL
285
+ main_dir = "/".join(bucket_path.split('/')[3:])
286
+ # Remove the redundant directory names from each file path
287
+ contents = [x.replace(f'{main_dir}', '').replace(
288
+ '//', '/') for x in contents]
289
+ # Print debugging information, if requested
290
+ if debug:
291
+ print("contents = \n", contents)
292
+ # Return the list of S3 paths to files
293
+ return contents
294
+
295
+
296
+ def get_all_s3_urls(
297
+ names=[], # list of all valid [LAION AudioDataset] dataset names
298
+ # list of subsets you want from those datasets, e.g. ['train','valid']
299
+ subsets=[''],
300
+ s3_url_prefix=None, # prefix for those dataset names
301
+ recursive=True, # recursively list all tar files in all subdirs
302
+ filter_str='tar', # only grab files with this substring
303
+ # print debugging info -- note: info displayed likely to change at dev's whims
304
+ debug=False,
305
+ profiles={}, # dictionary of profiles for each item in names, e.g. {'dataset1': 'profile1', 'dataset2': 'profile2'}
306
+ ):
307
+ "get urls of shards (tar files) for multiple datasets in one s3 bucket"
308
+ urls = []
309
+ for name in names:
310
+ # If s3_url_prefix is not specified, assume the full S3 path is included in each element of the names list
311
+ if s3_url_prefix is None:
312
+ contents_str = name
313
+ else:
314
+ # Construct the S3 path using the s3_url_prefix and the current name value
315
+ contents_str = posixpath.join(s3_url_prefix, name)
316
+ if debug:
317
+ print(f"get_all_s3_urls: {contents_str}:")
318
+ for subset in subsets:
319
+ subset_str = posixpath.join(contents_str, subset)
320
+ if debug:
321
+ print(f"subset_str = {subset_str}")
322
+ # Get the list of tar files in the current subset directory
323
+ profile = profiles.get(name, None)
324
+ tar_list = get_s3_contents(
325
+ subset_str, s3_url_prefix=None, recursive=recursive, filter=filter_str, debug=debug, profile=profile)
326
+ for tar in tar_list:
327
+ # Escape spaces and parentheses in the tar filename for use in the shell command
328
+ tar = tar.replace(" ", "\ ").replace(
329
+ "(", "\(").replace(")", "\)")
330
+ # Construct the S3 path to the current tar file
331
+ s3_path = posixpath.join(name, subset, tar) + " -"
332
+ # Construct the AWS CLI command to download the current tar file
333
+ if s3_url_prefix is None:
334
+ request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {s3_path}"
335
+ else:
336
+ request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {posixpath.join(s3_url_prefix, s3_path)}"
337
+ if profiles.get(name):
338
+ request_str += f" --profile {profiles.get(name)}"
339
+ if debug:
340
+ print("request_str = ", request_str)
341
+ # Add the constructed URL to the list of URLs
342
+ urls.append(request_str)
343
+ return urls
344
+
345
+
346
+ def log_and_continue(exn):
347
+ """Call in an exception handler to ignore any exception, isssue a warning, and continue."""
348
+ print(f"Handling webdataset error ({repr(exn)}). Ignoring.")
349
+ return True
350
+
351
+
352
+ def is_valid_sample(sample):
353
+ has_json = "json" in sample
354
+ has_audio = "audio" in sample
355
+ is_silent = is_silence(sample["audio"])
356
+ is_rejected = "__reject__" in sample["json"] and sample["json"]["__reject__"]
357
+
358
+ return has_json and has_audio and not is_silent and not is_rejected
359
+
360
+ class S3DatasetConfig:
361
+ def __init__(
362
+ self,
363
+ id: str,
364
+ s3_path: str,
365
+ custom_metadata_fn: Optional[Callable[[str], str]] = None,
366
+ profile: Optional[str] = None,
367
+ ):
368
+ self.id = id
369
+ self.s3_path = s3_path
370
+ self.custom_metadata_fn = custom_metadata_fn
371
+ self.profile = profile
372
+ self.urls = []
373
+
374
+ def load_data_urls(self):
375
+ self.urls = get_all_s3_urls(
376
+ names=[self.s3_path],
377
+ s3_url_prefix=None,
378
+ recursive=True,
379
+ profiles={self.s3_path: self.profile} if self.profile else {},
380
+ )
381
+
382
+ return self.urls
383
+
384
+ def audio_decoder(key, value):
385
+ # Get file extension from key
386
+ ext = key.split(".")[-1]
387
+
388
+ if ext in AUDIO_KEYS:
389
+ return torchaudio.load(io.BytesIO(value))
390
+ else:
391
+ return None
392
+
393
+ def collation_fn(samples):
394
+ batched = list(zip(*samples))
395
+ result = []
396
+ for b in batched:
397
+ if isinstance(b[0], (int, float)):
398
+ b = np.array(b)
399
+ elif isinstance(b[0], torch.Tensor):
400
+ b = torch.stack(b)
401
+ elif isinstance(b[0], np.ndarray):
402
+ b = np.array(b)
403
+ else:
404
+ b = b
405
+ result.append(b)
406
+ return result
407
+
408
+ class S3WebDataLoader():
409
+ def __init__(
410
+ self,
411
+ datasets: List[S3DatasetConfig],
412
+ batch_size,
413
+ sample_size,
414
+ sample_rate=48000,
415
+ num_workers=8,
416
+ epoch_steps=1000,
417
+ random_crop=True,
418
+ force_channels="stereo",
419
+ augment_phase=True,
420
+ **data_loader_kwargs
421
+ ):
422
+
423
+ self.datasets = datasets
424
+
425
+ self.sample_size = sample_size
426
+ self.sample_rate = sample_rate
427
+ self.random_crop = random_crop
428
+ self.force_channels = force_channels
429
+ self.augment_phase = augment_phase
430
+
431
+ urls = [dataset.load_data_urls() for dataset in datasets]
432
+
433
+ # Flatten the list of lists of URLs
434
+ urls = [url for dataset_urls in urls for url in dataset_urls]
435
+
436
+ self.dataset = wds.DataPipeline(
437
+ wds.ResampledShards(urls),
438
+ wds.tarfile_to_samples(handler=log_and_continue),
439
+ wds.decode(audio_decoder, handler=log_and_continue),
440
+ wds.map(self.wds_preprocess, handler=log_and_continue),
441
+ wds.select(is_valid_sample),
442
+ wds.to_tuple("audio", "json", handler=log_and_continue),
443
+ wds.batched(batch_size, partial=False, collation_fn=collation_fn),
444
+ ).with_epoch(epoch_steps//num_workers if num_workers > 0 else epoch_steps)
445
+
446
+ self.data_loader = wds.WebLoader(self.dataset, num_workers=num_workers, **data_loader_kwargs)
447
+
448
+ def wds_preprocess(self, sample):
449
+
450
+ found_key, rewrite_key = '', ''
451
+ for k, v in sample.items(): # print the all entries in dict
452
+ for akey in AUDIO_KEYS:
453
+ if k.endswith(akey):
454
+ # to rename long/weird key with its simpler counterpart
455
+ found_key, rewrite_key = k, akey
456
+ break
457
+ if '' != found_key:
458
+ break
459
+ if '' == found_key: # got no audio!
460
+ return None # try returning None to tell WebDataset to skip this one
461
+
462
+ audio, in_sr = sample[found_key]
463
+ if in_sr != self.sample_rate:
464
+ resample_tf = T.Resample(in_sr, self.sample_rate)
465
+ audio = resample_tf(audio)
466
+
467
+ if self.sample_size is not None:
468
+ # Pad/crop and get the relative timestamp
469
+ pad_crop = PadCrop_Normalized_T(
470
+ self.sample_size, randomize=self.random_crop, sample_rate=self.sample_rate)
471
+ audio, t_start, t_end, seconds_start, seconds_total, padding_mask = pad_crop(
472
+ audio)
473
+ sample["json"]["seconds_start"] = seconds_start
474
+ sample["json"]["seconds_total"] = seconds_total
475
+ sample["json"]["padding_mask"] = padding_mask
476
+ else:
477
+ t_start, t_end = 0, 1
478
+
479
+ # Check if audio is length zero, initialize to a single zero if so
480
+ if audio.shape[-1] == 0:
481
+ audio = torch.zeros(1, 1)
482
+
483
+ # Make the audio stereo and augment by randomly inverting phase
484
+ augs = torch.nn.Sequential(
485
+ Stereo() if self.force_channels == "stereo" else torch.nn.Identity(),
486
+ Mono() if self.force_channels == "mono" else torch.nn.Identity(),
487
+ PhaseFlipper() if self.augment_phase else torch.nn.Identity()
488
+ )
489
+
490
+ audio = augs(audio)
491
+
492
+ sample["json"]["timestamps"] = (t_start, t_end)
493
+
494
+ if "text" in sample["json"]:
495
+ sample["json"]["prompt"] = sample["json"]["text"]
496
+
497
+ # Check for custom metadata functions
498
+ for dataset in self.datasets:
499
+ if dataset.custom_metadata_fn is None:
500
+ continue
501
+
502
+ if dataset.s3_path in sample["__url__"]:
503
+ custom_metadata = dataset.custom_metadata_fn(sample["json"], audio)
504
+ sample["json"].update(custom_metadata)
505
+
506
+ if found_key != rewrite_key: # rename long/weird key with its simpler counterpart
507
+ del sample[found_key]
508
+
509
+ sample["audio"] = audio
510
+
511
+ # Add audio to the metadata as well for conditioning
512
+ sample["json"]["audio"] = audio
513
+
514
+ return sample
515
+
516
+ def create_dataloader_from_config(dataset_config, batch_size, sample_size, sample_rate, audio_channels=2, num_workers=4):
517
+
518
+ dataset_type = dataset_config.get("dataset_type", None)
519
+
520
+ assert dataset_type is not None, "Dataset type must be specified in dataset config"
521
+
522
+ if audio_channels == 1:
523
+ force_channels = "mono"
524
+ else:
525
+ force_channels = "stereo"
526
+
527
+ if dataset_type == "audio_dir":
528
+
529
+ audio_dir_configs = dataset_config.get("datasets", None)
530
+
531
+ assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]"
532
+
533
+ training_dirs = []
534
+
535
+ custom_metadata_fn = None
536
+ custom_metadata_module_path = dataset_config.get("custom_metadata_module", None)
537
+
538
+ if custom_metadata_module_path is not None:
539
+ spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path)
540
+ metadata_module = importlib.util.module_from_spec(spec)
541
+ spec.loader.exec_module(metadata_module)
542
+
543
+ custom_metadata_fn = metadata_module.get_custom_metadata
544
+
545
+ for audio_dir_config in audio_dir_configs:
546
+ audio_dir_path = audio_dir_config.get("path", None)
547
+ assert audio_dir_path is not None, "Path must be set for local audio directory configuration"
548
+ training_dirs.append(audio_dir_path)
549
+
550
+ train_set = SampleDataset(
551
+ training_dirs,
552
+ sample_rate=sample_rate,
553
+ sample_size=sample_size,
554
+ random_crop=dataset_config.get("random_crop", True),
555
+ force_channels=force_channels,
556
+ custom_metadata_fn=custom_metadata_fn,
557
+ relpath=training_dirs[0] #TODO: Make relpath relative to each training dir
558
+ )
559
+
560
+ return torch.utils.data.DataLoader(train_set, batch_size, shuffle=True,
561
+ num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=True, collate_fn=collation_fn)
562
+
563
+ elif dataset_type == "s3":
564
+ dataset_configs = []
565
+
566
+ for s3_config in dataset_config["datasets"]:
567
+
568
+ custom_metadata_fn = None
569
+ custom_metadata_module_path = s3_config.get("custom_metadata_module", None)
570
+
571
+ if custom_metadata_module_path is not None:
572
+ spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path)
573
+ metadata_module = importlib.util.module_from_spec(spec)
574
+ spec.loader.exec_module(metadata_module)
575
+
576
+ custom_metadata_fn = metadata_module.get_custom_metadata
577
+
578
+ dataset_configs.append(
579
+ S3DatasetConfig(
580
+ id=s3_config["id"],
581
+ s3_path=s3_config["s3_path"],
582
+ custom_metadata_fn=custom_metadata_fn,
583
+ profile=s3_config.get("profile", None),
584
+ )
585
+ )
586
+
587
+ return S3WebDataLoader(
588
+ dataset_configs,
589
+ sample_rate=sample_rate,
590
+ sample_size=sample_size,
591
+ batch_size=batch_size,
592
+ random_crop=dataset_config.get("random_crop", True),
593
+ num_workers=num_workers,
594
+ persistent_workers=True,
595
+ force_channels=force_channels,
596
+ epoch_steps=dataset_config.get("epoch_steps", 2000),
597
+ ).data_loader
stable_audio_tools/data/utils.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import torch
4
+
5
+ from torch import nn
6
+ from typing import Tuple
7
+
8
+ class PadCrop(nn.Module):
9
+ def __init__(self, n_samples, randomize=True):
10
+ super().__init__()
11
+ self.n_samples = n_samples
12
+ self.randomize = randomize
13
+
14
+ def __call__(self, signal):
15
+ n, s = signal.shape
16
+ start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item()
17
+ end = start + self.n_samples
18
+ output = signal.new_zeros([n, self.n_samples])
19
+ output[:, :min(s, self.n_samples)] = signal[:, start:end]
20
+ return output
21
+
22
+ class PadCrop_Normalized_T(nn.Module):
23
+
24
+ def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True):
25
+
26
+ super().__init__()
27
+
28
+ self.n_samples = n_samples
29
+ self.sample_rate = sample_rate
30
+ self.randomize = randomize
31
+
32
+ def __call__(self, source: torch.Tensor) -> Tuple[torch.Tensor, float, float, int, int]:
33
+
34
+ n_channels, n_samples = source.shape
35
+
36
+ # If the audio is shorter than the desired length, pad it
37
+ upper_bound = max(0, n_samples - self.n_samples)
38
+
39
+ # If randomize is False, always start at the beginning of the audio
40
+ offset = 0
41
+ if(self.randomize and n_samples > self.n_samples):
42
+ offset = random.randint(0, upper_bound)
43
+
44
+ # Calculate the start and end times of the chunk
45
+ t_start = offset / (upper_bound + self.n_samples)
46
+ t_end = (offset + self.n_samples) / (upper_bound + self.n_samples)
47
+
48
+ # Create the chunk
49
+ chunk = source.new_zeros([n_channels, self.n_samples])
50
+
51
+ # Copy the audio into the chunk
52
+ chunk[:, :min(n_samples, self.n_samples)] = source[:, offset:offset + self.n_samples]
53
+
54
+ # Calculate the start and end times of the chunk in seconds
55
+ seconds_start = math.floor(offset / self.sample_rate)
56
+ seconds_total = math.ceil(n_samples / self.sample_rate)
57
+
58
+ # Create a mask the same length as the chunk with 1s where the audio is and 0s where it isn't
59
+ padding_mask = torch.zeros([self.n_samples])
60
+ padding_mask[:min(n_samples, self.n_samples)] = 1
61
+
62
+
63
+ return (
64
+ chunk,
65
+ t_start,
66
+ t_end,
67
+ seconds_start,
68
+ seconds_total,
69
+ padding_mask
70
+ )
71
+
72
+ class PhaseFlipper(nn.Module):
73
+ "Randomly invert the phase of a signal"
74
+ def __init__(self, p=0.5):
75
+ super().__init__()
76
+ self.p = p
77
+ def __call__(self, signal):
78
+ return -signal if (random.random() < self.p) else signal
79
+
80
+ class Mono(nn.Module):
81
+ def __call__(self, signal):
82
+ return torch.mean(signal, dim=0, keepdims=True) if len(signal.shape) > 1 else signal
83
+
84
+ class Stereo(nn.Module):
85
+ def __call__(self, signal):
86
+ signal_shape = signal.shape
87
+ # Check if it's mono
88
+ if len(signal_shape) == 1: # s -> 2, s
89
+ signal = signal.unsqueeze(0).repeat(2, 1)
90
+ elif len(signal_shape) == 2:
91
+ if signal_shape[0] == 1: #1, s -> 2, s
92
+ signal = signal.repeat(2, 1)
93
+ elif signal_shape[0] > 2: #?, s -> 2,s
94
+ signal = signal[:2, :]
95
+
96
+ return signal
stable_audio_tools/inference/__init__.py ADDED
File without changes
stable_audio_tools/inference/generation.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import typing as tp
4
+ import math
5
+ from torchaudio import transforms as T
6
+
7
+ from .utils import prepare_audio
8
+ from .sampling import sample, sample_k
9
+ from ..data.utils import PadCrop
10
+
11
+ def generate_diffusion_uncond(
12
+ model,
13
+ steps: int = 250,
14
+ batch_size: int = 1,
15
+ sample_size: int = 2097152,
16
+ seed: int = -1,
17
+ device: str = "cuda",
18
+ init_audio: tp.Optional[tp.Tuple[int, torch.Tensor]] = None,
19
+ init_noise_level: float = 1.0,
20
+ return_latents = False,
21
+ **sampler_kwargs
22
+ ) -> torch.Tensor:
23
+
24
+ # The length of the output in audio samples
25
+ audio_sample_size = sample_size
26
+
27
+ # If this is latent diffusion, change sample_size instead to the downsampled latent size
28
+ if model.pretransform is not None:
29
+ sample_size = sample_size // model.pretransform.downsampling_ratio
30
+
31
+ # Seed
32
+ # The user can explicitly set the seed to deterministically generate the same output. Otherwise, use a random seed.
33
+ seed = seed if seed != -1 else np.random.randint(0, 2**32 - 1, dtype=np.uint32)
34
+ print(seed)
35
+ torch.manual_seed(seed)
36
+ # Define the initial noise immediately after setting the seed
37
+ noise = torch.randn([batch_size, model.io_channels, sample_size], device=device)
38
+
39
+ if init_audio is not None:
40
+ # The user supplied some initial audio (for inpainting or variation). Let us prepare the input audio.
41
+ in_sr, init_audio = init_audio
42
+
43
+ io_channels = model.io_channels
44
+
45
+ # For latent models, set the io_channels to the autoencoder's io_channels
46
+ if model.pretransform is not None:
47
+ io_channels = model.pretransform.io_channels
48
+
49
+ # Prepare the initial audio for use by the model
50
+ init_audio = prepare_audio(init_audio, in_sr=in_sr, target_sr=model.sample_rate, target_length=audio_sample_size, target_channels=io_channels, device=device)
51
+
52
+ # For latent models, encode the initial audio into latents
53
+ if model.pretransform is not None:
54
+ init_audio = model.pretransform.encode(init_audio)
55
+
56
+ init_audio = init_audio.repeat(batch_size, 1, 1)
57
+ else:
58
+ # The user did not supply any initial audio for inpainting or variation. Generate new output from scratch.
59
+ init_audio = None
60
+ init_noise_level = None
61
+
62
+ # Inpainting mask
63
+
64
+ if init_audio is not None:
65
+ # variations
66
+ sampler_kwargs["sigma_max"] = init_noise_level
67
+ mask = None
68
+ else:
69
+ mask = None
70
+
71
+ # Now the generative AI part:
72
+ # k-diffusion denoising process go!
73
+ sampled = sample_k(model.model, noise, init_audio, mask, steps, **sampler_kwargs, device=device)
74
+
75
+ # Denoising process done.
76
+ # If this is latent diffusion, decode latents back into audio
77
+ if model.pretransform is not None and not return_latents:
78
+ sampled = model.pretransform.decode(sampled)
79
+
80
+ # Return audio
81
+ return sampled
82
+
83
+
84
+ def generate_diffusion_cond(
85
+ model,
86
+ steps: int = 250,
87
+ cfg_scale=6,
88
+ conditioning: dict = None,
89
+ conditioning_tensors: tp.Optional[dict] = None,
90
+ negative_conditioning: dict = None,
91
+ negative_conditioning_tensors: tp.Optional[dict] = None,
92
+ batch_size: int = 1,
93
+ sample_size: int = 2097152,
94
+ sample_rate: int = 48000,
95
+ seed: int = -1,
96
+ device: str = "cuda",
97
+ init_audio: tp.Optional[tp.Tuple[int, torch.Tensor]] = None,
98
+ init_noise_level: float = 1.0,
99
+ mask_args: dict = None,
100
+ return_latents = False,
101
+ **sampler_kwargs
102
+ ) -> torch.Tensor:
103
+ """
104
+ Generate audio from a prompt using a diffusion model.
105
+
106
+ Args:
107
+ model: The diffusion model to use for generation.
108
+ steps: The number of diffusion steps to use.
109
+ cfg_scale: Classifier-free guidance scale
110
+ conditioning: A dictionary of conditioning parameters to use for generation.
111
+ conditioning_tensors: A dictionary of precomputed conditioning tensors to use for generation.
112
+ batch_size: The batch size to use for generation.
113
+ sample_size: The length of the audio to generate, in samples.
114
+ sample_rate: The sample rate of the audio to generate (Deprecated, now pulled from the model directly)
115
+ seed: The random seed to use for generation, or -1 to use a random seed.
116
+ device: The device to use for generation.
117
+ init_audio: A tuple of (sample_rate, audio) to use as the initial audio for generation.
118
+ init_noise_level: The noise level to use when generating from an initial audio sample.
119
+ return_latents: Whether to return the latents used for generation instead of the decoded audio.
120
+ **sampler_kwargs: Additional keyword arguments to pass to the sampler.
121
+ """
122
+
123
+ # The length of the output in audio samples
124
+ audio_sample_size = sample_size
125
+
126
+ # If this is latent diffusion, change sample_size instead to the downsampled latent size
127
+ if model.pretransform is not None:
128
+ sample_size = sample_size // model.pretransform.downsampling_ratio
129
+
130
+ # Seed
131
+ # The user can explicitly set the seed to deterministically generate the same output. Otherwise, use a random seed.
132
+ seed = seed if seed != -1 else np.random.randint(0, 2**32 - 1)
133
+ print(seed)
134
+ torch.manual_seed(seed)
135
+ # Define the initial noise immediately after setting the seed
136
+ noise = torch.randn([batch_size, model.io_channels, sample_size], device=device)
137
+
138
+ # Conditioning
139
+ assert conditioning is not None or conditioning_tensors is not None, "Must provide either conditioning or conditioning_tensors"
140
+ if conditioning_tensors is None:
141
+ conditioning_tensors = model.conditioner(conditioning, device)
142
+ conditioning_tensors = model.get_conditioning_inputs(conditioning_tensors)
143
+
144
+ if negative_conditioning is not None or negative_conditioning_tensors is not None:
145
+
146
+ if negative_conditioning_tensors is None:
147
+ negative_conditioning_tensors = model.conditioner(negative_conditioning, device)
148
+
149
+ negative_conditioning_tensors = model.get_conditioning_inputs(negative_conditioning_tensors, negative=True)
150
+ else:
151
+ negative_conditioning_tensors = {}
152
+
153
+ if init_audio is not None:
154
+ # The user supplied some initial audio (for inpainting or variation). Let us prepare the input audio.
155
+ in_sr, init_audio = init_audio
156
+
157
+ io_channels = model.io_channels
158
+
159
+ # For latent models, set the io_channels to the autoencoder's io_channels
160
+ if model.pretransform is not None:
161
+ io_channels = model.pretransform.io_channels
162
+
163
+ # Prepare the initial audio for use by the model
164
+ init_audio = prepare_audio(init_audio, in_sr=in_sr, target_sr=model.sample_rate, target_length=audio_sample_size, target_channels=io_channels, device=device)
165
+
166
+ # For latent models, encode the initial audio into latents
167
+ if model.pretransform is not None:
168
+ init_audio = model.pretransform.encode(init_audio)
169
+
170
+ init_audio = init_audio.repeat(batch_size, 1, 1)
171
+ else:
172
+ # The user did not supply any initial audio for inpainting or variation. Generate new output from scratch.
173
+ init_audio = None
174
+ init_noise_level = None
175
+ mask_args = None
176
+
177
+ # Inpainting mask
178
+ if init_audio is not None and mask_args is not None:
179
+ # Cut and paste init_audio according to cropfrom, pastefrom, pasteto
180
+ # This is helpful for forward and reverse outpainting
181
+ cropfrom = math.floor(mask_args["cropfrom"]/100.0 * sample_size)
182
+ pastefrom = math.floor(mask_args["pastefrom"]/100.0 * sample_size)
183
+ pasteto = math.ceil(mask_args["pasteto"]/100.0 * sample_size)
184
+ assert pastefrom < pasteto, "Paste From should be less than Paste To"
185
+ croplen = pasteto - pastefrom
186
+ if cropfrom + croplen > sample_size:
187
+ croplen = sample_size - cropfrom
188
+ cropto = cropfrom + croplen
189
+ pasteto = pastefrom + croplen
190
+ cutpaste = init_audio.new_zeros(init_audio.shape)
191
+ cutpaste[:, :, pastefrom:pasteto] = init_audio[:,:,cropfrom:cropto]
192
+ #print(cropfrom, cropto, pastefrom, pasteto)
193
+ init_audio = cutpaste
194
+ # Build a soft mask (list of floats 0 to 1, the size of the latent) from the given args
195
+ mask = build_mask(sample_size, mask_args)
196
+ mask = mask.to(device)
197
+ elif init_audio is not None and mask_args is None:
198
+ # variations
199
+ sampler_kwargs["sigma_max"] = init_noise_level
200
+ mask = None
201
+ else:
202
+ mask = None
203
+
204
+ # Now the generative AI part:
205
+ # k-diffusion denoising process go!
206
+ sampled = sample_k(model.model, noise, init_audio, mask, steps, **sampler_kwargs, **conditioning_tensors, **negative_conditioning_tensors, cfg_scale=cfg_scale, batch_cfg=True, rescale_cfg=True, device=device)
207
+
208
+ # v-diffusion:
209
+ #sampled = sample(model.model, noise, steps, 0, **conditioning_tensors, embedding_scale=cfg_scale)
210
+
211
+ # Denoising process done.
212
+ # If this is latent diffusion, decode latents back into audio
213
+ if model.pretransform is not None and not return_latents:
214
+ #cast sampled latents to pretransform dtype
215
+ sampled = sampled.to(next(model.pretransform.parameters()).dtype)
216
+ sampled = model.pretransform.decode(sampled)
217
+
218
+ # Return audio
219
+ return sampled
220
+
221
+ # builds a softmask given the parameters
222
+ # returns array of values 0 to 1, size sample_size, where 0 means noise / fresh generation, 1 means keep the input audio,
223
+ # and anything between is a mixture of old/new
224
+ # ideally 0.5 is half/half mixture but i haven't figured this out yet
225
+ def build_mask(sample_size, mask_args):
226
+ maskstart = math.floor(mask_args["maskstart"]/100.0 * sample_size)
227
+ maskend = math.ceil(mask_args["maskend"]/100.0 * sample_size)
228
+ softnessL = round(mask_args["softnessL"]/100.0 * sample_size)
229
+ softnessR = round(mask_args["softnessR"]/100.0 * sample_size)
230
+ marination = mask_args["marination"]
231
+ # use hann windows for softening the transition (i don't know if this is correct)
232
+ hannL = torch.hann_window(softnessL*2, periodic=False)[:softnessL]
233
+ hannR = torch.hann_window(softnessR*2, periodic=False)[softnessR:]
234
+ # build the mask.
235
+ mask = torch.zeros((sample_size))
236
+ mask[maskstart:maskend] = 1
237
+ mask[maskstart:maskstart+softnessL] = hannL
238
+ mask[maskend-softnessR:maskend] = hannR
239
+ # marination finishes the inpainting early in the denoising schedule, and lets audio get changed in the final rounds
240
+ if marination > 0:
241
+ mask = mask * (1-marination)
242
+ #print(mask)
243
+ return mask
stable_audio_tools/inference/sampling.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ from tqdm import trange
4
+
5
+ import k_diffusion as K
6
+
7
+ # Define the noise schedule and sampling loop
8
+ def get_alphas_sigmas(t):
9
+ """Returns the scaling factors for the clean image (alpha) and for the
10
+ noise (sigma), given a timestep."""
11
+ return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
12
+
13
+ def alpha_sigma_to_t(alpha, sigma):
14
+ """Returns a timestep, given the scaling factors for the clean image and for
15
+ the noise."""
16
+ return torch.atan2(sigma, alpha) / math.pi * 2
17
+
18
+ def t_to_alpha_sigma(t):
19
+ """Returns the scaling factors for the clean image and for the noise, given
20
+ a timestep."""
21
+ return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
22
+
23
+ @torch.no_grad()
24
+ def sample(model, x, steps, eta, **extra_args):
25
+ """Draws samples from a model given starting noise. v-diffusion"""
26
+ ts = x.new_ones([x.shape[0]])
27
+
28
+ # Create the noise schedule
29
+ t = torch.linspace(1, 0, steps + 1)[:-1]
30
+
31
+ alphas, sigmas = get_alphas_sigmas(t)
32
+
33
+ # The sampling loop
34
+ for i in trange(steps):
35
+
36
+ # Get the model output (v, the predicted velocity)
37
+ with torch.cuda.amp.autocast():
38
+ v = model(x, ts * t[i], **extra_args).float()
39
+
40
+ # Predict the noise and the denoised image
41
+ pred = x * alphas[i] - v * sigmas[i]
42
+ eps = x * sigmas[i] + v * alphas[i]
43
+
44
+ # If we are not on the last timestep, compute the noisy image for the
45
+ # next timestep.
46
+ if i < steps - 1:
47
+ # If eta > 0, adjust the scaling factor for the predicted noise
48
+ # downward according to the amount of additional noise to add
49
+ ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \
50
+ (1 - alphas[i]**2 / alphas[i + 1]**2).sqrt()
51
+ adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt()
52
+
53
+ # Recombine the predicted noise and predicted denoised image in the
54
+ # correct proportions for the next step
55
+ x = pred * alphas[i + 1] + eps * adjusted_sigma
56
+
57
+ # Add the correct amount of fresh noise
58
+ if eta:
59
+ x += torch.randn_like(x) * ddim_sigma
60
+
61
+ # If we are on the last timestep, output the denoised image
62
+ return pred
63
+
64
+ # Soft mask inpainting is just shrinking hard (binary) mask inpainting
65
+ # Given a float-valued soft mask (values between 0 and 1), get the binary mask for this particular step
66
+ def get_bmask(i, steps, mask):
67
+ strength = (i+1)/(steps)
68
+ # convert to binary mask
69
+ bmask = torch.where(mask<=strength,1,0)
70
+ return bmask
71
+
72
+ def make_cond_model_fn(model, cond_fn):
73
+ def cond_model_fn(x, sigma, **kwargs):
74
+ with torch.enable_grad():
75
+ x = x.detach().requires_grad_()
76
+ denoised = model(x, sigma, **kwargs)
77
+ cond_grad = cond_fn(x, sigma, denoised=denoised, **kwargs).detach()
78
+ cond_denoised = denoised.detach() + cond_grad * K.utils.append_dims(sigma**2, x.ndim)
79
+ return cond_denoised
80
+ return cond_model_fn
81
+
82
+ # Uses k-diffusion from https://github.com/crowsonkb/k-diffusion
83
+ # init_data is init_audio as latents (if this is latent diffusion)
84
+ # For sampling, set both init_data and mask to None
85
+ # For variations, set init_data
86
+ # For inpainting, set both init_data & mask
87
+ def sample_k(
88
+ model_fn,
89
+ noise,
90
+ init_data=None,
91
+ mask=None,
92
+ steps=100,
93
+ sampler_type="dpmpp-2m-sde",
94
+ sigma_min=0.5,
95
+ sigma_max=50,
96
+ rho=1.0, device="cuda",
97
+ callback=None,
98
+ cond_fn=None,
99
+ **extra_args
100
+ ):
101
+
102
+ denoiser = K.external.VDenoiser(model_fn)
103
+
104
+ if cond_fn is not None:
105
+ denoiser = make_cond_model_fn(denoiser, cond_fn)
106
+
107
+ # Make the list of sigmas. Sigma values are scalars related to the amount of noise each denoising step has
108
+ sigmas = K.sampling.get_sigmas_polyexponential(steps, sigma_min, sigma_max, rho, device=device)
109
+ # Scale the initial noise by sigma
110
+ noise = noise * sigmas[0]
111
+
112
+ wrapped_callback = callback
113
+
114
+ if mask is None and init_data is not None:
115
+ # VARIATION (no inpainting)
116
+ # set the initial latent to the init_data, and noise it with initial sigma
117
+ x = init_data + noise
118
+ elif mask is not None and init_data is not None:
119
+ # INPAINTING
120
+ bmask = get_bmask(0, steps, mask)
121
+ # initial noising
122
+ input_noised = init_data + noise
123
+ # set the initial latent to a mix of init_data and noise, based on step 0's binary mask
124
+ x = input_noised * bmask + noise * (1-bmask)
125
+ # define the inpainting callback function (Note: side effects, it mutates x)
126
+ # See https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py#L596C13-L596C105
127
+ # callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
128
+ # This is called immediately after `denoised = model(x, sigmas[i] * s_in, **extra_args)`
129
+ def inpainting_callback(args):
130
+ i = args["i"]
131
+ x = args["x"]
132
+ sigma = args["sigma"]
133
+ #denoised = args["denoised"]
134
+ # noise the init_data input with this step's appropriate amount of noise
135
+ input_noised = init_data + torch.randn_like(init_data) * sigma
136
+ # shrinking hard mask
137
+ bmask = get_bmask(i, steps, mask)
138
+ # mix input_noise with x, using binary mask
139
+ new_x = input_noised * bmask + x * (1-bmask)
140
+ # mutate x
141
+ x[:,:,:] = new_x[:,:,:]
142
+ # wrap together the inpainting callback and the user-submitted callback.
143
+ if callback is None:
144
+ wrapped_callback = inpainting_callback
145
+ else:
146
+ wrapped_callback = lambda args: (inpainting_callback(args), callback(args))
147
+ else:
148
+ # SAMPLING
149
+ # set the initial latent to noise
150
+ x = noise
151
+
152
+
153
+ with torch.cuda.amp.autocast():
154
+ if sampler_type == "k-heun":
155
+ return K.sampling.sample_heun(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
156
+ elif sampler_type == "k-lms":
157
+ return K.sampling.sample_lms(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
158
+ elif sampler_type == "k-dpmpp-2s-ancestral":
159
+ return K.sampling.sample_dpmpp_2s_ancestral(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
160
+ elif sampler_type == "k-dpm-2":
161
+ return K.sampling.sample_dpm_2(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
162
+ elif sampler_type == "k-dpm-fast":
163
+ return K.sampling.sample_dpm_fast(denoiser, x, sigma_min, sigma_max, steps, disable=False, callback=wrapped_callback, extra_args=extra_args)
164
+ elif sampler_type == "k-dpm-adaptive":
165
+ return K.sampling.sample_dpm_adaptive(denoiser, x, sigma_min, sigma_max, rtol=0.01, atol=0.01, disable=False, callback=wrapped_callback, extra_args=extra_args)
166
+ elif sampler_type == "dpmpp-2m-sde":
167
+ return K.sampling.sample_dpmpp_2m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
168
+ elif sampler_type == "dpmpp-3m-sde":
169
+ return K.sampling.sample_dpmpp_3m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
170
+
stable_audio_tools/inference/utils.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..data.utils import PadCrop
2
+
3
+ from torchaudio import transforms as T
4
+
5
+ def set_audio_channels(audio, target_channels):
6
+ if target_channels == 1:
7
+ # Convert to mono
8
+ audio = audio.mean(1, keepdim=True)
9
+ elif target_channels == 2:
10
+ # Convert to stereo
11
+ if audio.shape[1] == 1:
12
+ audio = audio.repeat(1, 2, 1)
13
+ elif audio.shape[1] > 2:
14
+ audio = audio[:, :2, :]
15
+ return audio
16
+
17
+ def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device):
18
+
19
+ audio = audio.to(device)
20
+
21
+ if in_sr != target_sr:
22
+ resample_tf = T.Resample(in_sr, target_sr).to(device)
23
+ audio = resample_tf(audio)
24
+
25
+ audio = PadCrop(target_length, randomize=False)(audio)
26
+
27
+ # Add batch dimension
28
+ if audio.dim() == 1:
29
+ audio = audio.unsqueeze(0).unsqueeze(0)
30
+ elif audio.dim() == 2:
31
+ audio = audio.unsqueeze(0)
32
+
33
+ audio = set_audio_channels(audio, target_channels)
34
+
35
+ return audio
stable_audio_tools/interface/__init__.py ADDED
File without changes
stable_audio_tools/interface/gradio.py ADDED
@@ -0,0 +1,782 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import numpy as np
3
+ import gradio as gr
4
+ import json
5
+ import torch
6
+ import torchaudio
7
+
8
+ from aeiou.viz import audio_spectrogram_image
9
+ from einops import rearrange
10
+ from safetensors.torch import load_file
11
+ from torch.nn import functional as F
12
+ from torchaudio import transforms as T
13
+
14
+ from ..inference.generation import generate_diffusion_cond, generate_diffusion_uncond
15
+ from ..models.factory import create_model_from_config
16
+ from ..models.pretrained import get_pretrained_model
17
+ from ..models.utils import load_ckpt_state_dict
18
+ from ..inference.utils import prepare_audio
19
+ from ..training.utils import copy_state_dict
20
+
21
+ model = None
22
+ sample_rate = 44100
23
+ sample_size = 524288
24
+
25
+ def load_model(model_config=None, model_ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, device="cuda", model_half=False):
26
+ global model, sample_rate, sample_size
27
+
28
+ if pretrained_name is not None:
29
+ print(f"Loading pretrained model {pretrained_name}")
30
+ model, model_config = get_pretrained_model(pretrained_name)
31
+
32
+ elif model_config is not None and model_ckpt_path is not None:
33
+ print(f"Creating model from config")
34
+ model = create_model_from_config(model_config)
35
+
36
+ print(f"Loading model checkpoint from {model_ckpt_path}")
37
+ # Load checkpoint
38
+ copy_state_dict(model, load_ckpt_state_dict(model_ckpt_path))
39
+ #model.load_state_dict(load_ckpt_state_dict(model_ckpt_path))
40
+
41
+ sample_rate = model_config["sample_rate"]
42
+ sample_size = model_config["sample_size"]
43
+
44
+ if pretransform_ckpt_path is not None:
45
+ print(f"Loading pretransform checkpoint from {pretransform_ckpt_path}")
46
+ model.pretransform.load_state_dict(load_ckpt_state_dict(pretransform_ckpt_path), strict=False)
47
+ print(f"Done loading pretransform")
48
+
49
+ model.to(device).eval().requires_grad_(False)
50
+
51
+ if model_half:
52
+ model.to(torch.float16)
53
+
54
+ print(f"Done loading model")
55
+
56
+ return model, model_config
57
+
58
+ def generate_cond(
59
+ prompt,
60
+ negative_prompt=None,
61
+ seconds_start=0,
62
+ seconds_total=30,
63
+ latitude = 0.0,
64
+ longitude = 0.0,
65
+ temperature = 0.0,
66
+ humidity = 0.0,
67
+ wind_speed = 0.0,
68
+ pressure = 0.0,
69
+ minutes_of_day = 0.0,
70
+ day_of_year = 0.0,
71
+ cfg_scale=6.0,
72
+ steps=250,
73
+ preview_every=None,
74
+ seed=-1,
75
+ sampler_type="dpmpp-2m-sde",
76
+ sigma_min=0.03,
77
+ sigma_max=50,
78
+ cfg_rescale=0.4,
79
+ use_init=False,
80
+ init_audio=None,
81
+ init_noise_level=1.0,
82
+ mask_cropfrom=None,
83
+ mask_pastefrom=None,
84
+ mask_pasteto=None,
85
+ mask_maskstart=None,
86
+ mask_maskend=None,
87
+ mask_softnessL=None,
88
+ mask_softnessR=None,
89
+ mask_marination=None,
90
+ batch_size=1
91
+ ):
92
+
93
+ if torch.cuda.is_available():
94
+ torch.cuda.empty_cache()
95
+ gc.collect()
96
+
97
+ print(f"Prompt: {prompt}")
98
+
99
+ global preview_images
100
+ preview_images = []
101
+ if preview_every == 0:
102
+ preview_every = None
103
+
104
+ # Return fake stereo audio
105
+ conditioning = [{"prompt": prompt, "latitude": -latitude, "longitude": longitude, "temperature": temperature, "humidity": humidity, "wind_speed": wind_speed, "pressure": pressure, "minutes_of_day": minutes_of_day,"day_of_year": day_of_year, "seconds_start":seconds_start, "seconds_total": seconds_total }] * batch_size
106
+
107
+ if negative_prompt:
108
+ negative_conditioning = [{"prompt": negative_prompt, "latitude": -latitude, "longitude": longitude, "temperature": temperature, "humidity": humidity, "wind_speed": wind_speed, "pressure": pressure, "minutes_of_day": minutes_of_day,"day_of_year": day_of_year, "seconds_start":seconds_start, "seconds_total": seconds_total}] * batch_size
109
+ else:
110
+ negative_conditioning = None
111
+
112
+ #Get the device from the model
113
+ device = next(model.parameters()).device
114
+
115
+ seed = int(seed)
116
+
117
+ if not use_init:
118
+ init_audio = None
119
+
120
+ input_sample_size = sample_size
121
+
122
+ if init_audio is not None:
123
+ in_sr, init_audio = init_audio
124
+ # Turn into torch tensor, converting from int16 to float32
125
+ init_audio = torch.from_numpy(init_audio).float().div(32767)
126
+
127
+ if init_audio.dim() == 1:
128
+ init_audio = init_audio.unsqueeze(0) # [1, n]
129
+ elif init_audio.dim() == 2:
130
+ init_audio = init_audio.transpose(0, 1) # [n, 2] -> [2, n]
131
+
132
+ if in_sr != sample_rate:
133
+ resample_tf = T.Resample(in_sr, sample_rate).to(init_audio.device)
134
+ init_audio = resample_tf(init_audio)
135
+
136
+ audio_length = init_audio.shape[-1]
137
+
138
+ if audio_length > sample_size:
139
+
140
+ input_sample_size = audio_length + (model.min_input_length - (audio_length % model.min_input_length)) % model.min_input_length
141
+
142
+ init_audio = (sample_rate, init_audio)
143
+
144
+ def progress_callback(callback_info):
145
+ global preview_images
146
+ denoised = callback_info["denoised"]
147
+ current_step = callback_info["i"]
148
+ sigma = callback_info["sigma"]
149
+
150
+ if (current_step - 1) % preview_every == 0:
151
+ if model.pretransform is not None:
152
+ denoised = model.pretransform.decode(denoised)
153
+ denoised = rearrange(denoised, "b d n -> d (b n)")
154
+ denoised = denoised.clamp(-1, 1).mul(32767).to(torch.int16).cpu()
155
+ audio_spectrogram = audio_spectrogram_image(denoised, sample_rate=sample_rate)
156
+ preview_images.append((audio_spectrogram, f"Step {current_step} sigma={sigma:.3f})"))
157
+
158
+ # If inpainting, send mask args
159
+ # This will definitely change in the future
160
+ if mask_cropfrom is not None:
161
+ mask_args = {
162
+ "cropfrom": mask_cropfrom,
163
+ "pastefrom": mask_pastefrom,
164
+ "pasteto": mask_pasteto,
165
+ "maskstart": mask_maskstart,
166
+ "maskend": mask_maskend,
167
+ "softnessL": mask_softnessL,
168
+ "softnessR": mask_softnessR,
169
+ "marination": mask_marination,
170
+ }
171
+ else:
172
+ mask_args = None
173
+
174
+ # Do the audio generation
175
+ audio = generate_diffusion_cond(
176
+ model,
177
+ conditioning=conditioning,
178
+ negative_conditioning=negative_conditioning,
179
+ steps=steps,
180
+ cfg_scale=cfg_scale,
181
+ batch_size=batch_size,
182
+ sample_size=input_sample_size,
183
+ sample_rate=sample_rate,
184
+ seed=seed,
185
+ device=device,
186
+ sampler_type=sampler_type,
187
+ sigma_min=sigma_min,
188
+ sigma_max=sigma_max,
189
+ init_audio=init_audio,
190
+ init_noise_level=init_noise_level,
191
+ mask_args = mask_args,
192
+ callback = progress_callback if preview_every is not None else None,
193
+ scale_phi = cfg_rescale
194
+ )
195
+
196
+ # Convert to WAV file
197
+ audio = rearrange(audio, "b d n -> d (b n)")
198
+ audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
199
+ torchaudio.save("output.wav", audio, sample_rate)
200
+
201
+ # Let's look at a nice spectrogram too
202
+ audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate)
203
+
204
+ return ("output.wav", [audio_spectrogram, *preview_images])
205
+
206
+ def generate_uncond(
207
+ steps=250,
208
+ seed=-1,
209
+ sampler_type="dpmpp-2m-sde",
210
+ sigma_min=0.03,
211
+ sigma_max=50,
212
+ use_init=False,
213
+ init_audio=None,
214
+ init_noise_level=1.0,
215
+ batch_size=1,
216
+ preview_every=None
217
+ ):
218
+
219
+ global preview_images
220
+
221
+ preview_images = []
222
+
223
+ if torch.cuda.is_available():
224
+ torch.cuda.empty_cache()
225
+ gc.collect()
226
+
227
+ #Get the device from the model
228
+ device = next(model.parameters()).device
229
+
230
+ seed = int(seed)
231
+
232
+ if not use_init:
233
+ init_audio = None
234
+
235
+ input_sample_size = sample_size
236
+
237
+ if init_audio is not None:
238
+ in_sr, init_audio = init_audio
239
+ # Turn into torch tensor, converting from int16 to float32
240
+ init_audio = torch.from_numpy(init_audio).float().div(32767)
241
+
242
+ if init_audio.dim() == 1:
243
+ init_audio = init_audio.unsqueeze(0) # [1, n]
244
+ elif init_audio.dim() == 2:
245
+ init_audio = init_audio.transpose(0, 1) # [n, 2] -> [2, n]
246
+
247
+ if in_sr != sample_rate:
248
+ resample_tf = T.Resample(in_sr, sample_rate).to(init_audio.device)
249
+ init_audio = resample_tf(init_audio)
250
+
251
+ audio_length = init_audio.shape[-1]
252
+
253
+ if audio_length > sample_size:
254
+
255
+ input_sample_size = audio_length + (model.min_input_length - (audio_length % model.min_input_length)) % model.min_input_length
256
+
257
+ init_audio = (sample_rate, init_audio)
258
+
259
+ def progress_callback(callback_info):
260
+ global preview_images
261
+ denoised = callback_info["denoised"]
262
+ current_step = callback_info["i"]
263
+ sigma = callback_info["sigma"]
264
+
265
+ if (current_step - 1) % preview_every == 0:
266
+
267
+ if model.pretransform is not None:
268
+ denoised = model.pretransform.decode(denoised)
269
+
270
+ denoised = rearrange(denoised, "b d n -> d (b n)")
271
+
272
+ denoised = denoised.clamp(-1, 1).mul(32767).to(torch.int16).cpu()
273
+
274
+ audio_spectrogram = audio_spectrogram_image(denoised, sample_rate=sample_rate)
275
+
276
+ preview_images.append((audio_spectrogram, f"Step {current_step} sigma={sigma:.3f})"))
277
+
278
+ audio = generate_diffusion_uncond(
279
+ model,
280
+ steps=steps,
281
+ batch_size=batch_size,
282
+ sample_size=input_sample_size,
283
+ seed=seed,
284
+ device=device,
285
+ sampler_type=sampler_type,
286
+ sigma_min=sigma_min,
287
+ sigma_max=sigma_max,
288
+ init_audio=init_audio,
289
+ init_noise_level=init_noise_level,
290
+ callback = progress_callback if preview_every is not None else None
291
+ )
292
+
293
+ audio = rearrange(audio, "b d n -> d (b n)")
294
+
295
+ audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
296
+
297
+ torchaudio.save("output.wav", audio, sample_rate)
298
+
299
+ audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate)
300
+
301
+ return ("output.wav", [audio_spectrogram, *preview_images])
302
+
303
+ def generate_lm(
304
+ temperature=1.0,
305
+ top_p=0.95,
306
+ top_k=0,
307
+ batch_size=1,
308
+ ):
309
+
310
+ if torch.cuda.is_available():
311
+ torch.cuda.empty_cache()
312
+ gc.collect()
313
+
314
+ #Get the device from the model
315
+ device = next(model.parameters()).device
316
+
317
+ audio = model.generate_audio(
318
+ batch_size=batch_size,
319
+ max_gen_len = sample_size//model.pretransform.downsampling_ratio,
320
+ conditioning=None,
321
+ temp=temperature,
322
+ top_p=top_p,
323
+ top_k=top_k,
324
+ use_cache=True
325
+ )
326
+
327
+ audio = rearrange(audio, "b d n -> d (b n)")
328
+
329
+ audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
330
+
331
+ torchaudio.save("output.wav", audio, sample_rate)
332
+
333
+ audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate)
334
+
335
+ return ("output.wav", [audio_spectrogram])
336
+
337
+
338
+ def create_uncond_sampling_ui(model_config):
339
+ generate_button = gr.Button("Generate", variant='primary', scale=1)
340
+
341
+ with gr.Row(equal_height=False):
342
+ with gr.Column():
343
+ with gr.Row():
344
+ # Steps slider
345
+ steps_slider = gr.Slider(minimum=1, maximum=500, step=1, value=100, label="Steps")
346
+
347
+ with gr.Accordion("Sampler params", open=False):
348
+
349
+ # Seed
350
+ seed_textbox = gr.Textbox(label="Seed (set to -1 for random seed)", value="-1")
351
+
352
+ # Sampler params
353
+ with gr.Row():
354
+ sampler_type_dropdown = gr.Dropdown(["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"], label="Sampler type", value="dpmpp-2m-sde")
355
+ sigma_min_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=0.03, label="Sigma min")
356
+ sigma_max_slider = gr.Slider(minimum=0.0, maximum=200.0, step=0.1, value=80, label="Sigma max")
357
+
358
+ with gr.Accordion("Init audio", open=False):
359
+ init_audio_checkbox = gr.Checkbox(label="Use init audio")
360
+ init_audio_input = gr.Audio(label="Init audio")
361
+ init_noise_level_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.01, value=0.1, label="Init noise level")
362
+
363
+ with gr.Column():
364
+ audio_output = gr.Audio(label="Output audio", interactive=False)
365
+ audio_spectrogram_output = gr.Gallery(label="Output spectrogram", show_label=False)
366
+ send_to_init_button = gr.Button("Send to init audio", scale=1)
367
+ send_to_init_button.click(fn=lambda audio: audio, inputs=[audio_output], outputs=[init_audio_input])
368
+
369
+ generate_button.click(fn=generate_uncond,
370
+ inputs=[
371
+ steps_slider,
372
+ seed_textbox,
373
+ sampler_type_dropdown,
374
+ sigma_min_slider,
375
+ sigma_max_slider,
376
+ init_audio_checkbox,
377
+ init_audio_input,
378
+ init_noise_level_slider,
379
+ ],
380
+ outputs=[
381
+ audio_output,
382
+ audio_spectrogram_output
383
+ ],
384
+ api_name="generate")
385
+ def create_conditioning_slider(min_val, max_val, label):
386
+ """
387
+ Create a Gradio slider for a given conditioning parameter.
388
+
389
+ Args:
390
+ - min_val: The minimum value for the slider.
391
+ - max_val: The maximum value for the slider.
392
+ - label: The label for the slider, which is displayed in the UI.
393
+
394
+ Returns:
395
+ - A gr.Slider object configured according to the provided parameters.
396
+ """
397
+ step = (max_val - min_val) / 1000
398
+ default_val = (max_val + min_val) / 2
399
+ print(f"Creating slider for {label} with min_val={min_val}, max_val={max_val}, step={step}, default_val={default_val}")
400
+ return gr.Slider(minimum=min_val, maximum=max_val, step=step, value=default_val, label=label)
401
+
402
+ def create_sampling_ui(model_config, inpainting=False):
403
+ with gr.Row():
404
+ with gr.Column(scale=6):
405
+ prompt = gr.Textbox(show_label=False, placeholder="Prompt")
406
+ negative_prompt = gr.Textbox(show_label=False, placeholder="Negative prompt")
407
+ generate_button = gr.Button("Generate", variant='primary', scale=1)
408
+
409
+ model_conditioning_config = model_config["model"].get("conditioning", None)
410
+
411
+ has_seconds_start = False
412
+ has_seconds_total = False
413
+
414
+ if model_conditioning_config is not None:
415
+ for conditioning_config in model_conditioning_config["configs"]:
416
+ if conditioning_config["id"] == "seconds_start":
417
+ has_seconds_start = True
418
+ if conditioning_config["id"] == "seconds_total":
419
+ has_seconds_total = True
420
+
421
+ with gr.Row(equal_height=False):
422
+ with gr.Column():
423
+ with gr.Row():
424
+
425
+ seconds_start_slider = gr.Slider(minimum=0, maximum=512, step=1, value=0, label="Seconds start", visible=has_seconds_start)
426
+
427
+ seconds_total_slider = gr.Slider(minimum=0, maximum=512, step=1, value=sample_size//sample_rate, label="Seconds total", visible=has_seconds_total)
428
+
429
+ with gr.Row():
430
+ # Steps slider
431
+ steps_slider = gr.Slider(minimum=1, maximum=500, step=1, value=100, label="Steps")
432
+
433
+ # Preview Every slider
434
+ preview_every_slider = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Preview Every")
435
+
436
+ # CFG scale
437
+ cfg_scale_slider = gr.Slider(minimum=0.0, maximum=25.0, step=0.1, value=0.2, label="CFG scale")
438
+
439
+ with gr.Accordion("Climate and location", open=True):
440
+ latitude_config = next((item for item in model_conditioning_config["configs"] if item["id"] == "latitude"), None)
441
+ if latitude_config:
442
+ latitude_slider = create_conditioning_slider(
443
+ min_val=latitude_config["config"]["min_val"],
444
+ max_val=latitude_config["config"]["max_val"],
445
+ label="latitude")
446
+
447
+ longitude_config = next((item for item in model_conditioning_config["configs"] if item["id"] == "longitude"), None)
448
+ if longitude_config:
449
+ longitude_slider = create_conditioning_slider(
450
+ min_val=longitude_config["config"]["min_val"],
451
+ max_val=longitude_config["config"]["max_val"],
452
+ label="longitude")
453
+
454
+ temperature_config = next((item for item in model_conditioning_config["configs"] if item["id"] == "temperature"), None)
455
+ if temperature_config:
456
+ temperature_slider = create_conditioning_slider(
457
+ min_val=temperature_config["config"]["min_val"],
458
+ max_val=temperature_config["config"]["max_val"],
459
+ label="temperature")
460
+
461
+ humidity_config = next((item for item in model_conditioning_config["configs"] if item["id"] == "humidity"), None)
462
+ if humidity_config:
463
+ humidity_slider = create_conditioning_slider(
464
+ min_val=humidity_config["config"]["min_val"],
465
+ max_val=humidity_config["config"]["max_val"],
466
+ label="humidity")
467
+
468
+ wind_speed_config = next((item for item in model_conditioning_config["configs"] if item["id"] == "wind_speed"), None)
469
+ if wind_speed_config:
470
+ wind_speed_slider = create_conditioning_slider(
471
+ min_val=wind_speed_config["config"]["min_val"],
472
+ max_val=wind_speed_config["config"]["max_val"],
473
+ label="wind_speed")
474
+
475
+ pressure_config = next((item for item in model_conditioning_config["configs"] if item["id"] == "pressure"), None)
476
+ if pressure_config:
477
+ pressure_slider = create_conditioning_slider(
478
+ min_val=pressure_config["config"]["min_val"],
479
+ max_val=pressure_config["config"]["max_val"],
480
+ label="pressure")
481
+
482
+ minutes_of_day_config = next((item for item in model_conditioning_config["configs"] if item["id"] == "minutes_of_day"), None)
483
+ if minutes_of_day_config:
484
+ minutes_of_day_slider = create_conditioning_slider(
485
+ min_val=minutes_of_day_config["config"]["min_val"],
486
+ max_val=minutes_of_day_config["config"]["max_val"],
487
+ label="minutes_of_day")
488
+
489
+ day_of_year_config = next((item for item in model_conditioning_config["configs"] if item["id"] == "day_of_year"), None)
490
+ if day_of_year_config:
491
+ day_of_year_slider = create_conditioning_slider(
492
+ min_val=day_of_year_config["config"]["min_val"],
493
+ max_val=day_of_year_config["config"]["max_val"],
494
+ label="Day of year")
495
+
496
+ with gr.Accordion("Sampler params", open=False):
497
+
498
+ # Seed
499
+ seed_textbox = gr.Textbox(label="Seed (set to -1 for random seed)", value="-1")
500
+
501
+ # Sampler params
502
+ with gr.Row():
503
+ sampler_type_dropdown = gr.Dropdown(["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"], label="Sampler type", value="dpmpp-2m-sde")
504
+ sigma_min_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=0.03, label="Sigma min")
505
+ sigma_max_slider = gr.Slider(minimum=0.0, maximum=200.0, step=0.1, value=80, label="Sigma max")
506
+ cfg_rescale_slider = gr.Slider(minimum=0.0, maximum=1, step=0.01, value=0.2, label="CFG rescale amount")
507
+
508
+ if inpainting:
509
+ # Inpainting Tab
510
+ with gr.Accordion("Inpainting", open=False):
511
+ sigma_max_slider.maximum=1000
512
+
513
+ init_audio_checkbox = gr.Checkbox(label="Do inpainting")
514
+ init_audio_input = gr.Audio(label="Init audio")
515
+ init_noise_level_slider = gr.Slider(minimum=0.1, maximum=100.0, step=0.1, value=80, label="Init audio noise level", visible=False) # hide this
516
+
517
+ mask_cropfrom_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=0, label="Crop From %")
518
+ mask_pastefrom_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=0, label="Paste From %")
519
+ mask_pasteto_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=100, label="Paste To %")
520
+
521
+ mask_maskstart_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=50, label="Mask Start %")
522
+ mask_maskend_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=100, label="Mask End %")
523
+ mask_softnessL_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=0, label="Softmask Left Crossfade Length %")
524
+ mask_softnessR_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=0, label="Softmask Right Crossfade Length %")
525
+ mask_marination_slider = gr.Slider(minimum=0.0, maximum=1, step=0.0001, value=0, label="Marination level", visible=False) # still working on the usefulness of this
526
+
527
+ inputs = [prompt,
528
+ negative_prompt,
529
+ seconds_start_slider,
530
+ seconds_total_slider,
531
+ latitude_slider,
532
+ longitude_slider,
533
+ temperature_slider,
534
+ humidity_slider,
535
+ wind_speed_slider,
536
+ pressure_slider,
537
+ minutes_of_day_slider,
538
+ day_of_year_slider,
539
+ cfg_scale_slider,
540
+ steps_slider,
541
+ preview_every_slider,
542
+ seed_textbox,
543
+ sampler_type_dropdown,
544
+ sigma_min_slider,
545
+ sigma_max_slider,
546
+ cfg_rescale_slider,
547
+ init_audio_checkbox,
548
+ init_audio_input,
549
+ init_noise_level_slider,
550
+ mask_cropfrom_slider,
551
+ mask_pastefrom_slider,
552
+ mask_pasteto_slider,
553
+ mask_maskstart_slider,
554
+ mask_maskend_slider,
555
+ mask_softnessL_slider,
556
+ mask_softnessR_slider,
557
+ mask_marination_slider
558
+ ]
559
+ else:
560
+ # Default generation tab
561
+ with gr.Accordion("Init audio", open=False):
562
+ init_audio_checkbox = gr.Checkbox(label="Use init audio")
563
+ init_audio_input = gr.Audio(label="Init audio")
564
+ init_noise_level_slider = gr.Slider(minimum=0.1, maximum=100.0, step=0.01, value=0.1, label="Init noise level")
565
+
566
+ inputs = [prompt,
567
+ negative_prompt,
568
+ seconds_start_slider,
569
+ seconds_total_slider,
570
+ latitude_slider,
571
+ longitude_slider,
572
+ temperature_slider,
573
+ humidity_slider,
574
+ wind_speed_slider,
575
+ pressure_slider,
576
+ minutes_of_day_slider,
577
+ day_of_year_slider,
578
+ cfg_scale_slider,
579
+ steps_slider,
580
+ preview_every_slider,
581
+ seed_textbox,
582
+ sampler_type_dropdown,
583
+ sigma_min_slider,
584
+ sigma_max_slider,
585
+ cfg_rescale_slider,
586
+ init_audio_checkbox,
587
+ init_audio_input,
588
+ init_noise_level_slider
589
+ ]
590
+
591
+ with gr.Column():
592
+ audio_output = gr.Audio(label="Output audio", interactive=False)
593
+ audio_spectrogram_output = gr.Gallery(label="Output spectrogram", show_label=False)
594
+ send_to_init_button = gr.Button("Send to init audio", scale=1)
595
+ send_to_init_button.click(fn=lambda audio: audio, inputs=[audio_output], outputs=[init_audio_input])
596
+
597
+ generate_button.click(fn=generate_cond,
598
+ inputs=inputs,
599
+ outputs=[
600
+ audio_output,
601
+ audio_spectrogram_output
602
+ ],
603
+ api_name="generate")
604
+
605
+
606
+ def create_txt2audio_ui(model_config):
607
+ with gr.Blocks() as ui:
608
+ with gr.Tab("Generation"):
609
+ create_sampling_ui(model_config)
610
+ with gr.Tab("Inpainting"):
611
+ create_sampling_ui(model_config, inpainting=True)
612
+ return ui
613
+
614
+ def create_diffusion_uncond_ui(model_config):
615
+ with gr.Blocks() as ui:
616
+ create_uncond_sampling_ui(model_config)
617
+
618
+ return ui
619
+
620
+ def autoencoder_process(audio, latent_noise, n_quantizers):
621
+ if torch.cuda.is_available():
622
+ torch.cuda.empty_cache()
623
+ gc.collect()
624
+
625
+ #Get the device from the model
626
+ device = next(model.parameters()).device
627
+
628
+ in_sr, audio = audio
629
+
630
+ audio = torch.from_numpy(audio).float().div(32767).to(device)
631
+
632
+ if audio.dim() == 1:
633
+ audio = audio.unsqueeze(0)
634
+ else:
635
+ audio = audio.transpose(0, 1)
636
+
637
+ audio = model.preprocess_audio_for_encoder(audio, in_sr)
638
+ # Note: If you need to do chunked encoding, to reduce VRAM,
639
+ # then add these arguments to encode_audio and decode_audio: chunked=True, overlap=32, chunk_size=128
640
+ # To turn it off, do chunked=False
641
+ # Optimal overlap and chunk_size values will depend on the model.
642
+ # See encode_audio & decode_audio in autoencoders.py for more info
643
+ # Get dtype of model
644
+ dtype = next(model.parameters()).dtype
645
+
646
+ audio = audio.to(dtype)
647
+
648
+ if n_quantizers > 0:
649
+ latents = model.encode_audio(audio, chunked=False, n_quantizers=n_quantizers)
650
+ else:
651
+ latents = model.encode_audio(audio, chunked=False)
652
+
653
+ if latent_noise > 0:
654
+ latents = latents + torch.randn_like(latents) * latent_noise
655
+
656
+ audio = model.decode_audio(latents, chunked=False)
657
+
658
+ audio = rearrange(audio, "b d n -> d (b n)")
659
+
660
+ audio = audio.to(torch.float32).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
661
+
662
+ torchaudio.save("output.wav", audio, sample_rate)
663
+
664
+ return "output.wav"
665
+
666
+ def create_autoencoder_ui(model_config):
667
+
668
+ is_dac_rvq = "model" in model_config and "bottleneck" in model_config["model"] and model_config["model"]["bottleneck"]["type"] in ["dac_rvq","dac_rvq_vae"]
669
+
670
+ if is_dac_rvq:
671
+ n_quantizers = model_config["model"]["bottleneck"]["config"]["n_codebooks"]
672
+ else:
673
+ n_quantizers = 0
674
+
675
+ with gr.Blocks() as ui:
676
+ input_audio = gr.Audio(label="Input audio")
677
+ output_audio = gr.Audio(label="Output audio", interactive=False)
678
+ n_quantizers_slider = gr.Slider(minimum=1, maximum=n_quantizers, step=1, value=n_quantizers, label="# quantizers", visible=is_dac_rvq)
679
+ latent_noise_slider = gr.Slider(minimum=0.0, maximum=10.0, step=0.001, value=0.0, label="Add latent noise")
680
+ process_button = gr.Button("Process", variant='primary', scale=1)
681
+ process_button.click(fn=autoencoder_process, inputs=[input_audio, latent_noise_slider, n_quantizers_slider], outputs=output_audio, api_name="process")
682
+
683
+ return ui
684
+
685
+ def diffusion_prior_process(audio, steps, sampler_type, sigma_min, sigma_max):
686
+
687
+ if torch.cuda.is_available():
688
+ torch.cuda.empty_cache()
689
+ gc.collect()
690
+
691
+ #Get the device from the model
692
+ device = next(model.parameters()).device
693
+
694
+ in_sr, audio = audio
695
+
696
+ audio = torch.from_numpy(audio).float().div(32767).to(device)
697
+
698
+ if audio.dim() == 1:
699
+ audio = audio.unsqueeze(0) # [1, n]
700
+ elif audio.dim() == 2:
701
+ audio = audio.transpose(0, 1) # [n, 2] -> [2, n]
702
+
703
+ audio = audio.unsqueeze(0)
704
+
705
+ audio = model.stereoize(audio, in_sr, steps, sampler_kwargs={"sampler_type": sampler_type, "sigma_min": sigma_min, "sigma_max": sigma_max})
706
+
707
+ audio = rearrange(audio, "b d n -> d (b n)")
708
+
709
+ audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
710
+
711
+ torchaudio.save("output.wav", audio, sample_rate)
712
+
713
+ return "output.wav"
714
+
715
+ def create_diffusion_prior_ui(model_config):
716
+ with gr.Blocks() as ui:
717
+ input_audio = gr.Audio(label="Input audio")
718
+ output_audio = gr.Audio(label="Output audio", interactive=False)
719
+ # Sampler params
720
+ with gr.Row():
721
+ steps_slider = gr.Slider(minimum=1, maximum=500, step=1, value=100, label="Steps")
722
+ sampler_type_dropdown = gr.Dropdown(["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"], label="Sampler type", value="dpmpp-2m-sde")
723
+ sigma_min_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=0.03, label="Sigma min")
724
+ sigma_max_slider = gr.Slider(minimum=0.0, maximum=200.0, step=0.1, value=80, label="Sigma max")
725
+ process_button = gr.Button("Process", variant='primary', scale=1)
726
+ process_button.click(fn=diffusion_prior_process, inputs=[input_audio, steps_slider, sampler_type_dropdown, sigma_min_slider, sigma_max_slider], outputs=output_audio, api_name="process")
727
+
728
+ return ui
729
+
730
+ def create_lm_ui(model_config):
731
+ with gr.Blocks() as ui:
732
+ output_audio = gr.Audio(label="Output audio", interactive=False)
733
+ audio_spectrogram_output = gr.Gallery(label="Output spectrogram", show_label=False)
734
+
735
+ # Sampling params
736
+ with gr.Row():
737
+ temperature_slider = gr.Slider(minimum=0, maximum=5, step=0.01, value=1.0, label="Temperature")
738
+ top_p_slider = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.95, label="Top p")
739
+ top_k_slider = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Top k")
740
+
741
+ generate_button = gr.Button("Generate", variant='primary', scale=1)
742
+ generate_button.click(
743
+ fn=generate_lm,
744
+ inputs=[
745
+ temperature_slider,
746
+ top_p_slider,
747
+ top_k_slider
748
+ ],
749
+ outputs=[output_audio, audio_spectrogram_output],
750
+ api_name="generate"
751
+ )
752
+
753
+ return ui
754
+
755
+ def create_ui(model_config_path=None, ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, model_half=False):
756
+
757
+ assert (pretrained_name is not None) ^ (model_config_path is not None and ckpt_path is not None), "Must specify either pretrained name or provide a model config and checkpoint, but not both"
758
+
759
+ if model_config_path is not None:
760
+ # Load config from json file
761
+ with open(model_config_path) as f:
762
+ model_config = json.load(f)
763
+ else:
764
+ model_config = None
765
+
766
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
767
+ _, model_config = load_model(model_config, ckpt_path, pretrained_name=pretrained_name, pretransform_ckpt_path=pretransform_ckpt_path, model_half=model_half, device=device)
768
+
769
+ model_type = model_config["model_type"]
770
+
771
+ if model_type == "diffusion_cond":
772
+ ui = create_txt2audio_ui(model_config)
773
+ elif model_type == "diffusion_uncond":
774
+ ui = create_diffusion_uncond_ui(model_config)
775
+ elif model_type == "autoencoder" or model_type == "diffusion_autoencoder":
776
+ ui = create_autoencoder_ui(model_config)
777
+ elif model_type == "diffusion_prior":
778
+ ui = create_diffusion_prior_ui(model_config)
779
+ elif model_type == "lm":
780
+ ui = create_lm_ui(model_config)
781
+
782
+ return ui
stable_audio_tools/interface/testing.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import numpy as np
3
+ import json
4
+ import torch
5
+ import torchaudio
6
+ import os
7
+ import re
8
+
9
+ from aeiou.viz import audio_spectrogram_image
10
+ from einops import rearrange
11
+ from safetensors.torch import load_file
12
+ from torch.nn import functional as F
13
+ from torchaudio import transforms as T
14
+
15
+ from ..inference.generation import generate_diffusion_cond, generate_diffusion_uncond
16
+ from ..models.factory import create_model_from_config
17
+ from ..models.pretrained import get_pretrained_model
18
+ from ..models.utils import load_ckpt_state_dict
19
+ from ..inference.utils import prepare_audio
20
+ from ..training.utils import copy_state_dict
21
+
22
+
23
+ model = None
24
+ sample_rate = 44100
25
+ sample_size = 524288
26
+
27
+ def load_model(model_config=None, model_ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, device="cuda", model_half=False):
28
+ global model, sample_rate, sample_size
29
+
30
+ if pretrained_name is not None:
31
+ print(f"Loading pretrained model {pretrained_name}")
32
+ model, model_config = get_pretrained_model(pretrained_name)
33
+
34
+ elif model_config is not None and model_ckpt_path is not None:
35
+ print(f"Creating model from config")
36
+ model = create_model_from_config(model_config)
37
+
38
+ print(f"Loading model checkpoint from {model_ckpt_path}")
39
+ # Load checkpoint
40
+ copy_state_dict(model, load_ckpt_state_dict(model_ckpt_path))
41
+ #model.load_state_dict(load_ckpt_state_dict(model_ckpt_path))
42
+
43
+ sample_rate = model_config["sample_rate"]
44
+ sample_size = model_config["sample_size"]
45
+
46
+ if pretransform_ckpt_path is not None:
47
+ print(f"Loading pretransform checkpoint from {pretransform_ckpt_path}")
48
+ model.pretransform.load_state_dict(load_ckpt_state_dict(pretransform_ckpt_path), strict=False)
49
+ print(f"Done loading pretransform")
50
+
51
+ model.to(device).eval().requires_grad_(False)
52
+
53
+ if model_half:
54
+ model.to(torch.float16)
55
+
56
+ print(f"Done loading model")
57
+
58
+ return model, model_config
59
+
60
+ def generate_cond_with_path(
61
+ prompt,
62
+ negative_prompt=None,
63
+ seconds_start=0,
64
+ seconds_total=30,
65
+ latitude = 0.0,
66
+ longitude = 0.0,
67
+ temperature = 0.0,
68
+ humidity = 0.0,
69
+ wind_speed = 0.0,
70
+ pressure = 0.0,
71
+ minutes_of_day = 0.0,
72
+ day_of_year = 0.0,
73
+ cfg_scale=6.0,
74
+ steps=250,
75
+ preview_every=None,
76
+ seed=-1,
77
+ sampler_type="dpmpp-2m-sde",
78
+ sigma_min=0.03,
79
+ sigma_max=50,
80
+ cfg_rescale=0.4,
81
+ use_init=False,
82
+ init_audio=None,
83
+ init_noise_level=1.0,
84
+ mask_cropfrom=None,
85
+ mask_pastefrom=None,
86
+ mask_pasteto=None,
87
+ mask_maskstart=None,
88
+ mask_maskend=None,
89
+ mask_softnessL=None,
90
+ mask_softnessR=None,
91
+ mask_marination=None,
92
+ batch_size=1,
93
+ destination_folder=None,
94
+ file_name=None
95
+ ):
96
+
97
+ if torch.cuda.is_available():
98
+ torch.cuda.empty_cache()
99
+ gc.collect()
100
+
101
+ print(f"Prompt: {prompt}")
102
+
103
+ global preview_images
104
+ preview_images = []
105
+ if preview_every == 0:
106
+ preview_every = None
107
+
108
+ # Return fake stereo audio
109
+ conditioning = [{"prompt": prompt, "latitude": latitude, "longitude": longitude, "temperature": temperature, "humidity": humidity, "wind_speed": wind_speed, "pressure": pressure, "minutes_of_day": minutes_of_day,"day_of_year": day_of_year, "seconds_start":seconds_start, "seconds_total": seconds_total }] * batch_size
110
+
111
+ if negative_prompt:
112
+ negative_conditioning = [{"prompt": negative_prompt, "latitude": latitude, "longitude": longitude, "temperature": temperature, "humidity": humidity, "wind_speed": wind_speed, "pressure": pressure, "minutes_of_day": minutes_of_day,"day_of_year": day_of_year, "seconds_start":seconds_start, "seconds_total": seconds_total}] * batch_size
113
+ else:
114
+ negative_conditioning = None
115
+
116
+ #Get the device from the model
117
+ device = next(model.parameters()).device
118
+
119
+ seed = int(seed)
120
+
121
+ if not use_init:
122
+ init_audio = None
123
+
124
+ input_sample_size = sample_size
125
+
126
+ if init_audio is not None:
127
+ in_sr, init_audio = init_audio
128
+ # Turn into torch tensor, converting from int16 to float32
129
+ init_audio = torch.from_numpy(init_audio).float().div(32767)
130
+
131
+ if init_audio.dim() == 1:
132
+ init_audio = init_audio.unsqueeze(0) # [1, n]
133
+ elif init_audio.dim() == 2:
134
+ init_audio = init_audio.transpose(0, 1) # [n, 2] -> [2, n]
135
+
136
+ if in_sr != sample_rate:
137
+ resample_tf = T.Resample(in_sr, sample_rate).to(init_audio.device)
138
+ init_audio = resample_tf(init_audio)
139
+
140
+ audio_length = init_audio.shape[-1]
141
+
142
+ if audio_length > sample_size:
143
+
144
+ input_sample_size = audio_length + (model.min_input_length - (audio_length % model.min_input_length)) % model.min_input_length
145
+
146
+ init_audio = (sample_rate, init_audio)
147
+
148
+ def progress_callback(callback_info):
149
+ global preview_images
150
+ denoised = callback_info["denoised"]
151
+ current_step = callback_info["i"]
152
+ sigma = callback_info["sigma"]
153
+
154
+ if (current_step - 1) % preview_every == 0:
155
+ if model.pretransform is not None:
156
+ denoised = model.pretransform.decode(denoised)
157
+ denoised = rearrange(denoised, "b d n -> d (b n)")
158
+ denoised = denoised.clamp(-1, 1).mul(32767).to(torch.int16).cpu()
159
+ audio_spectrogram = audio_spectrogram_image(denoised, sample_rate=sample_rate)
160
+ preview_images.append((audio_spectrogram, f"Step {current_step} sigma={sigma:.3f})"))
161
+
162
+ # If inpainting, send mask args
163
+ # This will definitely change in the future
164
+ if mask_cropfrom is not None:
165
+ mask_args = {
166
+ "cropfrom": mask_cropfrom,
167
+ "pastefrom": mask_pastefrom,
168
+ "pasteto": mask_pasteto,
169
+ "maskstart": mask_maskstart,
170
+ "maskend": mask_maskend,
171
+ "softnessL": mask_softnessL,
172
+ "softnessR": mask_softnessR,
173
+ "marination": mask_marination,
174
+ }
175
+ else:
176
+ mask_args = None
177
+
178
+ # Do the audio generation
179
+ audio = generate_diffusion_cond(
180
+ model,
181
+ conditioning=conditioning,
182
+ negative_conditioning=negative_conditioning,
183
+ steps=steps,
184
+ cfg_scale=cfg_scale,
185
+ batch_size=batch_size,
186
+ sample_size=input_sample_size,
187
+ sample_rate=sample_rate,
188
+ seed=seed,
189
+ device=device,
190
+ sampler_type=sampler_type,
191
+ sigma_min=sigma_min,
192
+ sigma_max=sigma_max,
193
+ init_audio=init_audio,
194
+ init_noise_level=init_noise_level,
195
+ mask_args = mask_args,
196
+ callback = progress_callback if preview_every is not None else None,
197
+ scale_phi = cfg_rescale
198
+ )
199
+
200
+ # Convert to WAV file
201
+ audio = rearrange(audio, "b d n -> d (b n)")
202
+ audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
203
+ #save to the desired folder with the required filename and add the .wav extension
204
+
205
+ if destination_folder is not None and file_name is not None:
206
+ torchaudio.save(f"{destination_folder}/{file_name}.wav", audio, sample_rate)
207
+
208
+
209
+
210
+ # Let's look at a nice spectrogram too
211
+ # audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate)
212
+
213
+ # return ("output.wav", [audio_spectrogram, *preview_images])
214
+
215
+
216
+
217
+ def generate_lm(
218
+ temperature=1.0,
219
+ top_p=0.95,
220
+ top_k=0,
221
+ batch_size=1,
222
+ ):
223
+
224
+ if torch.cuda.is_available():
225
+ torch.cuda.empty_cache()
226
+ gc.collect()
227
+
228
+ #Get the device from the model
229
+ device = next(model.parameters()).device
230
+
231
+ audio = model.generate_audio(
232
+ batch_size=batch_size,
233
+ max_gen_len = sample_size//model.pretransform.downsampling_ratio,
234
+ conditioning=None,
235
+ temp=temperature,
236
+ top_p=top_p,
237
+ top_k=top_k,
238
+ use_cache=True
239
+ )
240
+
241
+ audio = rearrange(audio, "b d n -> d (b n)")
242
+
243
+ audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
244
+
245
+ torchaudio.save("output.wav", audio, sample_rate)
246
+
247
+ audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate)
248
+
249
+ return ("output.wav", [audio_spectrogram])
250
+
251
+
252
+
253
+
254
+ def autoencoder_process(audio, latent_noise, n_quantizers):
255
+ if torch.cuda.is_available():
256
+ torch.cuda.empty_cache()
257
+ gc.collect()
258
+
259
+ #Get the device from the model
260
+ device = next(model.parameters()).device
261
+
262
+ in_sr, audio = audio
263
+
264
+ audio = torch.from_numpy(audio).float().div(32767).to(device)
265
+
266
+ if audio.dim() == 1:
267
+ audio = audio.unsqueeze(0)
268
+ else:
269
+ audio = audio.transpose(0, 1)
270
+
271
+ audio = model.preprocess_audio_for_encoder(audio, in_sr)
272
+ # Note: If you need to do chunked encoding, to reduce VRAM,
273
+ # then add these arguments to encode_audio and decode_audio: chunked=True, overlap=32, chunk_size=128
274
+ # To turn it off, do chunked=False
275
+ # Optimal overlap and chunk_size values will depend on the model.
276
+ # See encode_audio & decode_audio in autoencoders.py for more info
277
+ # Get dtype of model
278
+ dtype = next(model.parameters()).dtype
279
+
280
+ audio = audio.to(dtype)
281
+
282
+ if n_quantizers > 0:
283
+ latents = model.encode_audio(audio, chunked=False, n_quantizers=n_quantizers)
284
+ else:
285
+ latents = model.encode_audio(audio, chunked=False)
286
+
287
+ if latent_noise > 0:
288
+ latents = latents + torch.randn_like(latents) * latent_noise
289
+
290
+ audio = model.decode_audio(latents, chunked=False)
291
+
292
+ audio = rearrange(audio, "b d n -> d (b n)")
293
+
294
+ audio = audio.to(torch.float32).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
295
+
296
+ torchaudio.save("output.wav", audio, sample_rate)
297
+
298
+ return "output.wav"
299
+
300
+ def load_and_generate(model_path, json_dir, output_dir):
301
+ """Load JSON files and generate audio for each set of conditions."""
302
+ # List all files in the json_dir
303
+ files = os.listdir(json_dir)
304
+
305
+ # Filter for JSON files
306
+ json_files = [file for file in files if file.endswith('.json')]
307
+
308
+ if not json_files:
309
+ print(f"No JSON files found in {json_dir}. Please check the directory path and file permissions.")
310
+ return
311
+
312
+ for json_filename in json_files:
313
+ json_file_path = os.path.join(json_dir, json_filename)
314
+
315
+ try:
316
+ with open(json_file_path, 'r') as file:
317
+ data = json.load(file)
318
+ except Exception as e:
319
+ print(f"Failed to read or parse {json_file_path}: {e}")
320
+ continue
321
+
322
+ # Print the JSON path
323
+ print(json_file_path)
324
+
325
+ # Extract conditions from JSON
326
+ conditions = {
327
+ 'birdSpecies': data['birdSpecies'],
328
+ 'latitude': data['coord']['lat'],
329
+ 'longitude': data['coord']['lon'],
330
+ 'temperature': data['main']['temp'],
331
+ 'humidity': data['main']['humidity'],
332
+ 'pressure': data['main']['pressure'],
333
+ 'wind_speed': data['wind']['speed'],
334
+ 'day_of_year': data['dayOfYear'],
335
+ 'minutes_of_day': data['minutesOfDay']
336
+ }
337
+
338
+ # Extract base filename components
339
+ step_number = re.search(r'step=(\d+)', model_path).group(1)
340
+ bird_species = conditions['birdSpecies'].replace(' ', '_')
341
+ base_filename = f"{bird_species}_{os.path.splitext(json_filename)[0]}_{step_number}_cfg_scale_"
342
+
343
+
344
+
345
+ #An array of cfg scale values to test
346
+ cfg_scales = [1.8, 2.5, 4.0, 5.0, 12.0]
347
+
348
+ # Generate audio we do this 4 times with a loop
349
+ for scale in cfg_scales:
350
+ generate_cond_with_path(prompt = "",
351
+ negative_prompt="",
352
+ seconds_start=0,
353
+ seconds_total=22,
354
+ latitude = conditions['latitude'],
355
+ longitude = conditions['longitude'],
356
+ temperature = conditions['temperature'],
357
+ humidity = conditions['humidity'],
358
+ wind_speed = conditions['wind_speed'],
359
+ pressure = conditions['pressure'],
360
+ minutes_of_day = conditions['minutes_of_day'],
361
+ day_of_year = conditions['day_of_year'],
362
+ cfg_scale=scale,
363
+ steps=250,
364
+ preview_every=None,
365
+ seed=-1,
366
+ sampler_type="dpmpp-2m-sde",
367
+ sigma_min=0.03,
368
+ sigma_max=50,
369
+ cfg_rescale=0.4,
370
+ use_init=False,
371
+ init_audio=None,
372
+ init_noise_level=1.0,
373
+ mask_cropfrom=None,
374
+ mask_pastefrom=None,
375
+ mask_pasteto=None,
376
+ mask_maskstart=None,
377
+ mask_maskend=None,
378
+ mask_softnessL=None,
379
+ mask_softnessR=None,
380
+ mask_marination=None,
381
+ batch_size=1,
382
+ destination_folder=output_dir,
383
+ file_name=base_filename + str(scale))
384
+
385
+
386
+ def runTests(model_config_path=None, ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, model_half=False, json_dir=None, output_dir=None):
387
+ assert (pretrained_name is not None) ^ (model_config_path is not None and ckpt_path is not None), "Must specify either pretrained name or provide a model config and checkpoint, but not both"
388
+
389
+ if model_config_path is not None:
390
+ # Load config from json file
391
+ with open(model_config_path) as f:
392
+ model_config = json.load(f)
393
+ else:
394
+ model_config = None
395
+
396
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
397
+ _, model_config = load_model(model_config, ckpt_path, pretrained_name=pretrained_name, pretransform_ckpt_path=pretransform_ckpt_path, model_half=model_half, device=device)
398
+
399
+ # Ensure output directory exists- os.makedirs(args.output_dir, exist_ok=True)
400
+
401
+ # Process all JSON files and generate audio
402
+ load_and_generate(ckpt_path, json_dir, output_dir)
403
+
404
+
405
+
406
+
407
+
408
+
409
+
stable_audio_tools/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .factory import create_model_from_config, create_model_from_config_path
stable_audio_tools/models/adp.py ADDED
@@ -0,0 +1,1588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied and modified from https://github.com/archinetai/audio-diffusion-pytorch/blob/v0.0.94/audio_diffusion_pytorch/modules.py under MIT License
2
+ # License can be found in LICENSES/LICENSE_ADP.txt
3
+
4
+ import math
5
+ from inspect import isfunction
6
+ from math import ceil, floor, log, pi, log2
7
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
8
+ from packaging import version
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from einops import rearrange, reduce, repeat
13
+ from einops.layers.torch import Rearrange
14
+ from einops_exts import rearrange_many
15
+ from torch import Tensor, einsum
16
+ from torch.backends.cuda import sdp_kernel
17
+ from torch.nn import functional as F
18
+ from dac.nn.layers import Snake1d
19
+
20
+ """
21
+ Utils
22
+ """
23
+
24
+
25
+ class ConditionedSequential(nn.Module):
26
+ def __init__(self, *modules):
27
+ super().__init__()
28
+ self.module_list = nn.ModuleList(*modules)
29
+
30
+ def forward(self, x: Tensor, mapping: Optional[Tensor] = None):
31
+ for module in self.module_list:
32
+ x = module(x, mapping)
33
+ return x
34
+
35
+ T = TypeVar("T")
36
+
37
+ def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T:
38
+ if exists(val):
39
+ return val
40
+ return d() if isfunction(d) else d
41
+
42
+ def exists(val: Optional[T]) -> T:
43
+ return val is not None
44
+
45
+ def closest_power_2(x: float) -> int:
46
+ exponent = log2(x)
47
+ distance_fn = lambda z: abs(x - 2 ** z) # noqa
48
+ exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn)
49
+ return 2 ** int(exponent_closest)
50
+
51
+ def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]:
52
+ return_dicts: Tuple[Dict, Dict] = ({}, {})
53
+ for key in d.keys():
54
+ no_prefix = int(not key.startswith(prefix))
55
+ return_dicts[no_prefix][key] = d[key]
56
+ return return_dicts
57
+
58
+ def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]:
59
+ kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d)
60
+ if keep_prefix:
61
+ return kwargs_with_prefix, kwargs
62
+ kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()}
63
+ return kwargs_no_prefix, kwargs
64
+
65
+ """
66
+ Convolutional Blocks
67
+ """
68
+ import typing as tp
69
+
70
+ # Copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conv.py under MIT License
71
+ # License available in LICENSES/LICENSE_META.txt
72
+
73
+ def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
74
+ padding_total: int = 0) -> int:
75
+ """See `pad_for_conv1d`."""
76
+ length = x.shape[-1]
77
+ n_frames = (length - kernel_size + padding_total) / stride + 1
78
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
79
+ return ideal_length - length
80
+
81
+
82
+ def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
83
+ """Pad for a convolution to make sure that the last window is full.
84
+ Extra padding is added at the end. This is required to ensure that we can rebuild
85
+ an output of the same length, as otherwise, even with padding, some time steps
86
+ might get removed.
87
+ For instance, with total padding = 4, kernel size = 4, stride = 2:
88
+ 0 0 1 2 3 4 5 0 0 # (0s are padding)
89
+ 1 2 3 # (output frames of a convolution, last 0 is never used)
90
+ 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
91
+ 1 2 3 4 # once you removed padding, we are missing one time step !
92
+ """
93
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
94
+ return F.pad(x, (0, extra_padding))
95
+
96
+
97
+ def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
98
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
99
+ If this is the case, we insert extra 0 padding to the right before the reflection happen.
100
+ """
101
+ length = x.shape[-1]
102
+ padding_left, padding_right = paddings
103
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
104
+ if mode == 'reflect':
105
+ max_pad = max(padding_left, padding_right)
106
+ extra_pad = 0
107
+ if length <= max_pad:
108
+ extra_pad = max_pad - length + 1
109
+ x = F.pad(x, (0, extra_pad))
110
+ padded = F.pad(x, paddings, mode, value)
111
+ end = padded.shape[-1] - extra_pad
112
+ return padded[..., :end]
113
+ else:
114
+ return F.pad(x, paddings, mode, value)
115
+
116
+
117
+ def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
118
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
119
+ padding_left, padding_right = paddings
120
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
121
+ assert (padding_left + padding_right) <= x.shape[-1]
122
+ end = x.shape[-1] - padding_right
123
+ return x[..., padding_left: end]
124
+
125
+
126
+ class Conv1d(nn.Conv1d):
127
+ def __init__(self, *args, **kwargs):
128
+ super().__init__(*args, **kwargs)
129
+
130
+ def forward(self, x: Tensor, causal=False) -> Tensor:
131
+ kernel_size = self.kernel_size[0]
132
+ stride = self.stride[0]
133
+ dilation = self.dilation[0]
134
+ kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
135
+ padding_total = kernel_size - stride
136
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
137
+ if causal:
138
+ # Left padding for causal
139
+ x = pad1d(x, (padding_total, extra_padding))
140
+ else:
141
+ # Asymmetric padding required for odd strides
142
+ padding_right = padding_total // 2
143
+ padding_left = padding_total - padding_right
144
+ x = pad1d(x, (padding_left, padding_right + extra_padding))
145
+ return super().forward(x)
146
+
147
+ class ConvTranspose1d(nn.ConvTranspose1d):
148
+ def __init__(self, *args, **kwargs):
149
+ super().__init__(*args, **kwargs)
150
+
151
+ def forward(self, x: Tensor, causal=False) -> Tensor:
152
+ kernel_size = self.kernel_size[0]
153
+ stride = self.stride[0]
154
+ padding_total = kernel_size - stride
155
+
156
+ y = super().forward(x)
157
+
158
+ # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
159
+ # removed at the very end, when keeping only the right length for the output,
160
+ # as removing it here would require also passing the length at the matching layer
161
+ # in the encoder.
162
+ if causal:
163
+ padding_right = ceil(padding_total)
164
+ padding_left = padding_total - padding_right
165
+ y = unpad1d(y, (padding_left, padding_right))
166
+ else:
167
+ # Asymmetric padding required for odd strides
168
+ padding_right = padding_total // 2
169
+ padding_left = padding_total - padding_right
170
+ y = unpad1d(y, (padding_left, padding_right))
171
+ return y
172
+
173
+
174
+ def Downsample1d(
175
+ in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
176
+ ) -> nn.Module:
177
+ assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even"
178
+
179
+ return Conv1d(
180
+ in_channels=in_channels,
181
+ out_channels=out_channels,
182
+ kernel_size=factor * kernel_multiplier + 1,
183
+ stride=factor
184
+ )
185
+
186
+
187
+ def Upsample1d(
188
+ in_channels: int, out_channels: int, factor: int, use_nearest: bool = False
189
+ ) -> nn.Module:
190
+
191
+ if factor == 1:
192
+ return Conv1d(
193
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3
194
+ )
195
+
196
+ if use_nearest:
197
+ return nn.Sequential(
198
+ nn.Upsample(scale_factor=factor, mode="nearest"),
199
+ Conv1d(
200
+ in_channels=in_channels,
201
+ out_channels=out_channels,
202
+ kernel_size=3
203
+ ),
204
+ )
205
+ else:
206
+ return ConvTranspose1d(
207
+ in_channels=in_channels,
208
+ out_channels=out_channels,
209
+ kernel_size=factor * 2,
210
+ stride=factor
211
+ )
212
+
213
+
214
+ class ConvBlock1d(nn.Module):
215
+ def __init__(
216
+ self,
217
+ in_channels: int,
218
+ out_channels: int,
219
+ *,
220
+ kernel_size: int = 3,
221
+ stride: int = 1,
222
+ dilation: int = 1,
223
+ num_groups: int = 8,
224
+ use_norm: bool = True,
225
+ use_snake: bool = False
226
+ ) -> None:
227
+ super().__init__()
228
+
229
+ self.groupnorm = (
230
+ nn.GroupNorm(num_groups=num_groups, num_channels=in_channels)
231
+ if use_norm
232
+ else nn.Identity()
233
+ )
234
+
235
+ if use_snake:
236
+ self.activation = Snake1d(in_channels)
237
+ else:
238
+ self.activation = nn.SiLU()
239
+
240
+ self.project = Conv1d(
241
+ in_channels=in_channels,
242
+ out_channels=out_channels,
243
+ kernel_size=kernel_size,
244
+ stride=stride,
245
+ dilation=dilation,
246
+ )
247
+
248
+ def forward(
249
+ self, x: Tensor, scale_shift: Optional[Tuple[Tensor, Tensor]] = None, causal=False
250
+ ) -> Tensor:
251
+ x = self.groupnorm(x)
252
+ if exists(scale_shift):
253
+ scale, shift = scale_shift
254
+ x = x * (scale + 1) + shift
255
+ x = self.activation(x)
256
+ return self.project(x, causal=causal)
257
+
258
+
259
+ class MappingToScaleShift(nn.Module):
260
+ def __init__(
261
+ self,
262
+ features: int,
263
+ channels: int,
264
+ ):
265
+ super().__init__()
266
+
267
+ self.to_scale_shift = nn.Sequential(
268
+ nn.SiLU(),
269
+ nn.Linear(in_features=features, out_features=channels * 2),
270
+ )
271
+
272
+ def forward(self, mapping: Tensor) -> Tuple[Tensor, Tensor]:
273
+ scale_shift = self.to_scale_shift(mapping)
274
+ scale_shift = rearrange(scale_shift, "b c -> b c 1")
275
+ scale, shift = scale_shift.chunk(2, dim=1)
276
+ return scale, shift
277
+
278
+
279
+ class ResnetBlock1d(nn.Module):
280
+ def __init__(
281
+ self,
282
+ in_channels: int,
283
+ out_channels: int,
284
+ *,
285
+ kernel_size: int = 3,
286
+ stride: int = 1,
287
+ dilation: int = 1,
288
+ use_norm: bool = True,
289
+ use_snake: bool = False,
290
+ num_groups: int = 8,
291
+ context_mapping_features: Optional[int] = None,
292
+ ) -> None:
293
+ super().__init__()
294
+
295
+ self.use_mapping = exists(context_mapping_features)
296
+
297
+ self.block1 = ConvBlock1d(
298
+ in_channels=in_channels,
299
+ out_channels=out_channels,
300
+ kernel_size=kernel_size,
301
+ stride=stride,
302
+ dilation=dilation,
303
+ use_norm=use_norm,
304
+ num_groups=num_groups,
305
+ use_snake=use_snake
306
+ )
307
+
308
+ if self.use_mapping:
309
+ assert exists(context_mapping_features)
310
+ self.to_scale_shift = MappingToScaleShift(
311
+ features=context_mapping_features, channels=out_channels
312
+ )
313
+
314
+ self.block2 = ConvBlock1d(
315
+ in_channels=out_channels,
316
+ out_channels=out_channels,
317
+ use_norm=use_norm,
318
+ num_groups=num_groups,
319
+ use_snake=use_snake
320
+ )
321
+
322
+ self.to_out = (
323
+ Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
324
+ if in_channels != out_channels
325
+ else nn.Identity()
326
+ )
327
+
328
+ def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor:
329
+ assert_message = "context mapping required if context_mapping_features > 0"
330
+ assert not (self.use_mapping ^ exists(mapping)), assert_message
331
+
332
+ h = self.block1(x, causal=causal)
333
+
334
+ scale_shift = None
335
+ if self.use_mapping:
336
+ scale_shift = self.to_scale_shift(mapping)
337
+
338
+ h = self.block2(h, scale_shift=scale_shift, causal=causal)
339
+
340
+ return h + self.to_out(x)
341
+
342
+
343
+ class Patcher(nn.Module):
344
+ def __init__(
345
+ self,
346
+ in_channels: int,
347
+ out_channels: int,
348
+ patch_size: int,
349
+ context_mapping_features: Optional[int] = None,
350
+ use_snake: bool = False,
351
+ ):
352
+ super().__init__()
353
+ assert_message = f"out_channels must be divisible by patch_size ({patch_size})"
354
+ assert out_channels % patch_size == 0, assert_message
355
+ self.patch_size = patch_size
356
+
357
+ self.block = ResnetBlock1d(
358
+ in_channels=in_channels,
359
+ out_channels=out_channels // patch_size,
360
+ num_groups=1,
361
+ context_mapping_features=context_mapping_features,
362
+ use_snake=use_snake
363
+ )
364
+
365
+ def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor:
366
+ x = self.block(x, mapping, causal=causal)
367
+ x = rearrange(x, "b c (l p) -> b (c p) l", p=self.patch_size)
368
+ return x
369
+
370
+
371
+ class Unpatcher(nn.Module):
372
+ def __init__(
373
+ self,
374
+ in_channels: int,
375
+ out_channels: int,
376
+ patch_size: int,
377
+ context_mapping_features: Optional[int] = None,
378
+ use_snake: bool = False
379
+ ):
380
+ super().__init__()
381
+ assert_message = f"in_channels must be divisible by patch_size ({patch_size})"
382
+ assert in_channels % patch_size == 0, assert_message
383
+ self.patch_size = patch_size
384
+
385
+ self.block = ResnetBlock1d(
386
+ in_channels=in_channels // patch_size,
387
+ out_channels=out_channels,
388
+ num_groups=1,
389
+ context_mapping_features=context_mapping_features,
390
+ use_snake=use_snake
391
+ )
392
+
393
+ def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor:
394
+ x = rearrange(x, " b (c p) l -> b c (l p) ", p=self.patch_size)
395
+ x = self.block(x, mapping, causal=causal)
396
+ return x
397
+
398
+
399
+ """
400
+ Attention Components
401
+ """
402
+ def FeedForward(features: int, multiplier: int) -> nn.Module:
403
+ mid_features = features * multiplier
404
+ return nn.Sequential(
405
+ nn.Linear(in_features=features, out_features=mid_features),
406
+ nn.GELU(),
407
+ nn.Linear(in_features=mid_features, out_features=features),
408
+ )
409
+
410
+ def add_mask(sim: Tensor, mask: Tensor) -> Tensor:
411
+ b, ndim = sim.shape[0], mask.ndim
412
+ if ndim == 3:
413
+ mask = rearrange(mask, "b n m -> b 1 n m")
414
+ if ndim == 2:
415
+ mask = repeat(mask, "n m -> b 1 n m", b=b)
416
+ max_neg_value = -torch.finfo(sim.dtype).max
417
+ sim = sim.masked_fill(~mask, max_neg_value)
418
+ return sim
419
+
420
+ def causal_mask(q: Tensor, k: Tensor) -> Tensor:
421
+ b, i, j, device = q.shape[0], q.shape[-2], k.shape[-2], q.device
422
+ mask = ~torch.ones((i, j), dtype=torch.bool, device=device).triu(j - i + 1)
423
+ mask = repeat(mask, "n m -> b n m", b=b)
424
+ return mask
425
+
426
+ class AttentionBase(nn.Module):
427
+ def __init__(
428
+ self,
429
+ features: int,
430
+ *,
431
+ head_features: int,
432
+ num_heads: int,
433
+ out_features: Optional[int] = None,
434
+ ):
435
+ super().__init__()
436
+ self.scale = head_features**-0.5
437
+ self.num_heads = num_heads
438
+ mid_features = head_features * num_heads
439
+ out_features = default(out_features, features)
440
+
441
+ self.to_out = nn.Linear(
442
+ in_features=mid_features, out_features=out_features
443
+ )
444
+
445
+ self.use_flash = False
446
+
447
+ if not self.use_flash:
448
+ return
449
+
450
+ device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
451
+
452
+ if device_properties.major == 8 and device_properties.minor == 0:
453
+ # Use flash attention for A100 GPUs
454
+ self.sdp_kernel_config = (False, True, True)
455
+ else:
456
+ # Don't use flash attention for other GPUs
457
+ self.sdp_kernel_config = (False, True, True)
458
+
459
+ def forward(
460
+ self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, is_causal: bool = False
461
+ ) -> Tensor:
462
+ # Split heads
463
+ q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads)
464
+
465
+ if not self.use_flash:
466
+ if is_causal and not mask:
467
+ # Mask out future tokens for causal attention
468
+ mask = causal_mask(q, k)
469
+
470
+ # Compute similarity matrix and add eventual mask
471
+ sim = einsum("... n d, ... m d -> ... n m", q, k) * self.scale
472
+ sim = add_mask(sim, mask) if exists(mask) else sim
473
+
474
+ # Get attention matrix with softmax
475
+ attn = sim.softmax(dim=-1, dtype=torch.float32)
476
+
477
+ # Compute values
478
+ out = einsum("... n m, ... m d -> ... n d", attn, v)
479
+ else:
480
+ with sdp_kernel(*self.sdp_kernel_config):
481
+ out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, is_causal=is_causal)
482
+
483
+ out = rearrange(out, "b h n d -> b n (h d)")
484
+ return self.to_out(out)
485
+
486
+ class Attention(nn.Module):
487
+ def __init__(
488
+ self,
489
+ features: int,
490
+ *,
491
+ head_features: int,
492
+ num_heads: int,
493
+ out_features: Optional[int] = None,
494
+ context_features: Optional[int] = None,
495
+ causal: bool = False,
496
+ ):
497
+ super().__init__()
498
+ self.context_features = context_features
499
+ self.causal = causal
500
+ mid_features = head_features * num_heads
501
+ context_features = default(context_features, features)
502
+
503
+ self.norm = nn.LayerNorm(features)
504
+ self.norm_context = nn.LayerNorm(context_features)
505
+ self.to_q = nn.Linear(
506
+ in_features=features, out_features=mid_features, bias=False
507
+ )
508
+ self.to_kv = nn.Linear(
509
+ in_features=context_features, out_features=mid_features * 2, bias=False
510
+ )
511
+ self.attention = AttentionBase(
512
+ features,
513
+ num_heads=num_heads,
514
+ head_features=head_features,
515
+ out_features=out_features,
516
+ )
517
+
518
+ def forward(
519
+ self,
520
+ x: Tensor, # [b, n, c]
521
+ context: Optional[Tensor] = None, # [b, m, d]
522
+ context_mask: Optional[Tensor] = None, # [b, m], false is masked,
523
+ causal: Optional[bool] = False,
524
+ ) -> Tensor:
525
+ assert_message = "You must provide a context when using context_features"
526
+ assert not self.context_features or exists(context), assert_message
527
+ # Use context if provided
528
+ context = default(context, x)
529
+ # Normalize then compute q from input and k,v from context
530
+ x, context = self.norm(x), self.norm_context(context)
531
+
532
+ q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
533
+
534
+ if exists(context_mask):
535
+ # Mask out cross-attention for padding tokens
536
+ mask = repeat(context_mask, "b m -> b m d", d=v.shape[-1])
537
+ k, v = k * mask, v * mask
538
+
539
+ # Compute and return attention
540
+ return self.attention(q, k, v, is_causal=self.causal or causal)
541
+
542
+
543
+ def FeedForward(features: int, multiplier: int) -> nn.Module:
544
+ mid_features = features * multiplier
545
+ return nn.Sequential(
546
+ nn.Linear(in_features=features, out_features=mid_features),
547
+ nn.GELU(),
548
+ nn.Linear(in_features=mid_features, out_features=features),
549
+ )
550
+
551
+ """
552
+ Transformer Blocks
553
+ """
554
+
555
+
556
+ class TransformerBlock(nn.Module):
557
+ def __init__(
558
+ self,
559
+ features: int,
560
+ num_heads: int,
561
+ head_features: int,
562
+ multiplier: int,
563
+ context_features: Optional[int] = None,
564
+ ):
565
+ super().__init__()
566
+
567
+ self.use_cross_attention = exists(context_features) and context_features > 0
568
+
569
+ self.attention = Attention(
570
+ features=features,
571
+ num_heads=num_heads,
572
+ head_features=head_features
573
+ )
574
+
575
+ if self.use_cross_attention:
576
+ self.cross_attention = Attention(
577
+ features=features,
578
+ num_heads=num_heads,
579
+ head_features=head_features,
580
+ context_features=context_features
581
+ )
582
+
583
+ self.feed_forward = FeedForward(features=features, multiplier=multiplier)
584
+
585
+ def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal: Optional[bool] = False) -> Tensor:
586
+ x = self.attention(x, causal=causal) + x
587
+ if self.use_cross_attention:
588
+ x = self.cross_attention(x, context=context, context_mask=context_mask) + x
589
+ x = self.feed_forward(x) + x
590
+ return x
591
+
592
+
593
+ """
594
+ Transformers
595
+ """
596
+
597
+
598
+ class Transformer1d(nn.Module):
599
+ def __init__(
600
+ self,
601
+ num_layers: int,
602
+ channels: int,
603
+ num_heads: int,
604
+ head_features: int,
605
+ multiplier: int,
606
+ context_features: Optional[int] = None,
607
+ ):
608
+ super().__init__()
609
+
610
+ self.to_in = nn.Sequential(
611
+ nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True),
612
+ Conv1d(
613
+ in_channels=channels,
614
+ out_channels=channels,
615
+ kernel_size=1,
616
+ ),
617
+ Rearrange("b c t -> b t c"),
618
+ )
619
+
620
+ self.blocks = nn.ModuleList(
621
+ [
622
+ TransformerBlock(
623
+ features=channels,
624
+ head_features=head_features,
625
+ num_heads=num_heads,
626
+ multiplier=multiplier,
627
+ context_features=context_features,
628
+ )
629
+ for i in range(num_layers)
630
+ ]
631
+ )
632
+
633
+ self.to_out = nn.Sequential(
634
+ Rearrange("b t c -> b c t"),
635
+ Conv1d(
636
+ in_channels=channels,
637
+ out_channels=channels,
638
+ kernel_size=1,
639
+ ),
640
+ )
641
+
642
+ def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal=False) -> Tensor:
643
+ x = self.to_in(x)
644
+ for block in self.blocks:
645
+ x = block(x, context=context, context_mask=context_mask, causal=causal)
646
+ x = self.to_out(x)
647
+ return x
648
+
649
+
650
+ """
651
+ Time Embeddings
652
+ """
653
+
654
+
655
+ class SinusoidalEmbedding(nn.Module):
656
+ def __init__(self, dim: int):
657
+ super().__init__()
658
+ self.dim = dim
659
+
660
+ def forward(self, x: Tensor) -> Tensor:
661
+ device, half_dim = x.device, self.dim // 2
662
+ emb = torch.tensor(log(10000) / (half_dim - 1), device=device)
663
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
664
+ emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j")
665
+ return torch.cat((emb.sin(), emb.cos()), dim=-1)
666
+
667
+
668
+ class LearnedPositionalEmbedding(nn.Module):
669
+ """Used for continuous time"""
670
+
671
+ def __init__(self, dim: int):
672
+ super().__init__()
673
+ assert (dim % 2) == 0
674
+ half_dim = dim // 2
675
+ self.weights = nn.Parameter(torch.randn(half_dim))
676
+
677
+ def forward(self, x: Tensor) -> Tensor:
678
+ x = rearrange(x, "b -> b 1")
679
+ freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi
680
+ fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
681
+ fouriered = torch.cat((x, fouriered), dim=-1)
682
+ return fouriered
683
+
684
+
685
+ def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
686
+ return nn.Sequential(
687
+ LearnedPositionalEmbedding(dim),
688
+ nn.Linear(in_features=dim + 1, out_features=out_features),
689
+ )
690
+
691
+
692
+ """
693
+ Encoder/Decoder Components
694
+ """
695
+
696
+
697
+ class DownsampleBlock1d(nn.Module):
698
+ def __init__(
699
+ self,
700
+ in_channels: int,
701
+ out_channels: int,
702
+ *,
703
+ factor: int,
704
+ num_groups: int,
705
+ num_layers: int,
706
+ kernel_multiplier: int = 2,
707
+ use_pre_downsample: bool = True,
708
+ use_skip: bool = False,
709
+ use_snake: bool = False,
710
+ extract_channels: int = 0,
711
+ context_channels: int = 0,
712
+ num_transformer_blocks: int = 0,
713
+ attention_heads: Optional[int] = None,
714
+ attention_features: Optional[int] = None,
715
+ attention_multiplier: Optional[int] = None,
716
+ context_mapping_features: Optional[int] = None,
717
+ context_embedding_features: Optional[int] = None,
718
+ ):
719
+ super().__init__()
720
+ self.use_pre_downsample = use_pre_downsample
721
+ self.use_skip = use_skip
722
+ self.use_transformer = num_transformer_blocks > 0
723
+ self.use_extract = extract_channels > 0
724
+ self.use_context = context_channels > 0
725
+
726
+ channels = out_channels if use_pre_downsample else in_channels
727
+
728
+ self.downsample = Downsample1d(
729
+ in_channels=in_channels,
730
+ out_channels=out_channels,
731
+ factor=factor,
732
+ kernel_multiplier=kernel_multiplier,
733
+ )
734
+
735
+ self.blocks = nn.ModuleList(
736
+ [
737
+ ResnetBlock1d(
738
+ in_channels=channels + context_channels if i == 0 else channels,
739
+ out_channels=channels,
740
+ num_groups=num_groups,
741
+ context_mapping_features=context_mapping_features,
742
+ use_snake=use_snake
743
+ )
744
+ for i in range(num_layers)
745
+ ]
746
+ )
747
+
748
+ if self.use_transformer:
749
+ assert (
750
+ (exists(attention_heads) or exists(attention_features))
751
+ and exists(attention_multiplier)
752
+ )
753
+
754
+ if attention_features is None and attention_heads is not None:
755
+ attention_features = channels // attention_heads
756
+
757
+ if attention_heads is None and attention_features is not None:
758
+ attention_heads = channels // attention_features
759
+
760
+ self.transformer = Transformer1d(
761
+ num_layers=num_transformer_blocks,
762
+ channels=channels,
763
+ num_heads=attention_heads,
764
+ head_features=attention_features,
765
+ multiplier=attention_multiplier,
766
+ context_features=context_embedding_features
767
+ )
768
+
769
+ if self.use_extract:
770
+ num_extract_groups = min(num_groups, extract_channels)
771
+ self.to_extracted = ResnetBlock1d(
772
+ in_channels=out_channels,
773
+ out_channels=extract_channels,
774
+ num_groups=num_extract_groups,
775
+ use_snake=use_snake
776
+ )
777
+
778
+ def forward(
779
+ self,
780
+ x: Tensor,
781
+ *,
782
+ mapping: Optional[Tensor] = None,
783
+ channels: Optional[Tensor] = None,
784
+ embedding: Optional[Tensor] = None,
785
+ embedding_mask: Optional[Tensor] = None,
786
+ causal: Optional[bool] = False
787
+ ) -> Union[Tuple[Tensor, List[Tensor]], Tensor]:
788
+
789
+ if self.use_pre_downsample:
790
+ x = self.downsample(x)
791
+
792
+ if self.use_context and exists(channels):
793
+ x = torch.cat([x, channels], dim=1)
794
+
795
+ skips = []
796
+ for block in self.blocks:
797
+ x = block(x, mapping=mapping, causal=causal)
798
+ skips += [x] if self.use_skip else []
799
+
800
+ if self.use_transformer:
801
+ x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal)
802
+ skips += [x] if self.use_skip else []
803
+
804
+ if not self.use_pre_downsample:
805
+ x = self.downsample(x)
806
+
807
+ if self.use_extract:
808
+ extracted = self.to_extracted(x)
809
+ return x, extracted
810
+
811
+ return (x, skips) if self.use_skip else x
812
+
813
+
814
+ class UpsampleBlock1d(nn.Module):
815
+ def __init__(
816
+ self,
817
+ in_channels: int,
818
+ out_channels: int,
819
+ *,
820
+ factor: int,
821
+ num_layers: int,
822
+ num_groups: int,
823
+ use_nearest: bool = False,
824
+ use_pre_upsample: bool = False,
825
+ use_skip: bool = False,
826
+ use_snake: bool = False,
827
+ skip_channels: int = 0,
828
+ use_skip_scale: bool = False,
829
+ extract_channels: int = 0,
830
+ num_transformer_blocks: int = 0,
831
+ attention_heads: Optional[int] = None,
832
+ attention_features: Optional[int] = None,
833
+ attention_multiplier: Optional[int] = None,
834
+ context_mapping_features: Optional[int] = None,
835
+ context_embedding_features: Optional[int] = None,
836
+ ):
837
+ super().__init__()
838
+
839
+ self.use_extract = extract_channels > 0
840
+ self.use_pre_upsample = use_pre_upsample
841
+ self.use_transformer = num_transformer_blocks > 0
842
+ self.use_skip = use_skip
843
+ self.skip_scale = 2 ** -0.5 if use_skip_scale else 1.0
844
+
845
+ channels = out_channels if use_pre_upsample else in_channels
846
+
847
+ self.blocks = nn.ModuleList(
848
+ [
849
+ ResnetBlock1d(
850
+ in_channels=channels + skip_channels,
851
+ out_channels=channels,
852
+ num_groups=num_groups,
853
+ context_mapping_features=context_mapping_features,
854
+ use_snake=use_snake
855
+ )
856
+ for _ in range(num_layers)
857
+ ]
858
+ )
859
+
860
+ if self.use_transformer:
861
+ assert (
862
+ (exists(attention_heads) or exists(attention_features))
863
+ and exists(attention_multiplier)
864
+ )
865
+
866
+ if attention_features is None and attention_heads is not None:
867
+ attention_features = channels // attention_heads
868
+
869
+ if attention_heads is None and attention_features is not None:
870
+ attention_heads = channels // attention_features
871
+
872
+ self.transformer = Transformer1d(
873
+ num_layers=num_transformer_blocks,
874
+ channels=channels,
875
+ num_heads=attention_heads,
876
+ head_features=attention_features,
877
+ multiplier=attention_multiplier,
878
+ context_features=context_embedding_features,
879
+ )
880
+
881
+ self.upsample = Upsample1d(
882
+ in_channels=in_channels,
883
+ out_channels=out_channels,
884
+ factor=factor,
885
+ use_nearest=use_nearest,
886
+ )
887
+
888
+ if self.use_extract:
889
+ num_extract_groups = min(num_groups, extract_channels)
890
+ self.to_extracted = ResnetBlock1d(
891
+ in_channels=out_channels,
892
+ out_channels=extract_channels,
893
+ num_groups=num_extract_groups,
894
+ use_snake=use_snake
895
+ )
896
+
897
+ def add_skip(self, x: Tensor, skip: Tensor) -> Tensor:
898
+ return torch.cat([x, skip * self.skip_scale], dim=1)
899
+
900
+ def forward(
901
+ self,
902
+ x: Tensor,
903
+ *,
904
+ skips: Optional[List[Tensor]] = None,
905
+ mapping: Optional[Tensor] = None,
906
+ embedding: Optional[Tensor] = None,
907
+ embedding_mask: Optional[Tensor] = None,
908
+ causal: Optional[bool] = False
909
+ ) -> Union[Tuple[Tensor, Tensor], Tensor]:
910
+
911
+ if self.use_pre_upsample:
912
+ x = self.upsample(x)
913
+
914
+ for block in self.blocks:
915
+ x = self.add_skip(x, skip=skips.pop()) if exists(skips) else x
916
+ x = block(x, mapping=mapping, causal=causal)
917
+
918
+ if self.use_transformer:
919
+ x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal)
920
+
921
+ if not self.use_pre_upsample:
922
+ x = self.upsample(x)
923
+
924
+ if self.use_extract:
925
+ extracted = self.to_extracted(x)
926
+ return x, extracted
927
+
928
+ return x
929
+
930
+
931
+ class BottleneckBlock1d(nn.Module):
932
+ def __init__(
933
+ self,
934
+ channels: int,
935
+ *,
936
+ num_groups: int,
937
+ num_transformer_blocks: int = 0,
938
+ attention_heads: Optional[int] = None,
939
+ attention_features: Optional[int] = None,
940
+ attention_multiplier: Optional[int] = None,
941
+ context_mapping_features: Optional[int] = None,
942
+ context_embedding_features: Optional[int] = None,
943
+ use_snake: bool = False,
944
+ ):
945
+ super().__init__()
946
+ self.use_transformer = num_transformer_blocks > 0
947
+
948
+ self.pre_block = ResnetBlock1d(
949
+ in_channels=channels,
950
+ out_channels=channels,
951
+ num_groups=num_groups,
952
+ context_mapping_features=context_mapping_features,
953
+ use_snake=use_snake
954
+ )
955
+
956
+ if self.use_transformer:
957
+ assert (
958
+ (exists(attention_heads) or exists(attention_features))
959
+ and exists(attention_multiplier)
960
+ )
961
+
962
+ if attention_features is None and attention_heads is not None:
963
+ attention_features = channels // attention_heads
964
+
965
+ if attention_heads is None and attention_features is not None:
966
+ attention_heads = channels // attention_features
967
+
968
+ self.transformer = Transformer1d(
969
+ num_layers=num_transformer_blocks,
970
+ channels=channels,
971
+ num_heads=attention_heads,
972
+ head_features=attention_features,
973
+ multiplier=attention_multiplier,
974
+ context_features=context_embedding_features,
975
+ )
976
+
977
+ self.post_block = ResnetBlock1d(
978
+ in_channels=channels,
979
+ out_channels=channels,
980
+ num_groups=num_groups,
981
+ context_mapping_features=context_mapping_features,
982
+ use_snake=use_snake
983
+ )
984
+
985
+ def forward(
986
+ self,
987
+ x: Tensor,
988
+ *,
989
+ mapping: Optional[Tensor] = None,
990
+ embedding: Optional[Tensor] = None,
991
+ embedding_mask: Optional[Tensor] = None,
992
+ causal: Optional[bool] = False
993
+ ) -> Tensor:
994
+ x = self.pre_block(x, mapping=mapping, causal=causal)
995
+ if self.use_transformer:
996
+ x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal)
997
+ x = self.post_block(x, mapping=mapping, causal=causal)
998
+ return x
999
+
1000
+
1001
+ """
1002
+ UNet
1003
+ """
1004
+
1005
+
1006
+ class UNet1d(nn.Module):
1007
+ def __init__(
1008
+ self,
1009
+ in_channels: int,
1010
+ channels: int,
1011
+ multipliers: Sequence[int],
1012
+ factors: Sequence[int],
1013
+ num_blocks: Sequence[int],
1014
+ attentions: Sequence[int],
1015
+ patch_size: int = 1,
1016
+ resnet_groups: int = 8,
1017
+ use_context_time: bool = True,
1018
+ kernel_multiplier_downsample: int = 2,
1019
+ use_nearest_upsample: bool = False,
1020
+ use_skip_scale: bool = True,
1021
+ use_snake: bool = False,
1022
+ use_stft: bool = False,
1023
+ use_stft_context: bool = False,
1024
+ out_channels: Optional[int] = None,
1025
+ context_features: Optional[int] = None,
1026
+ context_features_multiplier: int = 4,
1027
+ context_channels: Optional[Sequence[int]] = None,
1028
+ context_embedding_features: Optional[int] = None,
1029
+ **kwargs,
1030
+ ):
1031
+ super().__init__()
1032
+ out_channels = default(out_channels, in_channels)
1033
+ context_channels = list(default(context_channels, []))
1034
+ num_layers = len(multipliers) - 1
1035
+ use_context_features = exists(context_features)
1036
+ use_context_channels = len(context_channels) > 0
1037
+ context_mapping_features = None
1038
+
1039
+ attention_kwargs, kwargs = groupby("attention_", kwargs, keep_prefix=True)
1040
+
1041
+ self.num_layers = num_layers
1042
+ self.use_context_time = use_context_time
1043
+ self.use_context_features = use_context_features
1044
+ self.use_context_channels = use_context_channels
1045
+ self.use_stft = use_stft
1046
+ self.use_stft_context = use_stft_context
1047
+
1048
+ self.context_features = context_features
1049
+ context_channels_pad_length = num_layers + 1 - len(context_channels)
1050
+ context_channels = context_channels + [0] * context_channels_pad_length
1051
+ self.context_channels = context_channels
1052
+ self.context_embedding_features = context_embedding_features
1053
+
1054
+ if use_context_channels:
1055
+ has_context = [c > 0 for c in context_channels]
1056
+ self.has_context = has_context
1057
+ self.channels_ids = [sum(has_context[:i]) for i in range(len(has_context))]
1058
+
1059
+ assert (
1060
+ len(factors) == num_layers
1061
+ and len(attentions) >= num_layers
1062
+ and len(num_blocks) == num_layers
1063
+ )
1064
+
1065
+ if use_context_time or use_context_features:
1066
+ context_mapping_features = channels * context_features_multiplier
1067
+
1068
+ self.to_mapping = nn.Sequential(
1069
+ nn.Linear(context_mapping_features, context_mapping_features),
1070
+ nn.GELU(),
1071
+ nn.Linear(context_mapping_features, context_mapping_features),
1072
+ nn.GELU(),
1073
+ )
1074
+
1075
+ if use_context_time:
1076
+ assert exists(context_mapping_features)
1077
+ self.to_time = nn.Sequential(
1078
+ TimePositionalEmbedding(
1079
+ dim=channels, out_features=context_mapping_features
1080
+ ),
1081
+ nn.GELU(),
1082
+ )
1083
+
1084
+ if use_context_features:
1085
+ assert exists(context_features) and exists(context_mapping_features)
1086
+ self.to_features = nn.Sequential(
1087
+ nn.Linear(
1088
+ in_features=context_features, out_features=context_mapping_features
1089
+ ),
1090
+ nn.GELU(),
1091
+ )
1092
+
1093
+ if use_stft:
1094
+ stft_kwargs, kwargs = groupby("stft_", kwargs)
1095
+ assert "num_fft" in stft_kwargs, "stft_num_fft required if use_stft=True"
1096
+ stft_channels = (stft_kwargs["num_fft"] // 2 + 1) * 2
1097
+ in_channels *= stft_channels
1098
+ out_channels *= stft_channels
1099
+ context_channels[0] *= stft_channels if use_stft_context else 1
1100
+ assert exists(in_channels) and exists(out_channels)
1101
+ self.stft = STFT(**stft_kwargs)
1102
+
1103
+ assert not kwargs, f"Unknown arguments: {', '.join(list(kwargs.keys()))}"
1104
+
1105
+ self.to_in = Patcher(
1106
+ in_channels=in_channels + context_channels[0],
1107
+ out_channels=channels * multipliers[0],
1108
+ patch_size=patch_size,
1109
+ context_mapping_features=context_mapping_features,
1110
+ use_snake=use_snake
1111
+ )
1112
+
1113
+ self.downsamples = nn.ModuleList(
1114
+ [
1115
+ DownsampleBlock1d(
1116
+ in_channels=channels * multipliers[i],
1117
+ out_channels=channels * multipliers[i + 1],
1118
+ context_mapping_features=context_mapping_features,
1119
+ context_channels=context_channels[i + 1],
1120
+ context_embedding_features=context_embedding_features,
1121
+ num_layers=num_blocks[i],
1122
+ factor=factors[i],
1123
+ kernel_multiplier=kernel_multiplier_downsample,
1124
+ num_groups=resnet_groups,
1125
+ use_pre_downsample=True,
1126
+ use_skip=True,
1127
+ use_snake=use_snake,
1128
+ num_transformer_blocks=attentions[i],
1129
+ **attention_kwargs,
1130
+ )
1131
+ for i in range(num_layers)
1132
+ ]
1133
+ )
1134
+
1135
+ self.bottleneck = BottleneckBlock1d(
1136
+ channels=channels * multipliers[-1],
1137
+ context_mapping_features=context_mapping_features,
1138
+ context_embedding_features=context_embedding_features,
1139
+ num_groups=resnet_groups,
1140
+ num_transformer_blocks=attentions[-1],
1141
+ use_snake=use_snake,
1142
+ **attention_kwargs,
1143
+ )
1144
+
1145
+ self.upsamples = nn.ModuleList(
1146
+ [
1147
+ UpsampleBlock1d(
1148
+ in_channels=channels * multipliers[i + 1],
1149
+ out_channels=channels * multipliers[i],
1150
+ context_mapping_features=context_mapping_features,
1151
+ context_embedding_features=context_embedding_features,
1152
+ num_layers=num_blocks[i] + (1 if attentions[i] else 0),
1153
+ factor=factors[i],
1154
+ use_nearest=use_nearest_upsample,
1155
+ num_groups=resnet_groups,
1156
+ use_skip_scale=use_skip_scale,
1157
+ use_pre_upsample=False,
1158
+ use_skip=True,
1159
+ use_snake=use_snake,
1160
+ skip_channels=channels * multipliers[i + 1],
1161
+ num_transformer_blocks=attentions[i],
1162
+ **attention_kwargs,
1163
+ )
1164
+ for i in reversed(range(num_layers))
1165
+ ]
1166
+ )
1167
+
1168
+ self.to_out = Unpatcher(
1169
+ in_channels=channels * multipliers[0],
1170
+ out_channels=out_channels,
1171
+ patch_size=patch_size,
1172
+ context_mapping_features=context_mapping_features,
1173
+ use_snake=use_snake
1174
+ )
1175
+
1176
+ def get_channels(
1177
+ self, channels_list: Optional[Sequence[Tensor]] = None, layer: int = 0
1178
+ ) -> Optional[Tensor]:
1179
+ """Gets context channels at `layer` and checks that shape is correct"""
1180
+ use_context_channels = self.use_context_channels and self.has_context[layer]
1181
+ if not use_context_channels:
1182
+ return None
1183
+ assert exists(channels_list), "Missing context"
1184
+ # Get channels index (skipping zero channel contexts)
1185
+ channels_id = self.channels_ids[layer]
1186
+ # Get channels
1187
+ channels = channels_list[channels_id]
1188
+ message = f"Missing context for layer {layer} at index {channels_id}"
1189
+ assert exists(channels), message
1190
+ # Check channels
1191
+ num_channels = self.context_channels[layer]
1192
+ message = f"Expected context with {num_channels} channels at idx {channels_id}"
1193
+ assert channels.shape[1] == num_channels, message
1194
+ # STFT channels if requested
1195
+ channels = self.stft.encode1d(channels) if self.use_stft_context else channels # type: ignore # noqa
1196
+ return channels
1197
+
1198
+ def get_mapping(
1199
+ self, time: Optional[Tensor] = None, features: Optional[Tensor] = None
1200
+ ) -> Optional[Tensor]:
1201
+ """Combines context time features and features into mapping"""
1202
+ items, mapping = [], None
1203
+ # Compute time features
1204
+ if self.use_context_time:
1205
+ assert_message = "use_context_time=True but no time features provided"
1206
+ assert exists(time), assert_message
1207
+ items += [self.to_time(time)]
1208
+ # Compute features
1209
+ if self.use_context_features:
1210
+ assert_message = "context_features exists but no features provided"
1211
+ assert exists(features), assert_message
1212
+ items += [self.to_features(features)]
1213
+ # Compute joint mapping
1214
+ if self.use_context_time or self.use_context_features:
1215
+ mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
1216
+ mapping = self.to_mapping(mapping)
1217
+ return mapping
1218
+
1219
+ def forward(
1220
+ self,
1221
+ x: Tensor,
1222
+ time: Optional[Tensor] = None,
1223
+ *,
1224
+ features: Optional[Tensor] = None,
1225
+ channels_list: Optional[Sequence[Tensor]] = None,
1226
+ embedding: Optional[Tensor] = None,
1227
+ embedding_mask: Optional[Tensor] = None,
1228
+ causal: Optional[bool] = False,
1229
+ ) -> Tensor:
1230
+ channels = self.get_channels(channels_list, layer=0)
1231
+ # Apply stft if required
1232
+ x = self.stft.encode1d(x) if self.use_stft else x # type: ignore
1233
+ # Concat context channels at layer 0 if provided
1234
+ x = torch.cat([x, channels], dim=1) if exists(channels) else x
1235
+ # Compute mapping from time and features
1236
+ mapping = self.get_mapping(time, features)
1237
+ x = self.to_in(x, mapping, causal=causal)
1238
+ skips_list = [x]
1239
+
1240
+ for i, downsample in enumerate(self.downsamples):
1241
+ channels = self.get_channels(channels_list, layer=i + 1)
1242
+ x, skips = downsample(
1243
+ x, mapping=mapping, channels=channels, embedding=embedding, embedding_mask=embedding_mask, causal=causal
1244
+ )
1245
+ skips_list += [skips]
1246
+
1247
+ x = self.bottleneck(x, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal)
1248
+
1249
+ for i, upsample in enumerate(self.upsamples):
1250
+ skips = skips_list.pop()
1251
+ x = upsample(x, skips=skips, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal)
1252
+
1253
+ x += skips_list.pop()
1254
+ x = self.to_out(x, mapping, causal=causal)
1255
+ x = self.stft.decode1d(x) if self.use_stft else x
1256
+
1257
+ return x
1258
+
1259
+
1260
+ """ Conditioning Modules """
1261
+
1262
+
1263
+ class FixedEmbedding(nn.Module):
1264
+ def __init__(self, max_length: int, features: int):
1265
+ super().__init__()
1266
+ self.max_length = max_length
1267
+ self.embedding = nn.Embedding(max_length, features)
1268
+
1269
+ def forward(self, x: Tensor) -> Tensor:
1270
+ batch_size, length, device = *x.shape[0:2], x.device
1271
+ assert_message = "Input sequence length must be <= max_length"
1272
+ assert length <= self.max_length, assert_message
1273
+ position = torch.arange(length, device=device)
1274
+ fixed_embedding = self.embedding(position)
1275
+ fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size)
1276
+ return fixed_embedding
1277
+
1278
+
1279
+ def rand_bool(shape: Any, proba: float, device: Any = None) -> Tensor:
1280
+ if proba == 1:
1281
+ return torch.ones(shape, device=device, dtype=torch.bool)
1282
+ elif proba == 0:
1283
+ return torch.zeros(shape, device=device, dtype=torch.bool)
1284
+ else:
1285
+ return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)
1286
+
1287
+
1288
+ class UNetCFG1d(UNet1d):
1289
+
1290
+ """UNet1d with Classifier-Free Guidance"""
1291
+
1292
+ def __init__(
1293
+ self,
1294
+ context_embedding_max_length: int,
1295
+ context_embedding_features: int,
1296
+ use_xattn_time: bool = False,
1297
+ **kwargs,
1298
+ ):
1299
+ super().__init__(
1300
+ context_embedding_features=context_embedding_features, **kwargs
1301
+ )
1302
+
1303
+ self.use_xattn_time = use_xattn_time
1304
+
1305
+ if use_xattn_time:
1306
+ assert exists(context_embedding_features)
1307
+ self.to_time_embedding = nn.Sequential(
1308
+ TimePositionalEmbedding(
1309
+ dim=kwargs["channels"], out_features=context_embedding_features
1310
+ ),
1311
+ nn.GELU(),
1312
+ )
1313
+
1314
+ context_embedding_max_length += 1 # Add one for time embedding
1315
+
1316
+ self.fixed_embedding = FixedEmbedding(
1317
+ max_length=context_embedding_max_length, features=context_embedding_features
1318
+ )
1319
+
1320
+ def forward( # type: ignore
1321
+ self,
1322
+ x: Tensor,
1323
+ time: Tensor,
1324
+ *,
1325
+ embedding: Tensor,
1326
+ embedding_mask: Optional[Tensor] = None,
1327
+ embedding_scale: float = 1.0,
1328
+ embedding_mask_proba: float = 0.0,
1329
+ batch_cfg: bool = False,
1330
+ rescale_cfg: bool = False,
1331
+ scale_phi: float = 0.4,
1332
+ negative_embedding: Optional[Tensor] = None,
1333
+ negative_embedding_mask: Optional[Tensor] = None,
1334
+ **kwargs,
1335
+ ) -> Tensor:
1336
+ b, device = embedding.shape[0], embedding.device
1337
+
1338
+ if self.use_xattn_time:
1339
+ embedding = torch.cat([embedding, self.to_time_embedding(time).unsqueeze(1)], dim=1)
1340
+
1341
+ if embedding_mask is not None:
1342
+ embedding_mask = torch.cat([embedding_mask, torch.ones((b, 1), device=device)], dim=1)
1343
+
1344
+ fixed_embedding = self.fixed_embedding(embedding)
1345
+
1346
+ if embedding_mask_proba > 0.0:
1347
+ # Randomly mask embedding
1348
+ batch_mask = rand_bool(
1349
+ shape=(b, 1, 1), proba=embedding_mask_proba, device=device
1350
+ )
1351
+ embedding = torch.where(batch_mask, fixed_embedding, embedding)
1352
+
1353
+ if embedding_scale != 1.0:
1354
+ if batch_cfg:
1355
+ batch_x = torch.cat([x, x], dim=0)
1356
+ batch_time = torch.cat([time, time], dim=0)
1357
+
1358
+ if negative_embedding is not None:
1359
+ if negative_embedding_mask is not None:
1360
+ negative_embedding_mask = negative_embedding_mask.to(torch.bool).unsqueeze(2)
1361
+
1362
+ negative_embedding = torch.where(negative_embedding_mask, negative_embedding, fixed_embedding)
1363
+
1364
+ batch_embed = torch.cat([embedding, negative_embedding], dim=0)
1365
+
1366
+ else:
1367
+ batch_embed = torch.cat([embedding, fixed_embedding], dim=0)
1368
+
1369
+ batch_mask = None
1370
+ if embedding_mask is not None:
1371
+ batch_mask = torch.cat([embedding_mask, embedding_mask], dim=0)
1372
+
1373
+ batch_features = None
1374
+ features = kwargs.pop("features", None)
1375
+ if self.use_context_features:
1376
+ batch_features = torch.cat([features, features], dim=0)
1377
+
1378
+ batch_channels = None
1379
+ channels_list = kwargs.pop("channels_list", None)
1380
+ if self.use_context_channels:
1381
+ batch_channels = []
1382
+ for channels in channels_list:
1383
+ batch_channels += [torch.cat([channels, channels], dim=0)]
1384
+
1385
+ # Compute both normal and fixed embedding outputs
1386
+ batch_out = super().forward(batch_x, batch_time, embedding=batch_embed, embedding_mask=batch_mask, features=batch_features, channels_list=batch_channels, **kwargs)
1387
+ out, out_masked = batch_out.chunk(2, dim=0)
1388
+
1389
+ else:
1390
+ # Compute both normal and fixed embedding outputs
1391
+ out = super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs)
1392
+ out_masked = super().forward(x, time, embedding=fixed_embedding, embedding_mask=embedding_mask, **kwargs)
1393
+
1394
+ out_cfg = out_masked + (out - out_masked) * embedding_scale
1395
+
1396
+ if rescale_cfg:
1397
+
1398
+ out_std = out.std(dim=1, keepdim=True)
1399
+ out_cfg_std = out_cfg.std(dim=1, keepdim=True)
1400
+
1401
+ return scale_phi * (out_cfg * (out_std/out_cfg_std)) + (1-scale_phi) * out_cfg
1402
+
1403
+ else:
1404
+
1405
+ return out_cfg
1406
+
1407
+ else:
1408
+ return super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs)
1409
+
1410
+
1411
+ class UNetNCCA1d(UNet1d):
1412
+
1413
+ """UNet1d with Noise Channel Conditioning Augmentation"""
1414
+
1415
+ def __init__(self, context_features: int, **kwargs):
1416
+ super().__init__(context_features=context_features, **kwargs)
1417
+ self.embedder = NumberEmbedder(features=context_features)
1418
+
1419
+ def expand(self, x: Any, shape: Tuple[int, ...]) -> Tensor:
1420
+ x = x if torch.is_tensor(x) else torch.tensor(x)
1421
+ return x.expand(shape)
1422
+
1423
+ def forward( # type: ignore
1424
+ self,
1425
+ x: Tensor,
1426
+ time: Tensor,
1427
+ *,
1428
+ channels_list: Sequence[Tensor],
1429
+ channels_augmentation: Union[
1430
+ bool, Sequence[bool], Sequence[Sequence[bool]], Tensor
1431
+ ] = False,
1432
+ channels_scale: Union[
1433
+ float, Sequence[float], Sequence[Sequence[float]], Tensor
1434
+ ] = 0,
1435
+ **kwargs,
1436
+ ) -> Tensor:
1437
+ b, n = x.shape[0], len(channels_list)
1438
+ channels_augmentation = self.expand(channels_augmentation, shape=(b, n)).to(x)
1439
+ channels_scale = self.expand(channels_scale, shape=(b, n)).to(x)
1440
+
1441
+ # Augmentation (for each channel list item)
1442
+ for i in range(n):
1443
+ scale = channels_scale[:, i] * channels_augmentation[:, i]
1444
+ scale = rearrange(scale, "b -> b 1 1")
1445
+ item = channels_list[i]
1446
+ channels_list[i] = torch.randn_like(item) * scale + item * (1 - scale) # type: ignore # noqa
1447
+
1448
+ # Scale embedding (sum reduction if more than one channel list item)
1449
+ channels_scale_emb = self.embedder(channels_scale)
1450
+ channels_scale_emb = reduce(channels_scale_emb, "b n d -> b d", "sum")
1451
+
1452
+ return super().forward(
1453
+ x=x,
1454
+ time=time,
1455
+ channels_list=channels_list,
1456
+ features=channels_scale_emb,
1457
+ **kwargs,
1458
+ )
1459
+
1460
+
1461
+ class UNetAll1d(UNetCFG1d, UNetNCCA1d):
1462
+ def __init__(self, *args, **kwargs):
1463
+ super().__init__(*args, **kwargs)
1464
+
1465
+ def forward(self, *args, **kwargs): # type: ignore
1466
+ return UNetCFG1d.forward(self, *args, **kwargs)
1467
+
1468
+
1469
+ def XUNet1d(type: str = "base", **kwargs) -> UNet1d:
1470
+ if type == "base":
1471
+ return UNet1d(**kwargs)
1472
+ elif type == "all":
1473
+ return UNetAll1d(**kwargs)
1474
+ elif type == "cfg":
1475
+ return UNetCFG1d(**kwargs)
1476
+ elif type == "ncca":
1477
+ return UNetNCCA1d(**kwargs)
1478
+ else:
1479
+ raise ValueError(f"Unknown XUNet1d type: {type}")
1480
+
1481
+ class NumberEmbedder(nn.Module):
1482
+ def __init__(
1483
+ self,
1484
+ features: int,
1485
+ dim: int = 256,
1486
+ ):
1487
+ super().__init__()
1488
+ self.features = features
1489
+ self.embedding = TimePositionalEmbedding(dim=dim, out_features=features)
1490
+
1491
+ def forward(self, x: Union[List[float], Tensor]) -> Tensor:
1492
+ if not torch.is_tensor(x):
1493
+ device = next(self.embedding.parameters()).device
1494
+ x = torch.tensor(x, device=device)
1495
+ assert isinstance(x, Tensor)
1496
+ shape = x.shape
1497
+ x = rearrange(x, "... -> (...)")
1498
+ embedding = self.embedding(x)
1499
+ x = embedding.view(*shape, self.features)
1500
+ return x # type: ignore
1501
+
1502
+
1503
+ """
1504
+ Audio Transforms
1505
+ """
1506
+
1507
+
1508
+ class STFT(nn.Module):
1509
+ """Helper for torch stft and istft"""
1510
+
1511
+ def __init__(
1512
+ self,
1513
+ num_fft: int = 1023,
1514
+ hop_length: int = 256,
1515
+ window_length: Optional[int] = None,
1516
+ length: Optional[int] = None,
1517
+ use_complex: bool = False,
1518
+ ):
1519
+ super().__init__()
1520
+ self.num_fft = num_fft
1521
+ self.hop_length = default(hop_length, floor(num_fft // 4))
1522
+ self.window_length = default(window_length, num_fft)
1523
+ self.length = length
1524
+ self.register_buffer("window", torch.hann_window(self.window_length))
1525
+ self.use_complex = use_complex
1526
+
1527
+ def encode(self, wave: Tensor) -> Tuple[Tensor, Tensor]:
1528
+ b = wave.shape[0]
1529
+ wave = rearrange(wave, "b c t -> (b c) t")
1530
+
1531
+ stft = torch.stft(
1532
+ wave,
1533
+ n_fft=self.num_fft,
1534
+ hop_length=self.hop_length,
1535
+ win_length=self.window_length,
1536
+ window=self.window, # type: ignore
1537
+ return_complex=True,
1538
+ normalized=True,
1539
+ )
1540
+
1541
+ if self.use_complex:
1542
+ # Returns real and imaginary
1543
+ stft_a, stft_b = stft.real, stft.imag
1544
+ else:
1545
+ # Returns magnitude and phase matrices
1546
+ magnitude, phase = torch.abs(stft), torch.angle(stft)
1547
+ stft_a, stft_b = magnitude, phase
1548
+
1549
+ return rearrange_many((stft_a, stft_b), "(b c) f l -> b c f l", b=b)
1550
+
1551
+ def decode(self, stft_a: Tensor, stft_b: Tensor) -> Tensor:
1552
+ b, l = stft_a.shape[0], stft_a.shape[-1] # noqa
1553
+ length = closest_power_2(l * self.hop_length)
1554
+
1555
+ stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> (b c) f l")
1556
+
1557
+ if self.use_complex:
1558
+ real, imag = stft_a, stft_b
1559
+ else:
1560
+ magnitude, phase = stft_a, stft_b
1561
+ real, imag = magnitude * torch.cos(phase), magnitude * torch.sin(phase)
1562
+
1563
+ stft = torch.stack([real, imag], dim=-1)
1564
+
1565
+ wave = torch.istft(
1566
+ stft,
1567
+ n_fft=self.num_fft,
1568
+ hop_length=self.hop_length,
1569
+ win_length=self.window_length,
1570
+ window=self.window, # type: ignore
1571
+ length=default(self.length, length),
1572
+ normalized=True,
1573
+ )
1574
+
1575
+ return rearrange(wave, "(b c) t -> b c t", b=b)
1576
+
1577
+ def encode1d(
1578
+ self, wave: Tensor, stacked: bool = True
1579
+ ) -> Union[Tensor, Tuple[Tensor, Tensor]]:
1580
+ stft_a, stft_b = self.encode(wave)
1581
+ stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> b (c f) l")
1582
+ return torch.cat((stft_a, stft_b), dim=1) if stacked else (stft_a, stft_b)
1583
+
1584
+ def decode1d(self, stft_pair: Tensor) -> Tensor:
1585
+ f = self.num_fft // 2 + 1
1586
+ stft_a, stft_b = stft_pair.chunk(chunks=2, dim=1)
1587
+ stft_a, stft_b = rearrange_many((stft_a, stft_b), "b (c f) l -> b c f l", f=f)
1588
+ return self.decode(stft_a, stft_b)
stable_audio_tools/models/autoencoders.py ADDED
@@ -0,0 +1,800 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import numpy as np
4
+
5
+ from torch import nn, sin, pow
6
+ from torch.nn import functional as F
7
+ from torchaudio import transforms as T
8
+ from alias_free_torch import Activation1d
9
+ from dac.nn.layers import WNConv1d, WNConvTranspose1d
10
+ from typing import List, Literal, Dict, Any, Callable
11
+ from einops import rearrange
12
+
13
+ from ..inference.sampling import sample
14
+ from ..inference.utils import prepare_audio
15
+ from .blocks import SnakeBeta
16
+ from .bottleneck import Bottleneck, DiscreteBottleneck
17
+ from .diffusion import ConditionedDiffusionModel, DAU1DCondWrapper, UNet1DCondWrapper, DiTWrapper
18
+ from .factory import create_pretransform_from_config, create_bottleneck_from_config
19
+ from .pretransforms import Pretransform, AutoencoderPretransform
20
+
21
+ def checkpoint(function, *args, **kwargs):
22
+ kwargs.setdefault("use_reentrant", False)
23
+ return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
24
+
25
+ def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
26
+ if activation == "elu":
27
+ act = nn.ELU()
28
+ elif activation == "snake":
29
+ act = SnakeBeta(channels)
30
+ elif activation == "none":
31
+ act = nn.Identity()
32
+ else:
33
+ raise ValueError(f"Unknown activation {activation}")
34
+
35
+ if antialias:
36
+ act = Activation1d(act)
37
+
38
+ return act
39
+
40
+ class ResidualUnit(nn.Module):
41
+ def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False):
42
+ super().__init__()
43
+
44
+ self.dilation = dilation
45
+
46
+ act = get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels)
47
+
48
+ padding = (dilation * (7-1)) // 2
49
+
50
+ self.layers = nn.Sequential(
51
+ act,
52
+ WNConv1d(in_channels=in_channels, out_channels=out_channels,
53
+ kernel_size=7, dilation=dilation, padding=padding),
54
+ act,
55
+ WNConv1d(in_channels=out_channels, out_channels=out_channels,
56
+ kernel_size=1)
57
+ )
58
+
59
+ def forward(self, x):
60
+ res = x
61
+
62
+ # Disable checkpoint until tensor mismatch is fixed
63
+ #x = checkpoint(self.layers, x)
64
+ x = self.layers(x)
65
+
66
+
67
+ class EncoderBlock(nn.Module):
68
+ def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False):
69
+ super().__init__()
70
+
71
+ act = get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels)
72
+
73
+ self.layers = nn.Sequential(
74
+ ResidualUnit(in_channels=in_channels,
75
+ out_channels=in_channels, dilation=1, use_snake=use_snake),
76
+ ResidualUnit(in_channels=in_channels,
77
+ out_channels=in_channels, dilation=3, use_snake=use_snake),
78
+ ResidualUnit(in_channels=in_channels,
79
+ out_channels=in_channels, dilation=9, use_snake=use_snake),
80
+ act,
81
+ WNConv1d(in_channels=in_channels, out_channels=out_channels,
82
+ kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)),
83
+ )
84
+
85
+ def forward(self, x):
86
+ return self.layers(x)
87
+
88
+ class DecoderBlock(nn.Module):
89
+ def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False):
90
+ super().__init__()
91
+
92
+ if use_nearest_upsample:
93
+ upsample_layer = nn.Sequential(
94
+ nn.Upsample(scale_factor=stride, mode="nearest"),
95
+ WNConv1d(in_channels=in_channels,
96
+ out_channels=out_channels,
97
+ kernel_size=2*stride,
98
+ stride=1,
99
+ bias=False,
100
+ padding='same')
101
+ )
102
+ else:
103
+ upsample_layer = WNConvTranspose1d(in_channels=in_channels,
104
+ out_channels=out_channels,
105
+ kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2))
106
+
107
+ act = get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels)
108
+
109
+ self.layers = nn.Sequential(
110
+ act,
111
+ upsample_layer,
112
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels,
113
+ dilation=1, use_snake=use_snake),
114
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels,
115
+ dilation=3, use_snake=use_snake),
116
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels,
117
+ dilation=9, use_snake=use_snake),
118
+ )
119
+
120
+ def forward(self, x):
121
+ return self.layers(x)
122
+
123
+ class OobleckEncoder(nn.Module):
124
+ def __init__(self,
125
+ in_channels=2,
126
+ channels=128,
127
+ latent_dim=32,
128
+ c_mults = [1, 2, 4, 8],
129
+ strides = [2, 4, 8, 8],
130
+ use_snake=False,
131
+ antialias_activation=False
132
+ ):
133
+ super().__init__()
134
+
135
+ c_mults = [1] + c_mults
136
+
137
+ self.depth = len(c_mults)
138
+
139
+ layers = [
140
+ WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3)
141
+ ]
142
+
143
+ for i in range(self.depth-1):
144
+ layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)]
145
+
146
+ layers += [
147
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels),
148
+ WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1)
149
+ ]
150
+
151
+ self.layers = nn.Sequential(*layers)
152
+
153
+ def forward(self, x):
154
+ return self.layers(x)
155
+
156
+
157
+ class OobleckDecoder(nn.Module):
158
+ def __init__(self,
159
+ out_channels=2,
160
+ channels=128,
161
+ latent_dim=32,
162
+ c_mults = [1, 2, 4, 8],
163
+ strides = [2, 4, 8, 8],
164
+ use_snake=False,
165
+ antialias_activation=False,
166
+ use_nearest_upsample=False,
167
+ final_tanh=True):
168
+ super().__init__()
169
+
170
+ c_mults = [1] + c_mults
171
+
172
+ self.depth = len(c_mults)
173
+
174
+ layers = [
175
+ WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3),
176
+ ]
177
+
178
+ for i in range(self.depth-1, 0, -1):
179
+ layers += [DecoderBlock(
180
+ in_channels=c_mults[i]*channels,
181
+ out_channels=c_mults[i-1]*channels,
182
+ stride=strides[i-1],
183
+ use_snake=use_snake,
184
+ antialias_activation=antialias_activation,
185
+ use_nearest_upsample=use_nearest_upsample
186
+ )
187
+ ]
188
+
189
+ layers += [
190
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels),
191
+ WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False),
192
+ nn.Tanh() if final_tanh else nn.Identity()
193
+ ]
194
+
195
+ self.layers = nn.Sequential(*layers)
196
+
197
+ def forward(self, x):
198
+ return self.layers(x)
199
+
200
+ class DACEncoderWrapper(nn.Module):
201
+ def __init__(self, in_channels=1, **kwargs):
202
+ super().__init__()
203
+
204
+ from dac.model.dac import Encoder as DACEncoder
205
+
206
+ latent_dim = kwargs.pop("latent_dim", None)
207
+
208
+ encoder_out_dim = kwargs["d_model"] * (2 ** len(kwargs["strides"]))
209
+ self.encoder = DACEncoder(d_latent=encoder_out_dim, **kwargs)
210
+ self.latent_dim = latent_dim
211
+
212
+ # Latent-dim support was added to DAC after this was first written, and implemented differently, so this is for backwards compatibility
213
+ self.proj_out = nn.Conv1d(self.encoder.enc_dim, latent_dim, kernel_size=1) if latent_dim is not None else nn.Identity()
214
+
215
+ if in_channels != 1:
216
+ self.encoder.block[0] = WNConv1d(in_channels, kwargs.get("d_model", 64), kernel_size=7, padding=3)
217
+
218
+ def forward(self, x):
219
+ x = self.encoder(x)
220
+ x = self.proj_out(x)
221
+ return x
222
+
223
+ class DACDecoderWrapper(nn.Module):
224
+ def __init__(self, latent_dim, out_channels=1, **kwargs):
225
+ super().__init__()
226
+
227
+ from dac.model.dac import Decoder as DACDecoder
228
+
229
+ self.decoder = DACDecoder(**kwargs, input_channel = latent_dim, d_out=out_channels)
230
+
231
+ self.latent_dim = latent_dim
232
+
233
+ def forward(self, x):
234
+ return self.decoder(x)
235
+
236
+ class AudioAutoencoder(nn.Module):
237
+ def __init__(
238
+ self,
239
+ encoder,
240
+ decoder,
241
+ latent_dim,
242
+ downsampling_ratio,
243
+ sample_rate,
244
+ io_channels=2,
245
+ bottleneck: Bottleneck = None,
246
+ pretransform: Pretransform = None,
247
+ in_channels = None,
248
+ out_channels = None,
249
+ soft_clip = False
250
+ ):
251
+ super().__init__()
252
+
253
+ self.downsampling_ratio = downsampling_ratio
254
+ self.sample_rate = sample_rate
255
+
256
+ self.latent_dim = latent_dim
257
+ self.io_channels = io_channels
258
+ self.in_channels = io_channels
259
+ self.out_channels = io_channels
260
+
261
+ self.min_length = self.downsampling_ratio
262
+
263
+ if in_channels is not None:
264
+ self.in_channels = in_channels
265
+
266
+ if out_channels is not None:
267
+ self.out_channels = out_channels
268
+
269
+ self.bottleneck = bottleneck
270
+
271
+ self.encoder = encoder
272
+
273
+ self.decoder = decoder
274
+
275
+ self.pretransform = pretransform
276
+
277
+ self.soft_clip = soft_clip
278
+
279
+ self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete
280
+
281
+ def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batch=False, **kwargs):
282
+
283
+ info = {}
284
+
285
+ if self.pretransform is not None and not skip_pretransform:
286
+ if self.pretransform.enable_grad:
287
+ if iterate_batch:
288
+ audios = []
289
+ for i in range(audio.shape[0]):
290
+ audios.append(self.pretransform.encode(audio[i:i+1]))
291
+ audio = torch.cat(audios, dim=0)
292
+ else:
293
+ audio = self.pretransform.encode(audio)
294
+ else:
295
+ with torch.no_grad():
296
+ if iterate_batch:
297
+ audios = []
298
+ for i in range(audio.shape[0]):
299
+ audios.append(self.pretransform.encode(audio[i:i+1]))
300
+ audio = torch.cat(audios, dim=0)
301
+ else:
302
+ audio = self.pretransform.encode(audio)
303
+
304
+ if self.encoder is not None:
305
+ if iterate_batch:
306
+ latents = []
307
+ for i in range(audio.shape[0]):
308
+ latents.append(self.encoder(audio[i:i+1]))
309
+ latents = torch.cat(latents, dim=0)
310
+ else:
311
+ latents = self.encoder(audio)
312
+ else:
313
+ latents = audio
314
+
315
+ if self.bottleneck is not None:
316
+ # TODO: Add iterate batch logic, needs to merge the info dicts
317
+ latents, bottleneck_info = self.bottleneck.encode(latents, return_info=True, **kwargs)
318
+
319
+ info.update(bottleneck_info)
320
+
321
+ if return_info:
322
+ return latents, info
323
+
324
+ return latents
325
+
326
+ def decode(self, latents, iterate_batch=False, **kwargs):
327
+
328
+ if self.bottleneck is not None:
329
+ if iterate_batch:
330
+ decoded = []
331
+ for i in range(latents.shape[0]):
332
+ decoded.append(self.bottleneck.decode(latents[i:i+1]))
333
+ decoded = torch.cat(decoded, dim=0)
334
+ else:
335
+ latents = self.bottleneck.decode(latents)
336
+
337
+ if iterate_batch:
338
+ decoded = []
339
+ for i in range(latents.shape[0]):
340
+ decoded.append(self.decoder(latents[i:i+1]))
341
+ decoded = torch.cat(decoded, dim=0)
342
+ else:
343
+ decoded = self.decoder(latents, **kwargs)
344
+
345
+ if self.pretransform is not None:
346
+ if self.pretransform.enable_grad:
347
+ if iterate_batch:
348
+ decodeds = []
349
+ for i in range(decoded.shape[0]):
350
+ decodeds.append(self.pretransform.decode(decoded[i:i+1]))
351
+ decoded = torch.cat(decodeds, dim=0)
352
+ else:
353
+ decoded = self.pretransform.decode(decoded)
354
+ else:
355
+ with torch.no_grad():
356
+ if iterate_batch:
357
+ decodeds = []
358
+ for i in range(latents.shape[0]):
359
+ decodeds.append(self.pretransform.decode(decoded[i:i+1]))
360
+ decoded = torch.cat(decodeds, dim=0)
361
+ else:
362
+ decoded = self.pretransform.decode(decoded)
363
+
364
+ if self.soft_clip:
365
+ decoded = torch.tanh(decoded)
366
+
367
+ return decoded
368
+
369
+ def decode_tokens(self, tokens, **kwargs):
370
+ '''
371
+ Decode discrete tokens to audio
372
+ Only works with discrete autoencoders
373
+ '''
374
+
375
+ assert isinstance(self.bottleneck, DiscreteBottleneck), "decode_tokens only works with discrete autoencoders"
376
+
377
+ latents = self.bottleneck.decode_tokens(tokens, **kwargs)
378
+
379
+ return self.decode(latents, **kwargs)
380
+
381
+
382
+ def preprocess_audio_for_encoder(self, audio, in_sr):
383
+ '''
384
+ Preprocess single audio tensor (Channels x Length) to be compatible with the encoder.
385
+ If the model is mono, stereo audio will be converted to mono.
386
+ Audio will be silence-padded to be a multiple of the model's downsampling ratio.
387
+ Audio will be resampled to the model's sample rate.
388
+ The output will have batch size 1 and be shape (1 x Channels x Length)
389
+ '''
390
+ return self.preprocess_audio_list_for_encoder([audio], [in_sr])
391
+
392
+ def preprocess_audio_list_for_encoder(self, audio_list, in_sr_list):
393
+ '''
394
+ Preprocess a [list] of audio (Channels x Length) into a batch tensor to be compatable with the encoder.
395
+ The audio in that list can be of different lengths and channels.
396
+ in_sr can be an integer or list. If it's an integer it will be assumed it is the input sample_rate for every audio.
397
+ All audio will be resampled to the model's sample rate.
398
+ Audio will be silence-padded to the longest length, and further padded to be a multiple of the model's downsampling ratio.
399
+ If the model is mono, all audio will be converted to mono.
400
+ The output will be a tensor of shape (Batch x Channels x Length)
401
+ '''
402
+ batch_size = len(audio_list)
403
+ if isinstance(in_sr_list, int):
404
+ in_sr_list = [in_sr_list]*batch_size
405
+ assert len(in_sr_list) == batch_size, "list of sample rates must be the same length of audio_list"
406
+ new_audio = []
407
+ max_length = 0
408
+ # resample & find the max length
409
+ for i in range(batch_size):
410
+ audio = audio_list[i]
411
+ in_sr = in_sr_list[i]
412
+ if len(audio.shape) == 3 and audio.shape[0] == 1:
413
+ # batchsize 1 was given by accident. Just squeeze it.
414
+ audio = audio.squeeze(0)
415
+ elif len(audio.shape) == 1:
416
+ # Mono signal, channel dimension is missing, unsqueeze it in
417
+ audio = audio.unsqueeze(0)
418
+ assert len(audio.shape)==2, "Audio should be shape (Channels x Length) with no batch dimension"
419
+ # Resample audio
420
+ if in_sr != self.sample_rate:
421
+ resample_tf = T.Resample(in_sr, self.sample_rate).to(audio.device)
422
+ audio = resample_tf(audio)
423
+ new_audio.append(audio)
424
+ if audio.shape[-1] > max_length:
425
+ max_length = audio.shape[-1]
426
+ # Pad every audio to the same length, multiple of model's downsampling ratio
427
+ padded_audio_length = max_length + (self.min_length - (max_length % self.min_length)) % self.min_length
428
+ for i in range(batch_size):
429
+ # Pad it & if necessary, mixdown/duplicate stereo/mono channels to support model
430
+ new_audio[i] = prepare_audio(new_audio[i], in_sr=in_sr, target_sr=in_sr, target_length=padded_audio_length,
431
+ target_channels=self.in_channels, device=new_audio[i].device).squeeze(0)
432
+ # convert to tensor
433
+ return torch.stack(new_audio)
434
+
435
+ def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwargs):
436
+ '''
437
+ Encode audios into latents. Audios should already be preprocesed by preprocess_audio_for_encoder.
438
+ If chunked is True, split the audio into chunks of a given maximum size chunk_size, with given overlap.
439
+ Overlap and chunk_size params are both measured in number of latents (not audio samples)
440
+ # and therefore you likely could use the same values with decode_audio.
441
+ A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
442
+ Every autoencoder will have a different receptive field size, and thus ideal overlap.
443
+ You can determine it empirically by diffing unchunked vs chunked output and looking at maximum diff.
444
+ The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
445
+ Smaller chunk_size uses less memory, but more compute.
446
+ The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
447
+ For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
448
+ '''
449
+ if not chunked:
450
+ # default behavior. Encode the entire audio in parallel
451
+ return self.encode(audio, **kwargs)
452
+ else:
453
+ # CHUNKED ENCODING
454
+ # samples_per_latent is just the downsampling ratio (which is also the upsampling ratio)
455
+ samples_per_latent = self.downsampling_ratio
456
+ total_size = audio.shape[2] # in samples
457
+ batch_size = audio.shape[0]
458
+ chunk_size *= samples_per_latent # converting metric in latents to samples
459
+ overlap *= samples_per_latent # converting metric in latents to samples
460
+ hop_size = chunk_size - overlap
461
+ chunks = []
462
+ for i in range(0, total_size - chunk_size + 1, hop_size):
463
+ chunk = audio[:,:,i:i+chunk_size]
464
+ chunks.append(chunk)
465
+ if i+chunk_size != total_size:
466
+ # Final chunk
467
+ chunk = audio[:,:,-chunk_size:]
468
+ chunks.append(chunk)
469
+ chunks = torch.stack(chunks)
470
+ num_chunks = chunks.shape[0]
471
+ # Note: y_size might be a different value from the latent length used in diffusion training
472
+ # because we can encode audio of varying lengths
473
+ # However, the audio should've been padded to a multiple of samples_per_latent by now.
474
+ y_size = total_size // samples_per_latent
475
+ # Create an empty latent, we will populate it with chunks as we encode them
476
+ y_final = torch.zeros((batch_size,self.latent_dim,y_size)).to(audio.device)
477
+ for i in range(num_chunks):
478
+ x_chunk = chunks[i,:]
479
+ # encode the chunk
480
+ y_chunk = self.encode(x_chunk)
481
+ # figure out where to put the audio along the time domain
482
+ if i == num_chunks-1:
483
+ # final chunk always goes at the end
484
+ t_end = y_size
485
+ t_start = t_end - y_chunk.shape[2]
486
+ else:
487
+ t_start = i * hop_size // samples_per_latent
488
+ t_end = t_start + chunk_size // samples_per_latent
489
+ # remove the edges of the overlaps
490
+ ol = overlap//samples_per_latent//2
491
+ chunk_start = 0
492
+ chunk_end = y_chunk.shape[2]
493
+ if i > 0:
494
+ # no overlap for the start of the first chunk
495
+ t_start += ol
496
+ chunk_start += ol
497
+ if i < num_chunks-1:
498
+ # no overlap for the end of the last chunk
499
+ t_end -= ol
500
+ chunk_end -= ol
501
+ # paste the chunked audio into our y_final output audio
502
+ y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
503
+ return y_final
504
+
505
+ def decode_audio(self, latents, chunked=False, overlap=32, chunk_size=128, **kwargs):
506
+ '''
507
+ Decode latents to audio.
508
+ If chunked is True, split the latents into chunks of a given maximum size chunk_size, with given overlap, both of which are measured in number of latents.
509
+ A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
510
+ Every autoencoder will have a different receptive field size, and thus ideal overlap.
511
+ You can determine it empirically by diffing unchunked vs chunked audio and looking at maximum diff.
512
+ The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
513
+ Smaller chunk_size uses less memory, but more compute.
514
+ The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
515
+ For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
516
+ '''
517
+ if not chunked:
518
+ # default behavior. Decode the entire latent in parallel
519
+ return self.decode(latents, **kwargs)
520
+ else:
521
+ # chunked decoding
522
+ hop_size = chunk_size - overlap
523
+ total_size = latents.shape[2]
524
+ batch_size = latents.shape[0]
525
+ chunks = []
526
+ for i in range(0, total_size - chunk_size + 1, hop_size):
527
+ chunk = latents[:,:,i:i+chunk_size]
528
+ chunks.append(chunk)
529
+ if i+chunk_size != total_size:
530
+ # Final chunk
531
+ chunk = latents[:,:,-chunk_size:]
532
+ chunks.append(chunk)
533
+ chunks = torch.stack(chunks)
534
+ num_chunks = chunks.shape[0]
535
+ # samples_per_latent is just the downsampling ratio
536
+ samples_per_latent = self.downsampling_ratio
537
+ # Create an empty waveform, we will populate it with chunks as decode them
538
+ y_size = total_size * samples_per_latent
539
+ y_final = torch.zeros((batch_size,self.out_channels,y_size)).to(latents.device)
540
+ for i in range(num_chunks):
541
+ x_chunk = chunks[i,:]
542
+ # decode the chunk
543
+ y_chunk = self.decode(x_chunk)
544
+ # figure out where to put the audio along the time domain
545
+ if i == num_chunks-1:
546
+ # final chunk always goes at the end
547
+ t_end = y_size
548
+ t_start = t_end - y_chunk.shape[2]
549
+ else:
550
+ t_start = i * hop_size * samples_per_latent
551
+ t_end = t_start + chunk_size * samples_per_latent
552
+ # remove the edges of the overlaps
553
+ ol = (overlap//2) * samples_per_latent
554
+ chunk_start = 0
555
+ chunk_end = y_chunk.shape[2]
556
+ if i > 0:
557
+ # no overlap for the start of the first chunk
558
+ t_start += ol
559
+ chunk_start += ol
560
+ if i < num_chunks-1:
561
+ # no overlap for the end of the last chunk
562
+ t_end -= ol
563
+ chunk_end -= ol
564
+ # paste the chunked audio into our y_final output audio
565
+ y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
566
+ return y_final
567
+
568
+
569
+ class DiffusionAutoencoder(AudioAutoencoder):
570
+ def __init__(
571
+ self,
572
+ diffusion: ConditionedDiffusionModel,
573
+ diffusion_downsampling_ratio,
574
+ *args,
575
+ **kwargs
576
+ ):
577
+ super().__init__(*args, **kwargs)
578
+
579
+ self.diffusion = diffusion
580
+
581
+ self.min_length = self.downsampling_ratio * diffusion_downsampling_ratio
582
+
583
+ if self.encoder is not None:
584
+ # Shrink the initial encoder parameters to avoid saturated latents
585
+ with torch.no_grad():
586
+ for param in self.encoder.parameters():
587
+ param *= 0.5
588
+
589
+ def decode(self, latents, steps=100):
590
+
591
+ upsampled_length = latents.shape[2] * self.downsampling_ratio
592
+
593
+ if self.bottleneck is not None:
594
+ latents = self.bottleneck.decode(latents)
595
+
596
+ if self.decoder is not None:
597
+ latents = self.decode(latents)
598
+
599
+ # Upsample latents to match diffusion length
600
+ if latents.shape[2] != upsampled_length:
601
+ latents = F.interpolate(latents, size=upsampled_length, mode='nearest')
602
+
603
+ noise = torch.randn(latents.shape[0], self.io_channels, upsampled_length, device=latents.device)
604
+ decoded = sample(self.diffusion, noise, steps, 0, input_concat_cond=latents)
605
+
606
+ if self.pretransform is not None:
607
+ if self.pretransform.enable_grad:
608
+ decoded = self.pretransform.decode(decoded)
609
+ else:
610
+ with torch.no_grad():
611
+ decoded = self.pretransform.decode(decoded)
612
+
613
+ return decoded
614
+
615
+ # AE factories
616
+
617
+ def create_encoder_from_config(encoder_config: Dict[str, Any]):
618
+ encoder_type = encoder_config.get("type", None)
619
+ assert encoder_type is not None, "Encoder type must be specified"
620
+
621
+ if encoder_type == "oobleck":
622
+ encoder = OobleckEncoder(
623
+ **encoder_config["config"]
624
+ )
625
+
626
+ elif encoder_type == "seanet":
627
+ from encodec.modules import SEANetEncoder
628
+ seanet_encoder_config = encoder_config["config"]
629
+
630
+ #SEANet encoder expects strides in reverse order
631
+ seanet_encoder_config["ratios"] = list(reversed(seanet_encoder_config.get("ratios", [2, 2, 2, 2, 2])))
632
+ encoder = SEANetEncoder(
633
+ **seanet_encoder_config
634
+ )
635
+ elif encoder_type == "dac":
636
+ dac_config = encoder_config["config"]
637
+
638
+ encoder = DACEncoderWrapper(**dac_config)
639
+ elif encoder_type == "local_attn":
640
+ from .local_attention import TransformerEncoder1D
641
+
642
+ local_attn_config = encoder_config["config"]
643
+
644
+ encoder = TransformerEncoder1D(
645
+ **local_attn_config
646
+ )
647
+ else:
648
+ raise ValueError(f"Unknown encoder type {encoder_type}")
649
+
650
+ requires_grad = encoder_config.get("requires_grad", True)
651
+ if not requires_grad:
652
+ for param in encoder.parameters():
653
+ param.requires_grad = False
654
+
655
+ return encoder
656
+
657
+ def create_decoder_from_config(decoder_config: Dict[str, Any]):
658
+ decoder_type = decoder_config.get("type", None)
659
+ assert decoder_type is not None, "Decoder type must be specified"
660
+
661
+ if decoder_type == "oobleck":
662
+ decoder = OobleckDecoder(
663
+ **decoder_config["config"]
664
+ )
665
+ elif decoder_type == "seanet":
666
+ from encodec.modules import SEANetDecoder
667
+
668
+ decoder = SEANetDecoder(
669
+ **decoder_config["config"]
670
+ )
671
+ elif decoder_type == "dac":
672
+ dac_config = decoder_config["config"]
673
+
674
+ decoder = DACDecoderWrapper(**dac_config)
675
+ elif decoder_type == "local_attn":
676
+ from .local_attention import TransformerDecoder1D
677
+
678
+ local_attn_config = decoder_config["config"]
679
+
680
+ decoder = TransformerDecoder1D(
681
+ **local_attn_config
682
+ )
683
+ else:
684
+ raise ValueError(f"Unknown decoder type {decoder_type}")
685
+
686
+ requires_grad = decoder_config.get("requires_grad", True)
687
+ if not requires_grad:
688
+ for param in decoder.parameters():
689
+ param.requires_grad = False
690
+
691
+ return decoder
692
+
693
+ def create_autoencoder_from_config(config: Dict[str, Any]):
694
+
695
+ ae_config = config["model"]
696
+
697
+ encoder = create_encoder_from_config(ae_config["encoder"])
698
+ decoder = create_decoder_from_config(ae_config["decoder"])
699
+
700
+ bottleneck = ae_config.get("bottleneck", None)
701
+
702
+ latent_dim = ae_config.get("latent_dim", None)
703
+ assert latent_dim is not None, "latent_dim must be specified in model config"
704
+ downsampling_ratio = ae_config.get("downsampling_ratio", None)
705
+ assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
706
+ io_channels = ae_config.get("io_channels", None)
707
+ assert io_channels is not None, "io_channels must be specified in model config"
708
+ sample_rate = config.get("sample_rate", None)
709
+ assert sample_rate is not None, "sample_rate must be specified in model config"
710
+
711
+ in_channels = ae_config.get("in_channels", None)
712
+ out_channels = ae_config.get("out_channels", None)
713
+
714
+ pretransform = ae_config.get("pretransform", None)
715
+
716
+ if pretransform is not None:
717
+ pretransform = create_pretransform_from_config(pretransform, sample_rate)
718
+
719
+ if bottleneck is not None:
720
+ bottleneck = create_bottleneck_from_config(bottleneck)
721
+
722
+ soft_clip = ae_config["decoder"].get("soft_clip", False)
723
+
724
+ return AudioAutoencoder(
725
+ encoder,
726
+ decoder,
727
+ io_channels=io_channels,
728
+ latent_dim=latent_dim,
729
+ downsampling_ratio=downsampling_ratio,
730
+ sample_rate=sample_rate,
731
+ bottleneck=bottleneck,
732
+ pretransform=pretransform,
733
+ in_channels=in_channels,
734
+ out_channels=out_channels,
735
+ soft_clip=soft_clip
736
+ )
737
+
738
+ def create_diffAE_from_config(config: Dict[str, Any]):
739
+
740
+ diffae_config = config["model"]
741
+
742
+ if "encoder" in diffae_config:
743
+ encoder = create_encoder_from_config(diffae_config["encoder"])
744
+ else:
745
+ encoder = None
746
+
747
+ if "decoder" in diffae_config:
748
+ decoder = create_decoder_from_config(diffae_config["decoder"])
749
+ else:
750
+ decoder = None
751
+
752
+ diffusion_model_type = diffae_config["diffusion"]["type"]
753
+
754
+ if diffusion_model_type == "DAU1d":
755
+ diffusion = DAU1DCondWrapper(**diffae_config["diffusion"]["config"])
756
+ elif diffusion_model_type == "adp_1d":
757
+ diffusion = UNet1DCondWrapper(**diffae_config["diffusion"]["config"])
758
+ elif diffusion_model_type == "dit":
759
+ diffusion = DiTWrapper(**diffae_config["diffusion"]["config"])
760
+
761
+ latent_dim = diffae_config.get("latent_dim", None)
762
+ assert latent_dim is not None, "latent_dim must be specified in model config"
763
+ downsampling_ratio = diffae_config.get("downsampling_ratio", None)
764
+ assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
765
+ io_channels = diffae_config.get("io_channels", None)
766
+ assert io_channels is not None, "io_channels must be specified in model config"
767
+ sample_rate = config.get("sample_rate", None)
768
+ assert sample_rate is not None, "sample_rate must be specified in model config"
769
+
770
+ bottleneck = diffae_config.get("bottleneck", None)
771
+
772
+ pretransform = diffae_config.get("pretransform", None)
773
+
774
+ if pretransform is not None:
775
+ pretransform = create_pretransform_from_config(pretransform, sample_rate)
776
+
777
+ if bottleneck is not None:
778
+ bottleneck = create_bottleneck_from_config(bottleneck)
779
+
780
+ diffusion_downsampling_ratio = None,
781
+
782
+ if diffusion_model_type == "DAU1d":
783
+ diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["strides"])
784
+ elif diffusion_model_type == "adp_1d":
785
+ diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["factors"])
786
+ elif diffusion_model_type == "dit":
787
+ diffusion_downsampling_ratio = 1
788
+
789
+ return DiffusionAutoencoder(
790
+ encoder=encoder,
791
+ decoder=decoder,
792
+ diffusion=diffusion,
793
+ io_channels=io_channels,
794
+ sample_rate=sample_rate,
795
+ latent_dim=latent_dim,
796
+ downsampling_ratio=downsampling_ratio,
797
+ diffusion_downsampling_ratio=diffusion_downsampling_ratio,
798
+ bottleneck=bottleneck,
799
+ pretransform=pretransform
800
+ )
stable_audio_tools/models/blocks.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import reduce
2
+ import math
3
+ import numpy as np
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from torch.backends.cuda import sdp_kernel
9
+ from packaging import version
10
+
11
+ from dac.nn.layers import Snake1d
12
+
13
+ class ResidualBlock(nn.Module):
14
+ def __init__(self, main, skip=None):
15
+ super().__init__()
16
+ self.main = nn.Sequential(*main)
17
+ self.skip = skip if skip else nn.Identity()
18
+
19
+ def forward(self, input):
20
+ return self.main(input) + self.skip(input)
21
+
22
+ class ResConvBlock(ResidualBlock):
23
+ def __init__(self, c_in, c_mid, c_out, is_last=False, kernel_size=5, conv_bias=True, use_snake=False):
24
+ skip = None if c_in == c_out else nn.Conv1d(c_in, c_out, 1, bias=False)
25
+ super().__init__([
26
+ nn.Conv1d(c_in, c_mid, kernel_size, padding=kernel_size//2, bias=conv_bias),
27
+ nn.GroupNorm(1, c_mid),
28
+ Snake1d(c_mid) if use_snake else nn.GELU(),
29
+ nn.Conv1d(c_mid, c_out, kernel_size, padding=kernel_size//2, bias=conv_bias),
30
+ nn.GroupNorm(1, c_out) if not is_last else nn.Identity(),
31
+ (Snake1d(c_out) if use_snake else nn.GELU()) if not is_last else nn.Identity(),
32
+ ], skip)
33
+
34
+ class SelfAttention1d(nn.Module):
35
+ def __init__(self, c_in, n_head=1, dropout_rate=0.):
36
+ super().__init__()
37
+ assert c_in % n_head == 0
38
+ self.norm = nn.GroupNorm(1, c_in)
39
+ self.n_head = n_head
40
+ self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1)
41
+ self.out_proj = nn.Conv1d(c_in, c_in, 1)
42
+ self.dropout = nn.Dropout(dropout_rate, inplace=True)
43
+
44
+ self.use_flash = False
45
+
46
+ if not self.use_flash:
47
+ return
48
+
49
+ device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
50
+
51
+ if device_properties.major == 8 and device_properties.minor == 0:
52
+ # Use flash attention for A100 GPUs
53
+ self.sdp_kernel_config = (False, True, True)
54
+ else:
55
+ # Don't use flash attention for other GPUs
56
+ self.sdp_kernel_config = (False, True, True)
57
+
58
+ def forward(self, input):
59
+ n, c, s = input.shape
60
+ qkv = self.qkv_proj(self.norm(input))
61
+ qkv = qkv.view(
62
+ [n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3)
63
+ q, k, v = qkv.chunk(3, dim=1)
64
+ scale = k.shape[3]**-0.25
65
+
66
+ if self.use_flash:
67
+ with sdp_kernel(*self.sdp_kernel_config):
68
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=False).contiguous().view([n, c, s])
69
+ else:
70
+ att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3)
71
+ y = (att @ v).transpose(2, 3).contiguous().view([n, c, s])
72
+
73
+
74
+ return input + self.dropout(self.out_proj(y))
75
+
76
+ class SkipBlock(nn.Module):
77
+ def __init__(self, *main):
78
+ super().__init__()
79
+ self.main = nn.Sequential(*main)
80
+
81
+ def forward(self, input):
82
+ return torch.cat([self.main(input), input], dim=1)
83
+
84
+ class FourierFeatures(nn.Module):
85
+ def __init__(self, in_features, out_features, std=1.):
86
+ super().__init__()
87
+ assert out_features % 2 == 0
88
+ self.weight = nn.Parameter(torch.randn(
89
+ [out_features // 2, in_features]) * std)
90
+
91
+ def forward(self, input):
92
+ f = 2 * math.pi * input @ self.weight.T
93
+ return torch.cat([f.cos(), f.sin()], dim=-1)
94
+
95
+ def expand_to_planes(input, shape):
96
+ return input[..., None].repeat([1, 1, shape[2]])
97
+
98
+ _kernels = {
99
+ 'linear':
100
+ [1 / 8, 3 / 8, 3 / 8, 1 / 8],
101
+ 'cubic':
102
+ [-0.01171875, -0.03515625, 0.11328125, 0.43359375,
103
+ 0.43359375, 0.11328125, -0.03515625, -0.01171875],
104
+ 'lanczos3':
105
+ [0.003689131001010537, 0.015056144446134567, -0.03399861603975296,
106
+ -0.066637322306633, 0.13550527393817902, 0.44638532400131226,
107
+ 0.44638532400131226, 0.13550527393817902, -0.066637322306633,
108
+ -0.03399861603975296, 0.015056144446134567, 0.003689131001010537]
109
+ }
110
+
111
+ class Downsample1d(nn.Module):
112
+ def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
113
+ super().__init__()
114
+ self.pad_mode = pad_mode
115
+ kernel_1d = torch.tensor(_kernels[kernel])
116
+ self.pad = kernel_1d.shape[0] // 2 - 1
117
+ self.register_buffer('kernel', kernel_1d)
118
+ self.channels_last = channels_last
119
+
120
+ def forward(self, x):
121
+ if self.channels_last:
122
+ x = x.permute(0, 2, 1)
123
+ x = F.pad(x, (self.pad,) * 2, self.pad_mode)
124
+ weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
125
+ indices = torch.arange(x.shape[1], device=x.device)
126
+ weight[indices, indices] = self.kernel.to(weight)
127
+ x = F.conv1d(x, weight, stride=2)
128
+ if self.channels_last:
129
+ x = x.permute(0, 2, 1)
130
+ return x
131
+
132
+
133
+ class Upsample1d(nn.Module):
134
+ def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
135
+ super().__init__()
136
+ self.pad_mode = pad_mode
137
+ kernel_1d = torch.tensor(_kernels[kernel]) * 2
138
+ self.pad = kernel_1d.shape[0] // 2 - 1
139
+ self.register_buffer('kernel', kernel_1d)
140
+ self.channels_last = channels_last
141
+
142
+ def forward(self, x):
143
+ if self.channels_last:
144
+ x = x.permute(0, 2, 1)
145
+ x = F.pad(x, ((self.pad + 1) // 2,) * 2, self.pad_mode)
146
+ weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
147
+ indices = torch.arange(x.shape[1], device=x.device)
148
+ weight[indices, indices] = self.kernel.to(weight)
149
+ x = F.conv_transpose1d(x, weight, stride=2, padding=self.pad * 2 + 1)
150
+ if self.channels_last:
151
+ x = x.permute(0, 2, 1)
152
+ return x
153
+
154
+ def Downsample1d_2(
155
+ in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
156
+ ) -> nn.Module:
157
+ assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even"
158
+
159
+ return nn.Conv1d(
160
+ in_channels=in_channels,
161
+ out_channels=out_channels,
162
+ kernel_size=factor * kernel_multiplier + 1,
163
+ stride=factor,
164
+ padding=factor * (kernel_multiplier // 2),
165
+ )
166
+
167
+
168
+ def Upsample1d_2(
169
+ in_channels: int, out_channels: int, factor: int, use_nearest: bool = False
170
+ ) -> nn.Module:
171
+
172
+ if factor == 1:
173
+ return nn.Conv1d(
174
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1
175
+ )
176
+
177
+ if use_nearest:
178
+ return nn.Sequential(
179
+ nn.Upsample(scale_factor=factor, mode="nearest"),
180
+ nn.Conv1d(
181
+ in_channels=in_channels,
182
+ out_channels=out_channels,
183
+ kernel_size=3,
184
+ padding=1,
185
+ ),
186
+ )
187
+ else:
188
+ return nn.ConvTranspose1d(
189
+ in_channels=in_channels,
190
+ out_channels=out_channels,
191
+ kernel_size=factor * 2,
192
+ stride=factor,
193
+ padding=factor // 2 + factor % 2,
194
+ output_padding=factor % 2,
195
+ )
196
+
197
+ def zero_init(layer):
198
+ nn.init.zeros_(layer.weight)
199
+ if layer.bias is not None:
200
+ nn.init.zeros_(layer.bias)
201
+ return layer
202
+
203
+ def rms_norm(x, scale, eps):
204
+ dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
205
+ mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
206
+ scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
207
+ return x * scale.to(x.dtype)
208
+
209
+ rms_norm = torch.compile(rms_norm)
210
+
211
+ class AdaRMSNorm(nn.Module):
212
+ def __init__(self, features, cond_features, eps=1e-6):
213
+ super().__init__()
214
+ self.eps = eps
215
+ self.linear = zero_init(nn.Linear(cond_features, features, bias=False))
216
+
217
+ def extra_repr(self):
218
+ return f"eps={self.eps},"
219
+
220
+ def forward(self, x, cond):
221
+ return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps)
222
+
223
+ def normalize(x, eps=1e-4):
224
+ dim = list(range(1, x.ndim))
225
+ n = torch.linalg.vector_norm(x, dim=dim, keepdim=True)
226
+ alpha = np.sqrt(n.numel() / x.numel())
227
+ return x / torch.add(eps, n, alpha=alpha)
228
+
229
+ class ForcedWNConv1d(nn.Module):
230
+ def __init__(self, in_channels, out_channels, kernel_size=1):
231
+ super().__init__()
232
+ self.weight = nn.Parameter(torch.randn([out_channels, in_channels, kernel_size]))
233
+
234
+ def forward(self, x):
235
+ if self.training:
236
+ with torch.no_grad():
237
+ self.weight.copy_(normalize(self.weight))
238
+
239
+ fan_in = self.weight[0].numel()
240
+
241
+ w = normalize(self.weight) / math.sqrt(fan_in)
242
+
243
+ return F.conv1d(x, w, padding='same')
244
+
245
+ # Kernels
246
+
247
+ use_compile = True
248
+
249
+ def compile(function, *args, **kwargs):
250
+ if not use_compile:
251
+ return function
252
+ try:
253
+ return torch.compile(function, *args, **kwargs)
254
+ except RuntimeError:
255
+ return function
256
+
257
+
258
+ @compile
259
+ def linear_geglu(x, weight, bias=None):
260
+ x = x @ weight.mT
261
+ if bias is not None:
262
+ x = x + bias
263
+ x, gate = x.chunk(2, dim=-1)
264
+ return x * F.gelu(gate)
265
+
266
+
267
+ @compile
268
+ def rms_norm(x, scale, eps):
269
+ dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
270
+ mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
271
+ scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
272
+ return x * scale.to(x.dtype)
273
+
274
+ # Layers
275
+
276
+ class LinearGEGLU(nn.Linear):
277
+ def __init__(self, in_features, out_features, bias=True):
278
+ super().__init__(in_features, out_features * 2, bias=bias)
279
+ self.out_features = out_features
280
+
281
+ def forward(self, x):
282
+ return linear_geglu(x, self.weight, self.bias)
283
+
284
+
285
+ class RMSNorm(nn.Module):
286
+ def __init__(self, shape, fix_scale = False, eps=1e-6):
287
+ super().__init__()
288
+ self.eps = eps
289
+
290
+ if fix_scale:
291
+ self.register_buffer("scale", torch.ones(shape))
292
+ else:
293
+ self.scale = nn.Parameter(torch.ones(shape))
294
+
295
+ def extra_repr(self):
296
+ return f"shape={tuple(self.scale.shape)}, eps={self.eps}"
297
+
298
+ def forward(self, x):
299
+ return rms_norm(x, self.scale, self.eps)
300
+
301
+ def snake_beta(x, alpha, beta):
302
+ return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
303
+
304
+ try:
305
+ snake_beta = torch.compile(snake_beta)
306
+ except RuntimeError:
307
+ pass
308
+
309
+ # Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license
310
+ # License available in LICENSES/LICENSE_NVIDIA.txt
311
+ class SnakeBeta(nn.Module):
312
+
313
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
314
+ super(SnakeBeta, self).__init__()
315
+ self.in_features = in_features
316
+
317
+ # initialize alpha
318
+ self.alpha_logscale = alpha_logscale
319
+ if self.alpha_logscale: # log scale alphas initialized to zeros
320
+ self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
321
+ self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
322
+ else: # linear scale alphas initialized to ones
323
+ self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
324
+ self.beta = nn.Parameter(torch.ones(in_features) * alpha)
325
+
326
+ self.alpha.requires_grad = alpha_trainable
327
+ self.beta.requires_grad = alpha_trainable
328
+
329
+ self.no_div_by_zero = 0.000000001
330
+
331
+ def forward(self, x):
332
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
333
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
334
+ if self.alpha_logscale:
335
+ alpha = torch.exp(alpha)
336
+ beta = torch.exp(beta)
337
+ x = snake_beta(x, alpha, beta)
338
+
339
+ return x
stable_audio_tools/models/bottleneck.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+ from einops import rearrange
6
+ from vector_quantize_pytorch import ResidualVQ, FSQ
7
+ from dac.nn.quantize import ResidualVectorQuantize as DACResidualVQ
8
+
9
+ class Bottleneck(nn.Module):
10
+ def __init__(self, is_discrete: bool = False):
11
+ super().__init__()
12
+
13
+ self.is_discrete = is_discrete
14
+
15
+ def encode(self, x, return_info=False, **kwargs):
16
+ raise NotImplementedError
17
+
18
+ def decode(self, x):
19
+ raise NotImplementedError
20
+
21
+ class DiscreteBottleneck(Bottleneck):
22
+ def __init__(self, num_quantizers, codebook_size, tokens_id):
23
+ super().__init__(is_discrete=True)
24
+
25
+ self.num_quantizers = num_quantizers
26
+ self.codebook_size = codebook_size
27
+ self.tokens_id = tokens_id
28
+
29
+ def decode_tokens(self, codes, **kwargs):
30
+ raise NotImplementedError
31
+
32
+ class TanhBottleneck(Bottleneck):
33
+ def __init__(self):
34
+ super().__init__(is_discrete=False)
35
+ self.tanh = nn.Tanh()
36
+
37
+ def encode(self, x, return_info=False):
38
+ info = {}
39
+
40
+ x = torch.tanh(x)
41
+
42
+ if return_info:
43
+ return x, info
44
+ else:
45
+ return x
46
+
47
+ def decode(self, x):
48
+ return x
49
+
50
+ def vae_sample(mean, scale):
51
+ stdev = nn.functional.softplus(scale) + 1e-4
52
+ var = stdev * stdev
53
+ logvar = torch.log(var)
54
+ latents = torch.randn_like(mean) * stdev + mean
55
+
56
+ kl = (mean * mean + var - logvar - 1).sum(1).mean()
57
+
58
+ return latents, kl
59
+
60
+ class VAEBottleneck(Bottleneck):
61
+ def __init__(self):
62
+ super().__init__(is_discrete=False)
63
+
64
+ def encode(self, x, return_info=False, **kwargs):
65
+ info = {}
66
+
67
+ mean, scale = x.chunk(2, dim=1)
68
+
69
+ x, kl = vae_sample(mean, scale)
70
+
71
+ info["kl"] = kl
72
+
73
+ if return_info:
74
+ return x, info
75
+ else:
76
+ return x
77
+
78
+ def decode(self, x):
79
+ return x
80
+
81
+ def compute_mean_kernel(x, y):
82
+ kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1]
83
+ return torch.exp(-kernel_input).mean()
84
+
85
+ def compute_mmd(latents):
86
+ latents_reshaped = latents.permute(0, 2, 1).reshape(-1, latents.shape[1])
87
+ noise = torch.randn_like(latents_reshaped)
88
+
89
+ latents_kernel = compute_mean_kernel(latents_reshaped, latents_reshaped)
90
+ noise_kernel = compute_mean_kernel(noise, noise)
91
+ latents_noise_kernel = compute_mean_kernel(latents_reshaped, noise)
92
+
93
+ mmd = latents_kernel + noise_kernel - 2 * latents_noise_kernel
94
+ return mmd.mean()
95
+
96
+ class WassersteinBottleneck(Bottleneck):
97
+ def __init__(self, noise_augment_dim: int = 0):
98
+ super().__init__(is_discrete=False)
99
+
100
+ self.noise_augment_dim = noise_augment_dim
101
+
102
+ def encode(self, x, return_info=False):
103
+ info = {}
104
+
105
+ if self.training and return_info:
106
+ mmd = compute_mmd(x)
107
+ info["mmd"] = mmd
108
+
109
+ if return_info:
110
+ return x, info
111
+
112
+ return x
113
+
114
+ def decode(self, x):
115
+
116
+ if self.noise_augment_dim > 0:
117
+ noise = torch.randn(x.shape[0], self.noise_augment_dim,
118
+ x.shape[-1]).type_as(x)
119
+ x = torch.cat([x, noise], dim=1)
120
+
121
+ return x
122
+
123
+ class L2Bottleneck(Bottleneck):
124
+ def __init__(self):
125
+ super().__init__(is_discrete=False)
126
+
127
+ def encode(self, x, return_info=False):
128
+ info = {}
129
+
130
+ x = F.normalize(x, dim=1)
131
+
132
+ if return_info:
133
+ return x, info
134
+ else:
135
+ return x
136
+
137
+ def decode(self, x):
138
+ return F.normalize(x, dim=1)
139
+
140
+ class RVQBottleneck(DiscreteBottleneck):
141
+ def __init__(self, **quantizer_kwargs):
142
+ super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
143
+ self.quantizer = ResidualVQ(**quantizer_kwargs)
144
+ self.num_quantizers = quantizer_kwargs["num_quantizers"]
145
+
146
+ def encode(self, x, return_info=False, **kwargs):
147
+ info = {}
148
+
149
+ x = rearrange(x, "b c n -> b n c")
150
+ x, indices, loss = self.quantizer(x)
151
+ x = rearrange(x, "b n c -> b c n")
152
+
153
+ info["quantizer_indices"] = indices
154
+ info["quantizer_loss"] = loss.mean()
155
+
156
+ if return_info:
157
+ return x, info
158
+ else:
159
+ return x
160
+
161
+ def decode(self, x):
162
+ return x
163
+
164
+ def decode_tokens(self, codes, **kwargs):
165
+ latents = self.quantizer.get_outputs_from_indices(codes)
166
+
167
+ return self.decode(latents, **kwargs)
168
+
169
+ class RVQVAEBottleneck(DiscreteBottleneck):
170
+ def __init__(self, **quantizer_kwargs):
171
+ super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
172
+ self.quantizer = ResidualVQ(**quantizer_kwargs)
173
+ self.num_quantizers = quantizer_kwargs["num_quantizers"]
174
+
175
+ def encode(self, x, return_info=False):
176
+ info = {}
177
+
178
+ x, kl = vae_sample(*x.chunk(2, dim=1))
179
+
180
+ info["kl"] = kl
181
+
182
+ x = rearrange(x, "b c n -> b n c")
183
+ x, indices, loss = self.quantizer(x)
184
+ x = rearrange(x, "b n c -> b c n")
185
+
186
+ info["quantizer_indices"] = indices
187
+ info["quantizer_loss"] = loss.mean()
188
+
189
+ if return_info:
190
+ return x, info
191
+ else:
192
+ return x
193
+
194
+ def decode(self, x):
195
+ return x
196
+
197
+ def decode_tokens(self, codes, **kwargs):
198
+ latents = self.quantizer.get_outputs_from_indices(codes)
199
+
200
+ return self.decode(latents, **kwargs)
201
+
202
+ class DACRVQBottleneck(DiscreteBottleneck):
203
+ def __init__(self, quantize_on_decode=False, **quantizer_kwargs):
204
+ super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
205
+ self.quantizer = DACResidualVQ(**quantizer_kwargs)
206
+ self.num_quantizers = quantizer_kwargs["n_codebooks"]
207
+ self.quantize_on_decode = quantize_on_decode
208
+
209
+ def encode(self, x, return_info=False, **kwargs):
210
+ info = {}
211
+
212
+ info["pre_quantizer"] = x
213
+
214
+ if self.quantize_on_decode:
215
+ return x, info if return_info else x
216
+
217
+ z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, **kwargs)
218
+
219
+ output = {
220
+ "z": z,
221
+ "codes": codes,
222
+ "latents": latents,
223
+ "vq/commitment_loss": commitment_loss,
224
+ "vq/codebook_loss": codebook_loss,
225
+ }
226
+
227
+ output["vq/commitment_loss"] /= self.num_quantizers
228
+ output["vq/codebook_loss"] /= self.num_quantizers
229
+
230
+ info.update(output)
231
+
232
+ if return_info:
233
+ return output["z"], info
234
+
235
+ return output["z"]
236
+
237
+ def decode(self, x):
238
+
239
+ if self.quantize_on_decode:
240
+ x = self.quantizer(x)[0]
241
+
242
+ return x
243
+
244
+ def decode_tokens(self, codes, **kwargs):
245
+ latents, _, _ = self.quantizer.from_codes(codes)
246
+
247
+ return self.decode(latents, **kwargs)
248
+
249
+ class DACRVQVAEBottleneck(DiscreteBottleneck):
250
+ def __init__(self, quantize_on_decode=False, **quantizer_kwargs):
251
+ super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
252
+ self.quantizer = DACResidualVQ(**quantizer_kwargs)
253
+ self.num_quantizers = quantizer_kwargs["n_codebooks"]
254
+ self.quantize_on_decode = quantize_on_decode
255
+
256
+ def encode(self, x, return_info=False, n_quantizers: int = None):
257
+ info = {}
258
+
259
+ mean, scale = x.chunk(2, dim=1)
260
+
261
+ x, kl = vae_sample(mean, scale)
262
+
263
+ info["pre_quantizer"] = x
264
+ info["kl"] = kl
265
+
266
+ if self.quantize_on_decode:
267
+ return x, info if return_info else x
268
+
269
+ z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, n_quantizers=n_quantizers)
270
+
271
+ output = {
272
+ "z": z,
273
+ "codes": codes,
274
+ "latents": latents,
275
+ "vq/commitment_loss": commitment_loss,
276
+ "vq/codebook_loss": codebook_loss,
277
+ }
278
+
279
+ output["vq/commitment_loss"] /= self.num_quantizers
280
+ output["vq/codebook_loss"] /= self.num_quantizers
281
+
282
+ info.update(output)
283
+
284
+ if return_info:
285
+ return output["z"], info
286
+
287
+ return output["z"]
288
+
289
+ def decode(self, x):
290
+
291
+ if self.quantize_on_decode:
292
+ x = self.quantizer(x)[0]
293
+
294
+ return x
295
+
296
+ def decode_tokens(self, codes, **kwargs):
297
+ latents, _, _ = self.quantizer.from_codes(codes)
298
+
299
+ return self.decode(latents, **kwargs)
300
+
301
+ class FSQBottleneck(DiscreteBottleneck):
302
+ def __init__(self, dim, levels):
303
+ super().__init__(num_quantizers = 1, codebook_size = levels ** dim, tokens_id = "quantizer_indices")
304
+ self.quantizer = FSQ(levels=[levels] * dim)
305
+
306
+ def encode(self, x, return_info=False):
307
+ info = {}
308
+
309
+ x = rearrange(x, "b c n -> b n c")
310
+ x, indices = self.quantizer(x)
311
+ x = rearrange(x, "b n c -> b c n")
312
+
313
+ info["quantizer_indices"] = indices
314
+
315
+ if return_info:
316
+ return x, info
317
+ else:
318
+ return x
319
+
320
+ def decode(self, x):
321
+ return x
322
+
323
+ def decode_tokens(self, tokens, **kwargs):
324
+ latents = self.quantizer.indices_to_codes(tokens)
325
+
326
+ return self.decode(latents, **kwargs)
stable_audio_tools/models/conditioners.py ADDED
@@ -0,0 +1,558 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Heavily influenced by https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conditioners.py
2
+
3
+ import torch
4
+ import logging, warnings
5
+ import string
6
+ import typing as tp
7
+ import gc
8
+
9
+ from .adp import NumberEmbedder
10
+ from ..inference.utils import set_audio_channels
11
+ from .factory import create_pretransform_from_config
12
+ from .pretransforms import Pretransform
13
+ from ..training.utils import copy_state_dict
14
+ from .utils import load_ckpt_state_dict
15
+
16
+ from torch import nn
17
+
18
+ class Conditioner(nn.Module):
19
+ def __init__(
20
+ self,
21
+ dim: int,
22
+ output_dim: int,
23
+ project_out: bool = False,
24
+ ):
25
+
26
+ super().__init__()
27
+
28
+ self.dim = dim
29
+ self.output_dim = output_dim
30
+ self.proj_out = nn.Linear(dim, output_dim) if (dim != output_dim or project_out) else nn.Identity()
31
+
32
+ def forward(self, x: tp.Any) -> tp.Any:
33
+ raise NotImplementedError()
34
+
35
+ class IntConditioner(Conditioner):
36
+ def __init__(self,
37
+ output_dim: int,
38
+ min_val: int=0,
39
+ max_val: int=512
40
+ ):
41
+ super().__init__(output_dim, output_dim)
42
+
43
+ self.min_val = min_val
44
+ self.max_val = max_val
45
+ self.int_embedder = nn.Embedding(max_val - min_val + 1, output_dim).requires_grad_(True)
46
+
47
+ def forward(self, ints: tp.List[int], device=None) -> tp.Any:
48
+
49
+ #self.int_embedder.to(device)
50
+
51
+ ints = torch.tensor(ints).to(device)
52
+ ints = ints.clamp(self.min_val, self.max_val)
53
+
54
+ int_embeds = self.int_embedder(ints).unsqueeze(1)
55
+
56
+ return [int_embeds, torch.ones(int_embeds.shape[0], 1).to(device)]
57
+
58
+ class NumberConditioner(Conditioner):
59
+ '''
60
+ Conditioner that takes a list of floats, normalizes them for a given range, and returns a list of embeddings
61
+ '''
62
+ def __init__(self,
63
+ output_dim: int,
64
+ min_val: float=0,
65
+ max_val: float=1
66
+ ):
67
+ super().__init__(output_dim, output_dim)
68
+
69
+ self.min_val = min_val
70
+ self.max_val = max_val
71
+
72
+ self.embedder = NumberEmbedder(features=output_dim)
73
+
74
+ def forward(self, floats: tp.List[float], device=None) -> tp.Any:
75
+
76
+ # Cast the inputs to floats
77
+ floats = [float(x) for x in floats]
78
+
79
+ floats = torch.tensor(floats).to(device)
80
+
81
+ floats = floats.clamp(self.min_val, self.max_val)
82
+
83
+ normalized_floats = (floats - self.min_val) / (self.max_val - self.min_val)
84
+
85
+ # Cast floats to same type as embedder
86
+ embedder_dtype = next(self.embedder.parameters()).dtype
87
+ normalized_floats = normalized_floats.to(embedder_dtype)
88
+
89
+ float_embeds = self.embedder(normalized_floats).unsqueeze(1)
90
+
91
+ return [float_embeds, torch.ones(float_embeds.shape[0], 1).to(device)]
92
+
93
+ class CLAPTextConditioner(Conditioner):
94
+ def __init__(self,
95
+ output_dim: int,
96
+ clap_ckpt_path,
97
+ use_text_features = False,
98
+ feature_layer_ix: int = -1,
99
+ audio_model_type="HTSAT-base",
100
+ enable_fusion=True,
101
+ project_out: bool = False,
102
+ finetune: bool = False):
103
+ super().__init__(768 if use_text_features else 512, output_dim, project_out=project_out)
104
+
105
+ self.use_text_features = use_text_features
106
+ self.feature_layer_ix = feature_layer_ix
107
+ self.finetune = finetune
108
+
109
+ # Suppress logging from transformers
110
+ previous_level = logging.root.manager.disable
111
+ logging.disable(logging.ERROR)
112
+ with warnings.catch_warnings():
113
+ warnings.simplefilter("ignore")
114
+ try:
115
+ import laion_clap
116
+ from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict
117
+
118
+ model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device='cpu')
119
+
120
+ if self.finetune:
121
+ self.model = model
122
+ else:
123
+ self.__dict__["model"] = model
124
+
125
+ state_dict = clap_load_state_dict(clap_ckpt_path)
126
+ self.model.model.load_state_dict(state_dict, strict=False)
127
+
128
+ if self.finetune:
129
+ self.model.model.text_branch.requires_grad_(True)
130
+ self.model.model.text_branch.train()
131
+ else:
132
+ self.model.model.text_branch.requires_grad_(False)
133
+ self.model.model.text_branch.eval()
134
+
135
+ finally:
136
+ logging.disable(previous_level)
137
+
138
+ del self.model.model.audio_branch
139
+
140
+ gc.collect()
141
+ torch.cuda.empty_cache()
142
+
143
+ def get_clap_features(self, prompts, layer_ix=-2, device: tp.Any = "cuda"):
144
+ prompt_tokens = self.model.tokenizer(prompts)
145
+ attention_mask = prompt_tokens["attention_mask"].to(device=device, non_blocking=True)
146
+ prompt_features = self.model.model.text_branch(
147
+ input_ids=prompt_tokens["input_ids"].to(device=device, non_blocking=True),
148
+ attention_mask=attention_mask,
149
+ output_hidden_states=True
150
+ )["hidden_states"][layer_ix]
151
+
152
+ return prompt_features, attention_mask
153
+
154
+ def forward(self, texts: tp.List[str], device: tp.Any = "cuda") -> tp.Any:
155
+ self.model.to(device)
156
+
157
+ if self.use_text_features:
158
+ if len(texts) == 1:
159
+ text_features, text_attention_mask = self.get_clap_features([texts[0], ""], layer_ix=self.feature_layer_ix, device=device)
160
+ text_features = text_features[:1, ...]
161
+ text_attention_mask = text_attention_mask[:1, ...]
162
+ else:
163
+ text_features, text_attention_mask = self.get_clap_features(texts, layer_ix=self.feature_layer_ix, device=device)
164
+ return [self.proj_out(text_features), text_attention_mask]
165
+
166
+ # Fix for CLAP bug when only one text is passed
167
+ if len(texts) == 1:
168
+ text_embedding = self.model.get_text_embedding([texts[0], ""], use_tensor=True)[:1, ...]
169
+ else:
170
+ text_embedding = self.model.get_text_embedding(texts, use_tensor=True)
171
+
172
+ text_embedding = text_embedding.unsqueeze(1).to(device)
173
+
174
+ return [self.proj_out(text_embedding), torch.ones(text_embedding.shape[0], 1).to(device)]
175
+
176
+ class CLAPAudioConditioner(Conditioner):
177
+ def __init__(self,
178
+ output_dim: int,
179
+ clap_ckpt_path,
180
+ audio_model_type="HTSAT-base",
181
+ enable_fusion=True,
182
+ project_out: bool = False):
183
+ super().__init__(512, output_dim, project_out=project_out)
184
+
185
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
186
+
187
+ # Suppress logging from transformers
188
+ previous_level = logging.root.manager.disable
189
+ logging.disable(logging.ERROR)
190
+ with warnings.catch_warnings():
191
+ warnings.simplefilter("ignore")
192
+ try:
193
+ import laion_clap
194
+ from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict
195
+
196
+ model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device='cpu')
197
+
198
+ if self.finetune:
199
+ self.model = model
200
+ else:
201
+ self.__dict__["model"] = model
202
+
203
+ state_dict = clap_load_state_dict(clap_ckpt_path)
204
+ self.model.model.load_state_dict(state_dict, strict=False)
205
+
206
+ if self.finetune:
207
+ self.model.model.audio_branch.requires_grad_(True)
208
+ self.model.model.audio_branch.train()
209
+ else:
210
+ self.model.model.audio_branch.requires_grad_(False)
211
+ self.model.model.audio_branch.eval()
212
+
213
+ finally:
214
+ logging.disable(previous_level)
215
+
216
+ del self.model.model.text_branch
217
+
218
+ gc.collect()
219
+ torch.cuda.empty_cache()
220
+
221
+ def forward(self, audios: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]] , device: tp.Any = "cuda") -> tp.Any:
222
+
223
+ self.model.to(device)
224
+
225
+ if isinstance(audios, list) or isinstance(audios, tuple):
226
+ audios = torch.cat(audios, dim=0)
227
+
228
+ # Convert to mono
229
+ mono_audios = audios.mean(dim=1)
230
+
231
+ with torch.cuda.amp.autocast(enabled=False):
232
+ audio_embedding = self.model.get_audio_embedding_from_data(mono_audios.float(), use_tensor=True)
233
+
234
+ audio_embedding = audio_embedding.unsqueeze(1).to(device)
235
+
236
+ return [self.proj_out(audio_embedding), torch.ones(audio_embedding.shape[0], 1).to(device)]
237
+
238
+ class T5Conditioner(Conditioner):
239
+
240
+ T5_MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b",
241
+ "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large",
242
+ "google/flan-t5-xl", "google/flan-t5-xxl"]
243
+
244
+ T5_MODEL_DIMS = {
245
+ "t5-small": 512,
246
+ "t5-base": 768,
247
+ "t5-large": 1024,
248
+ "t5-3b": 1024,
249
+ "t5-11b": 1024,
250
+ "t5-xl": 2048,
251
+ "t5-xxl": 4096,
252
+ "google/flan-t5-small": 512,
253
+ "google/flan-t5-base": 768,
254
+ "google/flan-t5-large": 1024,
255
+ "google/flan-t5-3b": 1024,
256
+ "google/flan-t5-11b": 1024,
257
+ "google/flan-t5-xl": 2048,
258
+ "google/flan-t5-xxl": 4096,
259
+ }
260
+
261
+ def __init__(
262
+ self,
263
+ output_dim: int,
264
+ t5_model_name: str = "t5-base",
265
+ max_length: str = 128,
266
+ enable_grad: bool = False,
267
+ project_out: bool = False,
268
+ ):
269
+ assert t5_model_name in self.T5_MODELS, f"Unknown T5 model name: {t5_model_name}"
270
+ super().__init__(self.T5_MODEL_DIMS[t5_model_name], output_dim, project_out=project_out)
271
+
272
+ from transformers import T5EncoderModel, AutoTokenizer
273
+
274
+ self.max_length = max_length
275
+ self.enable_grad = enable_grad
276
+
277
+ # Suppress logging from transformers
278
+ previous_level = logging.root.manager.disable
279
+ logging.disable(logging.ERROR)
280
+ with warnings.catch_warnings():
281
+ warnings.simplefilter("ignore")
282
+ try:
283
+ # self.tokenizer = T5Tokenizer.from_pretrained(t5_model_name, model_max_length = max_length)
284
+ # model = T5EncoderModel.from_pretrained(t5_model_name, max_length=max_length).train(enable_grad).requires_grad_(enable_grad)
285
+ self.tokenizer = AutoTokenizer.from_pretrained(t5_model_name)
286
+ model = T5EncoderModel.from_pretrained(t5_model_name).train(enable_grad).requires_grad_(enable_grad).to(torch.float16)
287
+ finally:
288
+ logging.disable(previous_level)
289
+
290
+ if self.enable_grad:
291
+ self.model = model
292
+ else:
293
+ self.__dict__["model"] = model
294
+
295
+
296
+ def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
297
+
298
+ self.model.to(device)
299
+ self.proj_out.to(device)
300
+
301
+ encoded = self.tokenizer(
302
+ texts,
303
+ truncation=True,
304
+ max_length=self.max_length,
305
+ padding="max_length",
306
+ return_tensors="pt",
307
+ )
308
+
309
+ input_ids = encoded["input_ids"].to(device)
310
+ attention_mask = encoded["attention_mask"].to(device).to(torch.bool)
311
+
312
+ self.model.eval()
313
+
314
+ with torch.cuda.amp.autocast(dtype=torch.float16) and torch.set_grad_enabled(self.enable_grad):
315
+ embeddings = self.model(
316
+ input_ids=input_ids, attention_mask=attention_mask
317
+ )["last_hidden_state"]
318
+
319
+ embeddings = self.proj_out(embeddings.float())
320
+
321
+ embeddings = embeddings * attention_mask.unsqueeze(-1).float()
322
+
323
+ return embeddings, attention_mask
324
+
325
+ class PhonemeConditioner(Conditioner):
326
+ """
327
+ A conditioner that turns text into phonemes and embeds them using a lookup table
328
+ Only works for English text
329
+
330
+ Args:
331
+ output_dim: the dimension of the output embeddings
332
+ max_length: the maximum number of phonemes to embed
333
+ project_out: whether to add another linear projection to the output embeddings
334
+ """
335
+
336
+ def __init__(
337
+ self,
338
+ output_dim: int,
339
+ max_length: int = 1024,
340
+ project_out: bool = False,
341
+ ):
342
+ super().__init__(output_dim, output_dim, project_out=project_out)
343
+
344
+ from g2p_en import G2p
345
+
346
+ self.max_length = max_length
347
+
348
+ self.g2p = G2p()
349
+
350
+ # Reserving 0 for padding, 1 for ignored
351
+ self.phoneme_embedder = nn.Embedding(len(self.g2p.phonemes) + 2, output_dim)
352
+
353
+ def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
354
+
355
+ self.phoneme_embedder.to(device)
356
+ self.proj_out.to(device)
357
+
358
+ batch_phonemes = [self.g2p(text) for text in texts] # shape [batch_size, length]
359
+
360
+ phoneme_ignore = [" ", *string.punctuation]
361
+
362
+ # Remove ignored phonemes and cut to max length
363
+ batch_phonemes = [[p if p not in phoneme_ignore else "_" for p in phonemes] for phonemes in batch_phonemes]
364
+
365
+ # Convert to ids
366
+ phoneme_ids = [[self.g2p.p2idx[p] + 2 if p in self.g2p.p2idx else 1 for p in phonemes] for phonemes in batch_phonemes]
367
+
368
+ #Pad to match longest and make a mask tensor for the padding
369
+ longest = max([len(ids) for ids in phoneme_ids])
370
+ phoneme_ids = [ids + [0] * (longest - len(ids)) for ids in phoneme_ids]
371
+
372
+ phoneme_ids = torch.tensor(phoneme_ids).to(device)
373
+
374
+ # Convert to embeddings
375
+ phoneme_embeds = self.phoneme_embedder(phoneme_ids)
376
+
377
+ phoneme_embeds = self.proj_out(phoneme_embeds)
378
+
379
+ return phoneme_embeds, torch.ones(phoneme_embeds.shape[0], phoneme_embeds.shape[1]).to(device)
380
+
381
+ class TokenizerLUTConditioner(Conditioner):
382
+ """
383
+ A conditioner that embeds text using a lookup table on a pretrained tokenizer's vocabulary
384
+
385
+ Args:
386
+ tokenizer_name: the name of the tokenizer from the Hugging Face transformers library
387
+ output_dim: the dimension of the output embeddings
388
+ max_length: the maximum length of the text to embed
389
+ project_out: whether to add another linear projection to the output embeddings
390
+ """
391
+
392
+ def __init__(
393
+ self,
394
+ tokenizer_name: str, # Name of a tokenizer from the Hugging Face transformers library
395
+ output_dim: int,
396
+ max_length: int = 1024,
397
+ project_out: bool = False,
398
+ ):
399
+ super().__init__(output_dim, output_dim, project_out=project_out)
400
+
401
+ from transformers import AutoTokenizer
402
+
403
+ # Suppress logging from transformers
404
+ previous_level = logging.root.manager.disable
405
+ logging.disable(logging.ERROR)
406
+ with warnings.catch_warnings():
407
+ warnings.simplefilter("ignore")
408
+ try:
409
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
410
+ finally:
411
+ logging.disable(previous_level)
412
+
413
+ self.max_length = max_length
414
+
415
+ self.token_embedder = nn.Embedding(len(self.tokenizer), output_dim)
416
+
417
+ def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
418
+ self.proj_out.to(device)
419
+
420
+ encoded = self.tokenizer(
421
+ texts,
422
+ truncation=True,
423
+ max_length=self.max_length,
424
+ padding="max_length",
425
+ return_tensors="pt",
426
+ )
427
+
428
+ input_ids = encoded["input_ids"].to(device)
429
+ attention_mask = encoded["attention_mask"].to(device).to(torch.bool)
430
+
431
+ embeddings = self.token_embedder(input_ids)
432
+
433
+ embeddings = self.proj_out(embeddings)
434
+
435
+ embeddings = embeddings * attention_mask.unsqueeze(-1).float()
436
+
437
+ return embeddings, attention_mask
438
+
439
+ class PretransformConditioner(Conditioner):
440
+ """
441
+ A conditioner that uses a pretransform's encoder for conditioning
442
+
443
+ Args:
444
+ pretransform: an instantiated pretransform to use for conditioning
445
+ output_dim: the dimension of the output embeddings
446
+ """
447
+ def __init__(self, pretransform: Pretransform, output_dim: int):
448
+ super().__init__(pretransform.encoded_channels, output_dim)
449
+
450
+ self.pretransform = pretransform
451
+
452
+ def forward(self, audio: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
453
+
454
+ self.pretransform.to(device)
455
+ self.proj_out.to(device)
456
+
457
+ if isinstance(audio, list) or isinstance(audio, tuple):
458
+ audio = torch.cat(audio, dim=0)
459
+
460
+ # Convert audio to pretransform input channels
461
+ audio = set_audio_channels(audio, self.pretransform.io_channels)
462
+
463
+ latents = self.pretransform.encode(audio)
464
+
465
+ latents = self.proj_out(latents)
466
+
467
+ return [latents, torch.ones(latents.shape[0], latents.shape[2]).to(latents.device)]
468
+
469
+ class MultiConditioner(nn.Module):
470
+ """
471
+ A module that applies multiple conditioners to an input dictionary based on the keys
472
+
473
+ Args:
474
+ conditioners: a dictionary of conditioners with keys corresponding to the keys of the conditioning input dictionary (e.g. "prompt")
475
+ default_keys: a dictionary of default keys to use if the key is not in the input dictionary (e.g. {"prompt_t5": "prompt"})
476
+ """
477
+ def __init__(self, conditioners: tp.Dict[str, Conditioner], default_keys: tp.Dict[str, str] = {}):
478
+ super().__init__()
479
+
480
+ self.conditioners = nn.ModuleDict(conditioners)
481
+ self.default_keys = default_keys
482
+
483
+ def forward(self, batch_metadata: tp.List[tp.Dict[str, tp.Any]], device: tp.Union[torch.device, str]) -> tp.Dict[str, tp.Any]:
484
+ output = {}
485
+
486
+ for key, conditioner in self.conditioners.items():
487
+ condition_key = key
488
+
489
+ conditioner_inputs = []
490
+
491
+ for x in batch_metadata:
492
+
493
+ if condition_key not in x:
494
+ if condition_key in self.default_keys:
495
+ condition_key = self.default_keys[condition_key]
496
+ else:
497
+ raise ValueError(f"Conditioner key {condition_key} not found in batch metadata")
498
+
499
+ #Unwrap the condition info if it's a single-element list or tuple, this is to support collation functions that wrap everything in a list
500
+ if isinstance(x[condition_key], list) or isinstance(x[condition_key], tuple) and len(x[condition_key]) == 1:
501
+ conditioner_inputs.append(x[condition_key][0])
502
+ else:
503
+ conditioner_inputs.append(x[condition_key])
504
+
505
+ output[key] = conditioner(conditioner_inputs, device)
506
+
507
+ return output
508
+
509
+ def create_multi_conditioner_from_conditioning_config(config: tp.Dict[str, tp.Any]) -> MultiConditioner:
510
+ """
511
+ Create a MultiConditioner from a conditioning config dictionary
512
+
513
+ Args:
514
+ config: the conditioning config dictionary
515
+ device: the device to put the conditioners on
516
+ """
517
+ conditioners = {}
518
+ cond_dim = config["cond_dim"]
519
+
520
+ default_keys = config.get("default_keys", {})
521
+
522
+ for conditioner_info in config["configs"]:
523
+ id = conditioner_info["id"]
524
+
525
+ conditioner_type = conditioner_info["type"]
526
+
527
+ conditioner_config = {"output_dim": cond_dim}
528
+
529
+ conditioner_config.update(conditioner_info["config"])
530
+
531
+ if conditioner_type == "t5":
532
+ conditioners[id] = T5Conditioner(**conditioner_config)
533
+ elif conditioner_type == "clap_text":
534
+ conditioners[id] = CLAPTextConditioner(**conditioner_config)
535
+ elif conditioner_type == "clap_audio":
536
+ conditioners[id] = CLAPAudioConditioner(**conditioner_config)
537
+ elif conditioner_type == "int":
538
+ conditioners[id] = IntConditioner(**conditioner_config)
539
+ elif conditioner_type == "number":
540
+ conditioners[id] = NumberConditioner(**conditioner_config)
541
+ elif conditioner_type == "phoneme":
542
+ conditioners[id] = PhonemeConditioner(**conditioner_config)
543
+ elif conditioner_type == "lut":
544
+ conditioners[id] = TokenizerLUTConditioner(**conditioner_config)
545
+ elif conditioner_type == "pretransform":
546
+ sample_rate = conditioner_config.pop("sample_rate", None)
547
+ assert sample_rate is not None, "Sample rate must be specified for pretransform conditioners"
548
+
549
+ pretransform = create_pretransform_from_config(conditioner_config.pop("pretransform_config"), sample_rate=sample_rate)
550
+
551
+ if conditioner_config.get("pretransform_ckpt_path", None) is not None:
552
+ pretransform.load_state_dict(load_ckpt_state_dict(conditioner_config.pop("pretransform_ckpt_path")))
553
+
554
+ conditioners[id] = PretransformConditioner(pretransform, **conditioner_config)
555
+ else:
556
+ raise ValueError(f"Unknown conditioner type: {conditioner_type}")
557
+
558
+ return MultiConditioner(conditioners, default_keys=default_keys)
stable_audio_tools/models/diffusion.py ADDED
@@ -0,0 +1,678 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+ from functools import partial, reduce
5
+ import numpy as np
6
+ import typing as tp
7
+
8
+ from .blocks import ResConvBlock, FourierFeatures, Upsample1d, Upsample1d_2, Downsample1d, Downsample1d_2, SelfAttention1d, SkipBlock, expand_to_planes
9
+ from .conditioners import MultiConditioner, create_multi_conditioner_from_conditioning_config
10
+ from .dit import DiffusionTransformer
11
+ from .factory import create_pretransform_from_config
12
+ from .pretransforms import Pretransform
13
+ from ..inference.generation import generate_diffusion_cond
14
+
15
+ from .adp import UNetCFG1d, UNet1d
16
+
17
+ from time import time
18
+
19
+ class Profiler:
20
+
21
+ def __init__(self):
22
+ self.ticks = [[time(), None]]
23
+
24
+ def tick(self, msg):
25
+ self.ticks.append([time(), msg])
26
+
27
+ def __repr__(self):
28
+ rep = 80 * "=" + "\n"
29
+ for i in range(1, len(self.ticks)):
30
+ msg = self.ticks[i][1]
31
+ ellapsed = self.ticks[i][0] - self.ticks[i - 1][0]
32
+ rep += msg + f": {ellapsed*1000:.2f}ms\n"
33
+ rep += 80 * "=" + "\n\n\n"
34
+ return rep
35
+
36
+ class DiffusionModel(nn.Module):
37
+ def __init__(self, *args, **kwargs):
38
+ super().__init__(*args, **kwargs)
39
+
40
+ def forward(self, x, t, **kwargs):
41
+ raise NotImplementedError()
42
+
43
+ class DiffusionModelWrapper(nn.Module):
44
+ def __init__(
45
+ self,
46
+ model: DiffusionModel,
47
+ io_channels,
48
+ sample_size,
49
+ sample_rate,
50
+ min_input_length,
51
+ pretransform: tp.Optional[Pretransform] = None,
52
+ ):
53
+ super().__init__()
54
+ self.io_channels = io_channels
55
+ self.sample_size = sample_size
56
+ self.sample_rate = sample_rate
57
+ self.min_input_length = min_input_length
58
+
59
+ self.model = model
60
+
61
+ if pretransform is not None:
62
+ self.pretransform = pretransform
63
+ else:
64
+ self.pretransform = None
65
+
66
+ def forward(self, x, t, **kwargs):
67
+ return self.model(x, t, **kwargs)
68
+
69
+ class ConditionedDiffusionModel(nn.Module):
70
+ def __init__(self,
71
+ *args,
72
+ supports_cross_attention: bool = False,
73
+ supports_input_concat: bool = False,
74
+ supports_global_cond: bool = False,
75
+ supports_prepend_cond: bool = False,
76
+ **kwargs):
77
+ super().__init__(*args, **kwargs)
78
+ self.supports_cross_attention = supports_cross_attention
79
+ self.supports_input_concat = supports_input_concat
80
+ self.supports_global_cond = supports_global_cond
81
+ self.supports_prepend_cond = supports_prepend_cond
82
+
83
+ def forward(self,
84
+ x: torch.Tensor,
85
+ t: torch.Tensor,
86
+ cross_attn_cond: torch.Tensor = None,
87
+ cross_attn_mask: torch.Tensor = None,
88
+ input_concat_cond: torch.Tensor = None,
89
+ global_embed: torch.Tensor = None,
90
+ prepend_cond: torch.Tensor = None,
91
+ prepend_cond_mask: torch.Tensor = None,
92
+ cfg_scale: float = 1.0,
93
+ cfg_dropout_prob: float = 0.0,
94
+ batch_cfg: bool = False,
95
+ rescale_cfg: bool = False,
96
+ **kwargs):
97
+ raise NotImplementedError()
98
+
99
+ class ConditionedDiffusionModelWrapper(nn.Module):
100
+ """
101
+ A diffusion model that takes in conditioning
102
+ """
103
+ def __init__(
104
+ self,
105
+ model: ConditionedDiffusionModel,
106
+ conditioner: MultiConditioner,
107
+ io_channels,
108
+ sample_rate,
109
+ min_input_length: int,
110
+ pretransform: tp.Optional[Pretransform] = None,
111
+ cross_attn_cond_ids: tp.List[str] = [],
112
+ global_cond_ids: tp.List[str] = [],
113
+ input_concat_ids: tp.List[str] = [],
114
+ prepend_cond_ids: tp.List[str] = [],
115
+ ):
116
+ super().__init__()
117
+
118
+ self.model = model
119
+ self.conditioner = conditioner
120
+ self.io_channels = io_channels
121
+ self.sample_rate = sample_rate
122
+ self.pretransform = pretransform
123
+ self.cross_attn_cond_ids = cross_attn_cond_ids
124
+ self.global_cond_ids = global_cond_ids
125
+ self.input_concat_ids = input_concat_ids
126
+ self.prepend_cond_ids = prepend_cond_ids
127
+ self.min_input_length = min_input_length
128
+
129
+ def get_conditioning_inputs(self, cond: tp.Dict[str, tp.Any], negative=False):
130
+ cross_attention_input = None
131
+ cross_attention_masks = None
132
+ global_cond = None
133
+ input_concat_cond = None
134
+ prepend_cond = None
135
+ prepend_cond_mask = None
136
+
137
+ if len(self.cross_attn_cond_ids) > 0:
138
+ # Concatenate all cross-attention inputs over the sequence dimension
139
+ # Assumes that the cross-attention inputs are of shape (batch, seq, channels)
140
+ cross_attention_input = []
141
+ cross_attention_masks = []
142
+
143
+ for key in self.cross_attn_cond_ids:
144
+ cross_attn_in, cross_attn_mask = cond[key]
145
+
146
+ # Add sequence dimension if it's not there
147
+ if len(cross_attn_in.shape) == 2:
148
+ cross_attn_in = cross_attn_in.unsqueeze(1)
149
+ cross_attn_mask = cross_attn_mask.unsqueeze(1)
150
+
151
+ cross_attention_input.append(cross_attn_in)
152
+ cross_attention_masks.append(cross_attn_mask)
153
+
154
+ cross_attention_input = torch.cat(cross_attention_input, dim=1)
155
+ cross_attention_masks = torch.cat(cross_attention_masks, dim=1)
156
+
157
+ if len(self.global_cond_ids) > 0:
158
+ # Concatenate all global conditioning inputs over the channel dimension
159
+ # Assumes that the global conditioning inputs are of shape (batch, channels)
160
+ global_cond = torch.cat([cond[key][0] for key in self.global_cond_ids], dim=-1)
161
+ if len(global_cond.shape) == 3:
162
+ global_cond = global_cond.squeeze(1)
163
+
164
+ if len(self.input_concat_ids) > 0:
165
+ # Concatenate all input concat conditioning inputs over the channel dimension
166
+ # Assumes that the input concat conditioning inputs are of shape (batch, channels, seq)
167
+ input_concat_cond = torch.cat([cond[key][0] for key in self.input_concat_ids], dim=1)
168
+
169
+ if len(self.prepend_cond_ids) > 0:
170
+ # Concatenate all prepend conditioning inputs over the sequence dimension
171
+ # Assumes that the prepend conditioning inputs are of shape (batch, seq, channels)
172
+ prepend_cond = torch.cat([cond[key][0] for key in self.prepend_cond_ids], dim=1)
173
+ prepend_cond_mask = torch.cat([cond[key][1] for key in self.prepend_cond_ids], dim=1)
174
+
175
+ if negative:
176
+ return {
177
+ "negative_cross_attn_cond": cross_attention_input,
178
+ "negative_cross_attn_mask": cross_attention_masks,
179
+ "negative_global_cond": global_cond,
180
+ "negative_input_concat_cond": input_concat_cond
181
+ }
182
+ else:
183
+ return {
184
+ "cross_attn_cond": cross_attention_input,
185
+ "cross_attn_mask": cross_attention_masks,
186
+ "global_cond": global_cond,
187
+ "input_concat_cond": input_concat_cond,
188
+ "prepend_cond": prepend_cond,
189
+ "prepend_cond_mask": prepend_cond_mask
190
+ }
191
+
192
+ def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs):
193
+ return self.model(x, t, **self.get_conditioning_inputs(cond), **kwargs)
194
+
195
+ def generate(self, *args, **kwargs):
196
+ return generate_diffusion_cond(self, *args, **kwargs)
197
+
198
+ class UNetCFG1DWrapper(ConditionedDiffusionModel):
199
+ def __init__(
200
+ self,
201
+ *args,
202
+ **kwargs
203
+ ):
204
+ super().__init__(supports_cross_attention=True, supports_global_cond=True, supports_input_concat=True)
205
+
206
+ self.model = UNetCFG1d(*args, **kwargs)
207
+
208
+ with torch.no_grad():
209
+ for param in self.model.parameters():
210
+ param *= 0.5
211
+
212
+ def forward(self,
213
+ x,
214
+ t,
215
+ cross_attn_cond=None,
216
+ cross_attn_mask=None,
217
+ input_concat_cond=None,
218
+ global_cond=None,
219
+ cfg_scale=1.0,
220
+ cfg_dropout_prob: float = 0.0,
221
+ batch_cfg: bool = False,
222
+ rescale_cfg: bool = False,
223
+ negative_cross_attn_cond=None,
224
+ negative_cross_attn_mask=None,
225
+ negative_global_cond=None,
226
+ negative_input_concat_cond=None,
227
+ prepend_cond=None,
228
+ prepend_cond_mask=None,
229
+ **kwargs):
230
+ p = Profiler()
231
+
232
+ p.tick("start")
233
+
234
+ channels_list = None
235
+ if input_concat_cond is not None:
236
+ channels_list = [input_concat_cond]
237
+
238
+ outputs = self.model(
239
+ x,
240
+ t,
241
+ embedding=cross_attn_cond,
242
+ embedding_mask=cross_attn_mask,
243
+ features=global_cond,
244
+ channels_list=channels_list,
245
+ embedding_scale=cfg_scale,
246
+ embedding_mask_proba=cfg_dropout_prob,
247
+ batch_cfg=batch_cfg,
248
+ rescale_cfg=rescale_cfg,
249
+ negative_embedding=negative_cross_attn_cond,
250
+ negative_embedding_mask=negative_cross_attn_mask,
251
+ **kwargs)
252
+
253
+ p.tick("UNetCFG1D forward")
254
+
255
+ #print(f"Profiler: {p}")
256
+ return outputs
257
+
258
+ class UNet1DCondWrapper(ConditionedDiffusionModel):
259
+ def __init__(
260
+ self,
261
+ *args,
262
+ **kwargs
263
+ ):
264
+ super().__init__(supports_cross_attention=False, supports_global_cond=True, supports_input_concat=True)
265
+
266
+ self.model = UNet1d(*args, **kwargs)
267
+
268
+ with torch.no_grad():
269
+ for param in self.model.parameters():
270
+ param *= 0.5
271
+
272
+ def forward(self,
273
+ x,
274
+ t,
275
+ input_concat_cond=None,
276
+ global_cond=None,
277
+ cross_attn_cond=None,
278
+ cross_attn_mask=None,
279
+ prepend_cond=None,
280
+ prepend_cond_mask=None,
281
+ cfg_scale=1.0,
282
+ cfg_dropout_prob: float = 0.0,
283
+ batch_cfg: bool = False,
284
+ rescale_cfg: bool = False,
285
+ negative_cross_attn_cond=None,
286
+ negative_cross_attn_mask=None,
287
+ negative_global_cond=None,
288
+ negative_input_concat_cond=None,
289
+ **kwargs):
290
+
291
+ channels_list = None
292
+ if input_concat_cond is not None:
293
+
294
+ # Interpolate input_concat_cond to the same length as x
295
+ if input_concat_cond.shape[2] != x.shape[2]:
296
+ input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')
297
+
298
+ channels_list = [input_concat_cond]
299
+
300
+ outputs = self.model(
301
+ x,
302
+ t,
303
+ features=global_cond,
304
+ channels_list=channels_list,
305
+ **kwargs)
306
+
307
+ return outputs
308
+
309
+ class UNet1DUncondWrapper(DiffusionModel):
310
+ def __init__(
311
+ self,
312
+ in_channels,
313
+ *args,
314
+ **kwargs
315
+ ):
316
+ super().__init__()
317
+
318
+ self.model = UNet1d(in_channels=in_channels, *args, **kwargs)
319
+
320
+ self.io_channels = in_channels
321
+
322
+ with torch.no_grad():
323
+ for param in self.model.parameters():
324
+ param *= 0.5
325
+
326
+ def forward(self, x, t, **kwargs):
327
+ return self.model(x, t, **kwargs)
328
+
329
+ class DAU1DCondWrapper(ConditionedDiffusionModel):
330
+ def __init__(
331
+ self,
332
+ *args,
333
+ **kwargs
334
+ ):
335
+ super().__init__(supports_cross_attention=False, supports_global_cond=False, supports_input_concat=True)
336
+
337
+ self.model = DiffusionAttnUnet1D(*args, **kwargs)
338
+
339
+ with torch.no_grad():
340
+ for param in self.model.parameters():
341
+ param *= 0.5
342
+
343
+ def forward(self,
344
+ x,
345
+ t,
346
+ input_concat_cond=None,
347
+ cross_attn_cond=None,
348
+ cross_attn_mask=None,
349
+ global_cond=None,
350
+ cfg_scale=1.0,
351
+ cfg_dropout_prob: float = 0.0,
352
+ batch_cfg: bool = False,
353
+ rescale_cfg: bool = False,
354
+ negative_cross_attn_cond=None,
355
+ negative_cross_attn_mask=None,
356
+ negative_global_cond=None,
357
+ negative_input_concat_cond=None,
358
+ prepend_cond=None,
359
+ **kwargs):
360
+
361
+ return self.model(x, t, cond = input_concat_cond)
362
+
363
+ class DiffusionAttnUnet1D(nn.Module):
364
+ def __init__(
365
+ self,
366
+ io_channels = 2,
367
+ depth=14,
368
+ n_attn_layers = 6,
369
+ channels = [128, 128, 256, 256] + [512] * 10,
370
+ cond_dim = 0,
371
+ cond_noise_aug = False,
372
+ kernel_size = 5,
373
+ learned_resample = False,
374
+ strides = [2] * 13,
375
+ conv_bias = True,
376
+ use_snake = False
377
+ ):
378
+ super().__init__()
379
+
380
+ self.cond_noise_aug = cond_noise_aug
381
+
382
+ self.io_channels = io_channels
383
+
384
+ if self.cond_noise_aug:
385
+ self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
386
+
387
+ self.timestep_embed = FourierFeatures(1, 16)
388
+
389
+ attn_layer = depth - n_attn_layers
390
+
391
+ strides = [1] + strides
392
+
393
+ block = nn.Identity()
394
+
395
+ conv_block = partial(ResConvBlock, kernel_size=kernel_size, conv_bias = conv_bias, use_snake=use_snake)
396
+
397
+ for i in range(depth, 0, -1):
398
+ c = channels[i - 1]
399
+ stride = strides[i-1]
400
+ if stride > 2 and not learned_resample:
401
+ raise ValueError("Must have stride 2 without learned resampling")
402
+
403
+ if i > 1:
404
+ c_prev = channels[i - 2]
405
+ add_attn = i >= attn_layer and n_attn_layers > 0
406
+ block = SkipBlock(
407
+ Downsample1d_2(c_prev, c_prev, stride) if (learned_resample or stride == 1) else Downsample1d("cubic"),
408
+ conv_block(c_prev, c, c),
409
+ SelfAttention1d(
410
+ c, c // 32) if add_attn else nn.Identity(),
411
+ conv_block(c, c, c),
412
+ SelfAttention1d(
413
+ c, c // 32) if add_attn else nn.Identity(),
414
+ conv_block(c, c, c),
415
+ SelfAttention1d(
416
+ c, c // 32) if add_attn else nn.Identity(),
417
+ block,
418
+ conv_block(c * 2 if i != depth else c, c, c),
419
+ SelfAttention1d(
420
+ c, c // 32) if add_attn else nn.Identity(),
421
+ conv_block(c, c, c),
422
+ SelfAttention1d(
423
+ c, c // 32) if add_attn else nn.Identity(),
424
+ conv_block(c, c, c_prev),
425
+ SelfAttention1d(c_prev, c_prev //
426
+ 32) if add_attn else nn.Identity(),
427
+ Upsample1d_2(c_prev, c_prev, stride) if learned_resample else Upsample1d(kernel="cubic")
428
+ )
429
+ else:
430
+ cond_embed_dim = 16 if not self.cond_noise_aug else 32
431
+ block = nn.Sequential(
432
+ conv_block((io_channels + cond_dim) + cond_embed_dim, c, c),
433
+ conv_block(c, c, c),
434
+ conv_block(c, c, c),
435
+ block,
436
+ conv_block(c * 2, c, c),
437
+ conv_block(c, c, c),
438
+ conv_block(c, c, io_channels, is_last=True),
439
+ )
440
+ self.net = block
441
+
442
+ with torch.no_grad():
443
+ for param in self.net.parameters():
444
+ param *= 0.5
445
+
446
+ def forward(self, x, t, cond=None, cond_aug_scale=None):
447
+
448
+ timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), x.shape)
449
+
450
+ inputs = [x, timestep_embed]
451
+
452
+ if cond is not None:
453
+ if cond.shape[2] != x.shape[2]:
454
+ cond = F.interpolate(cond, (x.shape[2], ), mode='linear', align_corners=False)
455
+
456
+ if self.cond_noise_aug:
457
+ # Get a random number between 0 and 1, uniformly sampled
458
+ if cond_aug_scale is None:
459
+ aug_level = self.rng.draw(cond.shape[0])[:, 0].to(cond)
460
+ else:
461
+ aug_level = torch.tensor([cond_aug_scale]).repeat([cond.shape[0]]).to(cond)
462
+
463
+ # Add noise to the conditioning signal
464
+ cond = cond + torch.randn_like(cond) * aug_level[:, None, None]
465
+
466
+ # Get embedding for noise cond level, reusing timestamp_embed
467
+ aug_level_embed = expand_to_planes(self.timestep_embed(aug_level[:, None]), x.shape)
468
+
469
+ inputs.append(aug_level_embed)
470
+
471
+ inputs.append(cond)
472
+
473
+ outputs = self.net(torch.cat(inputs, dim=1))
474
+
475
+ return outputs
476
+
477
+ class DiTWrapper(ConditionedDiffusionModel):
478
+ def __init__(
479
+ self,
480
+ *args,
481
+ **kwargs
482
+ ):
483
+ super().__init__(supports_cross_attention=True, supports_global_cond=False, supports_input_concat=False)
484
+
485
+ self.model = DiffusionTransformer(*args, **kwargs)
486
+
487
+ with torch.no_grad():
488
+ for param in self.model.parameters():
489
+ param *= 0.5
490
+
491
+ def forward(self,
492
+ x,
493
+ t,
494
+ cross_attn_cond=None,
495
+ cross_attn_mask=None,
496
+ negative_cross_attn_cond=None,
497
+ negative_cross_attn_mask=None,
498
+ input_concat_cond=None,
499
+ negative_input_concat_cond=None,
500
+ global_cond=None,
501
+ negative_global_cond=None,
502
+ prepend_cond=None,
503
+ prepend_cond_mask=None,
504
+ cfg_scale=1.0,
505
+ cfg_dropout_prob: float = 0.0,
506
+ batch_cfg: bool = True,
507
+ rescale_cfg: bool = False,
508
+ scale_phi: float = 0.0,
509
+ **kwargs):
510
+
511
+ assert batch_cfg, "batch_cfg must be True for DiTWrapper"
512
+ assert negative_input_concat_cond is None, "negative_input_concat_cond is not supported for DiTWrapper"
513
+
514
+ return self.model(
515
+ x,
516
+ t,
517
+ cross_attn_cond=cross_attn_cond,
518
+ cross_attn_cond_mask=cross_attn_mask,
519
+ negative_cross_attn_cond=negative_cross_attn_cond,
520
+ negative_cross_attn_mask=negative_cross_attn_mask,
521
+ input_concat_cond=input_concat_cond,
522
+ prepend_cond=prepend_cond,
523
+ prepend_cond_mask=prepend_cond_mask,
524
+ cfg_scale=cfg_scale,
525
+ cfg_dropout_prob=cfg_dropout_prob,
526
+ scale_phi=scale_phi,
527
+ global_embed=global_cond,
528
+ **kwargs)
529
+
530
+ class DiTUncondWrapper(DiffusionModel):
531
+ def __init__(
532
+ self,
533
+ in_channels,
534
+ *args,
535
+ **kwargs
536
+ ):
537
+ super().__init__()
538
+
539
+ self.model = DiffusionTransformer(io_channels=in_channels, *args, **kwargs)
540
+
541
+ self.io_channels = in_channels
542
+
543
+ with torch.no_grad():
544
+ for param in self.model.parameters():
545
+ param *= 0.5
546
+
547
+ def forward(self, x, t, **kwargs):
548
+ return self.model(x, t, **kwargs)
549
+
550
+ def create_diffusion_uncond_from_config(config: tp.Dict[str, tp.Any]):
551
+ diffusion_uncond_config = config["model"]
552
+
553
+ model_type = diffusion_uncond_config.get('type', None)
554
+
555
+ diffusion_config = diffusion_uncond_config.get('config', {})
556
+
557
+ assert model_type is not None, "Must specify model type in config"
558
+
559
+ pretransform = diffusion_uncond_config.get("pretransform", None)
560
+
561
+ sample_size = config.get("sample_size", None)
562
+ assert sample_size is not None, "Must specify sample size in config"
563
+
564
+ sample_rate = config.get("sample_rate", None)
565
+ assert sample_rate is not None, "Must specify sample rate in config"
566
+
567
+ if pretransform is not None:
568
+ pretransform = create_pretransform_from_config(pretransform, sample_rate)
569
+ min_input_length = pretransform.downsampling_ratio
570
+ else:
571
+ min_input_length = 1
572
+
573
+ if model_type == 'DAU1d':
574
+
575
+ model = DiffusionAttnUnet1D(
576
+ **diffusion_config
577
+ )
578
+
579
+ elif model_type == "adp_uncond_1d":
580
+
581
+ model = UNet1DUncondWrapper(
582
+ **diffusion_config
583
+ )
584
+
585
+ elif model_type == "dit":
586
+ model = DiTUncondWrapper(
587
+ **diffusion_config
588
+ )
589
+
590
+ else:
591
+ raise NotImplementedError(f'Unknown model type: {model_type}')
592
+
593
+ return DiffusionModelWrapper(model,
594
+ io_channels=model.io_channels,
595
+ sample_size=sample_size,
596
+ sample_rate=sample_rate,
597
+ pretransform=pretransform,
598
+ min_input_length=min_input_length)
599
+
600
+ def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]):
601
+
602
+ model_config = config["model"]
603
+
604
+ model_type = config["model_type"]
605
+
606
+ diffusion_config = model_config.get('diffusion', None)
607
+ assert diffusion_config is not None, "Must specify diffusion config"
608
+
609
+ diffusion_model_type = diffusion_config.get('type', None)
610
+ assert diffusion_model_type is not None, "Must specify diffusion model type"
611
+
612
+ diffusion_model_config = diffusion_config.get('config', None)
613
+ assert diffusion_model_config is not None, "Must specify diffusion model config"
614
+
615
+ if diffusion_model_type == 'adp_cfg_1d':
616
+ diffusion_model = UNetCFG1DWrapper(**diffusion_model_config)
617
+ elif diffusion_model_type == 'adp_1d':
618
+ diffusion_model = UNet1DCondWrapper(**diffusion_model_config)
619
+ elif diffusion_model_type == 'dit':
620
+ diffusion_model = DiTWrapper(**diffusion_model_config)
621
+
622
+ io_channels = model_config.get('io_channels', None)
623
+ assert io_channels is not None, "Must specify io_channels in model config"
624
+
625
+ sample_rate = config.get('sample_rate', None)
626
+ assert sample_rate is not None, "Must specify sample_rate in config"
627
+
628
+ conditioning_config = model_config.get('conditioning', None)
629
+
630
+ conditioner = None
631
+ if conditioning_config is not None:
632
+ conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config)
633
+
634
+ cross_attention_ids = diffusion_config.get('cross_attention_cond_ids', [])
635
+ global_cond_ids = diffusion_config.get('global_cond_ids', [])
636
+ input_concat_ids = diffusion_config.get('input_concat_ids', [])
637
+ prepend_cond_ids = diffusion_config.get('prepend_cond_ids', [])
638
+
639
+ pretransform = model_config.get("pretransform", None)
640
+
641
+ if pretransform is not None:
642
+ pretransform = create_pretransform_from_config(pretransform, sample_rate)
643
+ min_input_length = pretransform.downsampling_ratio
644
+ else:
645
+ min_input_length = 1
646
+
647
+ if diffusion_model_type == "adp_cfg_1d" or diffusion_model_type == "adp_1d":
648
+ min_input_length *= np.prod(diffusion_model_config["factors"])
649
+ elif diffusion_model_type == "dit":
650
+ min_input_length *= diffusion_model.model.patch_size
651
+
652
+ # Get the proper wrapper class
653
+
654
+ if model_type == "diffusion_cond" or model_type == "diffusion_cond_inpaint":
655
+ wrapper_fn = ConditionedDiffusionModelWrapper
656
+ elif model_type == "diffusion_prior":
657
+ prior_type = model_config.get("prior_type", None)
658
+ assert prior_type is not None, "Must specify prior_type in diffusion prior model config"
659
+
660
+ if prior_type == "mono_stereo":
661
+ from .diffusion_prior import MonoToStereoDiffusionPrior
662
+ wrapper_fn = MonoToStereoDiffusionPrior
663
+ elif prior_type == "source_separation":
664
+ from .diffusion_prior import SourceSeparationDiffusionPrior
665
+ wrapper_fn = SourceSeparationDiffusionPrior
666
+
667
+ return wrapper_fn(
668
+ diffusion_model,
669
+ conditioner,
670
+ min_input_length=min_input_length,
671
+ sample_rate=sample_rate,
672
+ cross_attn_cond_ids=cross_attention_ids,
673
+ global_cond_ids=global_cond_ids,
674
+ input_concat_ids=input_concat_ids,
675
+ prepend_cond_ids=prepend_cond_ids,
676
+ pretransform=pretransform,
677
+ io_channels=io_channels
678
+ )