diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..5959f0863b55b3c16d90c274338626e7e1a4518c --- /dev/null +++ b/app.py @@ -0,0 +1,21 @@ +import gradio as gr +from stable_audio_tools import get_pretrained_model +from stable_audio_tools.interface.gradio import create_ui +import json + +import torch + +def run(): + torch.manual_seed(42) + + interface = create_ui( + model_config_path = "model_config_float_conditioning_dit_all.json", + ckpt_path="epoch=1292-step=602500.ckpt", + model_half=False + ) + interface.queue() + interface.launch(share=True) + +if __name__ == "__main__": + run() + diff --git a/defaults.ini b/defaults.ini new file mode 100644 index 0000000000000000000000000000000000000000..9f240a3771abaab168a008e30ab4aad82e2c7c96 --- /dev/null +++ b/defaults.ini @@ -0,0 +1,56 @@ + +[DEFAULTS] + +#name of the run +name = stable_audio_tools + +# the batch size +batch_size = 8 + +# number of GPUs to use for training +num_gpus = 1 + +# number of nodes to use for training +num_nodes = 1 + +# Multi-GPU strategy for PyTorch Lightning +strategy = "" + +# Precision to use for training +precision = "16-mixed" + +# number of CPU workers for the DataLoader +num_workers = 8 + +# the random seed +seed = 42 + +# Batches for gradient accumulation +accum_batches = 1 + +# Number of steps between checkpoints +checkpoint_every = 10000 + +# trainer checkpoint file to restart training from +ckpt_path = '' + +# model checkpoint file to start a new training run from +pretrained_ckpt_path = '' + +# Checkpoint path for the pretransform model if needed +pretransform_ckpt_path = '' + +# configuration model specifying model hyperparameters +model_config = '' + +# configuration for datasets +dataset_config = '' + +# directory to save the checkpoints in +save_dir = '' + +# gradient_clip_val passed into PyTorch Lightning Trainer +gradient_clip_val = 0.0 + +# remove the weight norm from the pretransform model +remove_pretransform_weight_norm = '' \ No newline at end of file diff --git a/model_config_float_conditioning_dit_all.json b/model_config_float_conditioning_dit_all.json new file mode 100644 index 0000000000000000000000000000000000000000..599dba99263b1ebf3c1e909b4c9310759fceab86 --- /dev/null +++ b/model_config_float_conditioning_dit_all.json @@ -0,0 +1,208 @@ +{ + "model_type": "diffusion_cond", + "sample_size": 1048576, + "sample_rate": 44100, + "audio_channels": 1, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "dac", + "config": { + "in_channels": 1, + "latent_dim": 128, + "d_model": 128, + "strides": [4, 4, 8, 8] + } + }, + "decoder": { + "type": "dac", + "config": { + "out_channels": 1, + "latent_dim": 64, + "channels": 1536, + "rates": [8, 8, 4, 4] + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 1024, + "io_channels": 1 + } + }, + "conditioning": { + "configs": [ + { + "id": "latitude", + "type": "number", + "config": { + "min_val": -54.617412, + "max_val": -10.13994 + } + }, + { + "id": "longitude", + "type": "number", + "config": { + "min_val": 96.8233, + "max_val": 167.9619 + } + }, + { + "id": "temperature", + "type": "number", + "config": { + "min_val": -10.0, + "max_val": 55.0 + } + }, + { + "id": "humidity", + "type": "number", + "config": { + "min_val": 1, + "max_val": 100.0 + } + }, + { + "id": "wind_speed", + "type": "number", + "config": { + "min_val": 0, + "max_val": 50.0 + } + }, + { + "id": "pressure", + "type": "number", + "config": { + "min_val": 800.0, + "max_val": 1200.0 + } + }, + { + "id": "minutes_of_day", + "type": "number", + "config": { + "min_val": 0, + "max_val": 1439 + } + }, + { + "id": "day_of_year", + "type": "number", + "config": { + "min_val": 1, + "max_val": 366 + } + }, + { + "id": "seconds_start", + "type": "number", + "config": { + "min_val": 0, + "max_val": 512 + } + }, + { + "id": "seconds_total", + "type": "number", + "config": { + "min_val": 0, + "max_val": 512 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "cross_attention_cond_ids": ["latitude", "longitude", "temperature", "humidity", "wind_speed", "pressure", "minutes_of_day", "day_of_year","seconds_start", "seconds_total"], + "global_cond_ids": ["seconds_start", "seconds_total"], + "type": "dit", + "config": { + "io_channels": 64, + "embed_dim": 768, + "depth": 24, + "num_heads": 24, + "cond_token_dim": 768, + "global_cond_dim": 1536, + "project_cond_tokens": false, + "transformer_type": "continuous_transformer" + } + }, + "io_channels": 64 + }, + + "training": { + "use_ema": true, + "log_loss_info": false, + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 5e-5, + "betas": [0.9, 0.999], + "weight_decay": 1e-3 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 2500, + "demo_steps": 100, + "num_demos": 3, + "demo_cfg_scales": [3, 5, 7], + "demo_cond": [ + { + "latitude": -24.005512, + "longitude": 133.368348, + "temperature": 25.5, + "humidity": 60, + "wind_speed": 8, + "pressure": 1000, + "minutes_of_day": 400, + "day_of_year": 110, + "seconds_start": 0, + "seconds_total": 22 + }, + { + "latitude": -26.987815, + "longitude": 153.129068, + "temperature": 31.5, + "humidity": 70, + "wind_speed": 12, + "pressure": 1010, + "minutes_of_day": 600, + "day_of_year": 57, + "seconds_start": 0, + "seconds_total": 22 + }, + { + "latitude": -12.546364, + "longitude": 130.919605, + "temperature": 28.5, + "humidity": 60, + "wind_speed": 18, + "pressure": 1015, + "minutes_of_day": 1140, + "day_of_year": 280, + "seconds_start": 0, + "seconds_total": 22 + } + ] + } + } +} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..7fd26b970b848120ad4f6ab398e1864ab0111e60 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..4f29995f34fbb0a166ab598f0264f74747933462 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,34 @@ +--extra-index-url https://download.pytorch.org/whl/cu113 +torch +aeiou +alias-free-torch +auraloss +descript-audio-codec +einops +einops-exts +ema-pytorch +encodec +gradio +huggingface_hub +importlib-resources +k-diffusion +laion-clap +local-attention +pandas +pedalboard +prefigure +pytorch_lightning +PyWavelets +safetensors +sentencepiece +s3fs +torch +torchaudio +torchmetrics +tqdm +transformers +v-diffusion-pytorch +vector-quantize-pytorch +wandb +webdataset +x_transformers \ No newline at end of file diff --git a/run_gradio.py b/run_gradio.py new file mode 100644 index 0000000000000000000000000000000000000000..ae3ba95c20079c4ea0539c5bbbb2970f31eed236 --- /dev/null +++ b/run_gradio.py @@ -0,0 +1,31 @@ +from stable_audio_tools import get_pretrained_model +from stable_audio_tools.interface.gradio import create_ui +import json + +import torch + +def main(args): + torch.manual_seed(42) + + interface = create_ui( + model_config_path = args.model_config, + ckpt_path=args.ckpt_path, + pretrained_name=args.pretrained_name, + pretransform_ckpt_path=args.pretransform_ckpt_path, + model_half=args.model_half + ) + interface.queue() + interface.launch(share=True, auth=(args.username, args.password) if args.username is not None else None) + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description='Run gradio interface') + parser.add_argument('--pretrained-name', type=str, help='Name of pretrained model', required=False) + parser.add_argument('--model-config', type=str, help='Path to model config', required=False) + parser.add_argument('--ckpt-path', type=str, help='Path to model checkpoint', required=False) + parser.add_argument('--pretransform-ckpt-path', type=str, help='Optional to model pretransform checkpoint', required=False) + parser.add_argument('--username', type=str, help='Gradio username', required=False) + parser.add_argument('--password', type=str, help='Gradio password', required=False) + parser.add_argument('--model-half', action='store_true', help='Whether to use half precision', required=False) + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/run_tests.py b/run_tests.py new file mode 100644 index 0000000000000000000000000000000000000000..7a2c1e5cd455a980e3090c590777e8f55164c535 --- /dev/null +++ b/run_tests.py @@ -0,0 +1,44 @@ +from stable_audio_tools import get_pretrained_model +from stable_audio_tools.interface.testing import runTests +print(runTests) # Check if it prints a function reference + + +import torch + +def main(args): + torch.manual_seed(42) + runTests(model_config_path = args.model_config, + ckpt_path=args.ckpt_path, + pretrained_name=args.pretrained_name, + pretransform_ckpt_path=args.pretransform_ckpt_path, + model_half=args.model_half, + output_dir=args.output_dir, + json_dir=args.json_dir + ) + + + + + +if __name__ == "__main__": + import argparse + import sys + parser = argparse.ArgumentParser(description='Run generation tests') + parser.add_argument('--pretrained-name', type=str, help='Name of pretrained model', required=False) + parser.add_argument('--model-config', type=str, help='Path to model config', required=False) + parser.add_argument('--ckpt-path', type=str, help='Path to model checkpoint', required=False) + parser.add_argument('--pretransform-ckpt-path', type=str, help='Optional to model pretransform checkpoint', required=False) + parser.add_argument('--model-half', action='store_true', help='Whether to use half precision', required=False) + parser.add_argument('--output-dir', type=str, help='Path to output directory', required=True) + parser.add_argument('--json-dir', type=str, help='Path to directory containing JSON files', required=True) + print("Running tests") + + print("Arguments provided:", sys.argv[1:]) + + args = parser.parse_args() + print("Parsed arguments:", args) + main(args) + + + + diff --git a/scripts/ds_zero_to_pl_ckpt.py b/scripts/ds_zero_to_pl_ckpt.py new file mode 100644 index 0000000000000000000000000000000000000000..528a51609d0126e812b5319b85481e3184d5c8ca --- /dev/null +++ b/scripts/ds_zero_to_pl_ckpt.py @@ -0,0 +1,14 @@ +import argparse +from lightning.pytorch.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("--save_path", type=str, help="Path to the zero checkpoint") + parser.add_argument("--output_path", type=str, help="Path to the output checkpoint", default="lightning_model.pt") + args = parser.parse_args() + + # lightning deepspeed has saved a directory instead of a file + save_path = args.save_path + output_path = args.output_path + convert_zero_checkpoint_to_fp32_state_dict(save_path, output_path) \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..d3e06717c2ef2e4fcf0e031aaa8527e254c3ed48 --- /dev/null +++ b/setup.py @@ -0,0 +1,46 @@ +from setuptools import setup, find_packages + +setup( + name='stable-audio-tools', + version='0.0.12', + url='https://github.com/Stability-AI/stable-audio-tools.git', + author='Stability AI', + description='Training and inference tools for generative audio models from Stability AI', + packages=find_packages(), + install_requires=[ + 'audiocraft==1.0.0', + 'aeiou==0.0.20', + 'alias-free-torch==0.0.6', + 'auraloss==0.4.0', + 'descript-audio-codec==1.0.0', + 'einops==0.7.0', + 'einops-exts==0.0.4', + 'ema-pytorch==0.2.3', + 'encodec==0.1.1', + 'flash-attn>=2.5.0', + 'gradio>=3.42.0', + 'huggingface_hub', + 'importlib-resources==5.12.0', + 'k-diffusion==0.1.1', + 'laion-clap==1.1.4', + 'local-attention==1.8.6', + 'pandas==2.0.2', + 'pedalboard==0.7.4', + 'prefigure==0.0.9', + 'pytorch_lightning==2.1.0', + 'PyWavelets==1.4.1', + 'safetensors', + 'sentencepiece==0.1.99', + 's3fs', + 'torch>=2.0.1', + 'torchaudio>=2.0.2', + 'torchmetrics==0.11.4', + 'tqdm', + 'transformers==4.33.3', + 'v-diffusion-pytorch==0.0.2', + 'vector-quantize-pytorch==1.9.14', + 'wandb==0.15.4', + 'webdataset==0.2.48', + 'x-transformers<1.27.0' + ], +) \ No newline at end of file diff --git a/stable_audio_tools/__init__.py b/stable_audio_tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..22446be50eb6617222c50b007a38d06490cbab41 --- /dev/null +++ b/stable_audio_tools/__init__.py @@ -0,0 +1,2 @@ +from .models.factory import create_model_from_config, create_model_from_config_path +from .models.pretrained import get_pretrained_model \ No newline at end of file diff --git a/stable_audio_tools/configs/dataset_configs/custom_metadata/custom_md_example.py b/stable_audio_tools/configs/dataset_configs/custom_metadata/custom_md_example.py new file mode 100644 index 0000000000000000000000000000000000000000..d7ca14ae77a714a40e760bc56fa97d81fb56983a --- /dev/null +++ b/stable_audio_tools/configs/dataset_configs/custom_metadata/custom_md_example.py @@ -0,0 +1,4 @@ +def get_custom_metadata(info, audio): + + # Use relative path as the prompt + return {"prompt": info["relpath"]} \ No newline at end of file diff --git a/stable_audio_tools/configs/dataset_configs/local_training_example.json b/stable_audio_tools/configs/dataset_configs/local_training_example.json new file mode 100644 index 0000000000000000000000000000000000000000..94444ea1bfb506b188fa0ce66d4979a2a0220adb --- /dev/null +++ b/stable_audio_tools/configs/dataset_configs/local_training_example.json @@ -0,0 +1,11 @@ +{ + "dataset_type": "audio_dir", + "datasets": [ + { + "id": "my_audio", + "path": "/path/to/audio/dataset/" + } + ], + "custom_metadata_module": "/path/to/custom_metadata/custom_md_example.py", + "random_crop": true +} \ No newline at end of file diff --git a/stable_audio_tools/configs/dataset_configs/s3_wds_example.json b/stable_audio_tools/configs/dataset_configs/s3_wds_example.json new file mode 100644 index 0000000000000000000000000000000000000000..71e3a8b9ddf3fdf23a2fd23a8d07155328140f74 --- /dev/null +++ b/stable_audio_tools/configs/dataset_configs/s3_wds_example.json @@ -0,0 +1,10 @@ +{ + "dataset_type": "s3", + "datasets": [ + { + "id": "s3-test", + "s3_path": "s3://my-bucket/datasets/webdataset/audio/" + } + ], + "random_crop": true +} \ No newline at end of file diff --git a/stable_audio_tools/configs/model_configs/autoencoders/dac_2048_32_vae.json b/stable_audio_tools/configs/model_configs/autoencoders/dac_2048_32_vae.json new file mode 100644 index 0000000000000000000000000000000000000000..d0f3eba705c12f4088dbe058d16fc50f3f7d3c06 --- /dev/null +++ b/stable_audio_tools/configs/model_configs/autoencoders/dac_2048_32_vae.json @@ -0,0 +1,71 @@ +{ + "model_type": "autoencoder", + "sample_size": 65536, + "sample_rate": 44100, + "audio_channels": 1, + "model": { + "encoder": { + "type": "dac", + "config": { + "latent_dim": 64, + "d_model": 128, + "strides": [4, 8, 8, 8] + } + }, + "decoder": { + "type": "dac", + "config": { + "latent_dim": 32, + "channels": 1536, + "rates": [8, 8, 8, 4] + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 32, + "downsampling_ratio": 2048, + "io_channels": 1 + }, + "training": { + "learning_rate": 1e-4, + "warmup_steps": 0, + "use_ema": false, + "loss_configs": { + "discriminator": { + "type": "encodec", + "config": { + "filters": 32, + "n_ffts": [2048, 1024, 512, 256, 128, 64, 32], + "hop_lengths": [512, 256, 128, 64, 32, 16, 8], + "win_lengths": [2048, 1024, 512, 256, 128, 64, 32] + }, + "weights": { + "adversarial": 0.1, + "feature_matching": 5.0 + } + }, + "spectral": { + "type": "mrstft", + "config": { + "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32], + "hop_sizes": [512, 256, 128, 64, 32, 16, 8], + "win_lengths": [2048, 1024, 512, 256, 128, 64, 32], + "perceptual_weighting": true + }, + "weights": { + "mrstft": 1.0 + } + }, + "time": { + "type": "l1", + "weights": { + "l1": 0.0 + } + } + }, + "demo": { + "demo_every": 2000 + } + } +} \ No newline at end of file diff --git a/stable_audio_tools/configs/model_configs/autoencoders/encodec_musicgen_rvq.json b/stable_audio_tools/configs/model_configs/autoencoders/encodec_musicgen_rvq.json new file mode 100644 index 0000000000000000000000000000000000000000..e76bd3d9a12ae028f3038562ce8082b8eadca116 --- /dev/null +++ b/stable_audio_tools/configs/model_configs/autoencoders/encodec_musicgen_rvq.json @@ -0,0 +1,88 @@ +{ + "model_type": "autoencoder", + "sample_size": 32000, + "sample_rate": 32000, + "audio_channels": 1, + "model": { + "encoder": { + "type": "seanet", + "config": { + "channels": 1, + "dimension": 128, + "n_filters": 64, + "ratios": [4, 4, 5, 8], + "n_residual_layers": 1, + "dilation_base": 2, + "lstm": 2, + "norm": "weight_norm" + } + }, + "decoder": { + "type": "seanet", + "config": { + "channels": 1, + "dimension": 128, + "n_filters": 64, + "ratios": [4, 4, 5, 8], + "n_residual_layers": 1, + "dilation_base": 2, + "lstm": 2, + "norm": "weight_norm" + } + }, + "bottleneck": { + "type": "rvq", + "config": { + "num_quantizers": 4, + "codebook_size": 2048, + "dim": 128, + "decay": 0.99, + "threshold_ema_dead_code": 2 + } + }, + "latent_dim": 128, + "downsampling_ratio": 640, + "io_channels": 1 + }, + "training": { + "learning_rate": 1e-4, + "warmup_steps": 0, + "use_ema": true, + "loss_configs": { + "discriminator": { + "type": "encodec", + "config": { + "filters": 32, + "n_ffts": [2048, 1024, 512, 256, 128], + "hop_lengths": [512, 256, 128, 64, 32], + "win_lengths": [2048, 1024, 512, 256, 128] + }, + "weights": { + "adversarial": 0.1, + "feature_matching": 5.0 + } + }, + "spectral": { + "type": "mrstft", + "config": { + "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32], + "hop_sizes": [512, 256, 128, 64, 32, 16, 8], + "win_lengths": [2048, 1024, 512, 256, 128, 64, 32], + "perceptual_weighting": true + }, + "weights": { + "mrstft": 1.0 + } + }, + "time": { + "type": "l1", + "weights": { + "l1": 0.0 + } + } + }, + "demo": { + "demo_every": 2000 + } + } +} \ No newline at end of file diff --git a/stable_audio_tools/configs/model_configs/autoencoders/stable_audio_1_0_vae.json b/stable_audio_tools/configs/model_configs/autoencoders/stable_audio_1_0_vae.json new file mode 100644 index 0000000000000000000000000000000000000000..26dcb25f3322e79422c7ab288aace9f23e711768 --- /dev/null +++ b/stable_audio_tools/configs/model_configs/autoencoders/stable_audio_1_0_vae.json @@ -0,0 +1,111 @@ +{ + "model_type": "autoencoder", + "sample_size": 65536, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "encoder": { + "type": "dac", + "config": { + "in_channels": 2, + "latent_dim": 128, + "d_model": 128, + "strides": [4, 4, 8, 8] + } + }, + "decoder": { + "type": "dac", + "config": { + "out_channels": 2, + "latent_dim": 64, + "channels": 1536, + "rates": [8, 8, 4, 4] + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 1024, + "io_channels": 2 + }, + "training": { + "learning_rate": 1e-4, + "warmup_steps": 0, + "use_ema": true, + "optimizer_configs": { + "autoencoder": { + "optimizer": { + "type": "AdamW", + "config": { + "betas": [0.8, 0.99], + "lr": 1e-4 + } + }, + "scheduler": { + "type": "ExponentialLR", + "config": { + "gamma": 0.999996 + } + } + }, + "discriminator": { + "optimizer": { + "type": "AdamW", + "config": { + "betas": [0.8, 0.99], + "lr": 1e-4 + } + }, + "scheduler": { + "type": "ExponentialLR", + "config": { + "gamma": 0.999996 + } + } + } + }, + "loss_configs": { + "discriminator": { + "type": "encodec", + "config": { + "filters": 32, + "n_ffts": [2048, 1024, 512, 256, 128], + "hop_lengths": [512, 256, 128, 64, 32], + "win_lengths": [2048, 1024, 512, 256, 128] + }, + "weights": { + "adversarial": 0.1, + "feature_matching": 5.0 + } + }, + "spectral": { + "type": "mrstft", + "config": { + "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32], + "hop_sizes": [512, 256, 128, 64, 32, 16, 8], + "win_lengths": [2048, 1024, 512, 256, 128, 64, 32], + "perceptual_weighting": true + }, + "weights": { + "mrstft": 1.0 + } + }, + "time": { + "type": "l1", + "weights": { + "l1": 0.0 + } + }, + "bottleneck": { + "type": "kl", + "weights": { + "kl": 1e-6 + } + } + }, + "demo": { + "demo_every": 2000 + } + } +} \ No newline at end of file diff --git a/stable_audio_tools/configs/model_configs/autoencoders/stable_audio_2_0_vae.json b/stable_audio_tools/configs/model_configs/autoencoders/stable_audio_2_0_vae.json new file mode 100644 index 0000000000000000000000000000000000000000..3aa762f2a4bb3ff631fd53401c5ec22e524e9bf2 --- /dev/null +++ b/stable_audio_tools/configs/model_configs/autoencoders/stable_audio_2_0_vae.json @@ -0,0 +1,122 @@ +{ + "model_type": "autoencoder", + "sample_size": 65536, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + }, + "training": { + "learning_rate": 1.5e-4, + "warmup_steps": 0, + "use_ema": true, + "optimizer_configs": { + "autoencoder": { + "optimizer": { + "type": "AdamW", + "config": { + "betas": [0.8, 0.99], + "lr": 1.5e-4, + "weight_decay": 1e-3 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 200000, + "power": 0.5, + "warmup": 0.999 + } + } + }, + "discriminator": { + "optimizer": { + "type": "AdamW", + "config": { + "betas": [0.8, 0.99], + "lr": 3e-4, + "weight_decay": 1e-3 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 200000, + "power": 0.5, + "warmup": 0.999 + } + } + } + }, + "loss_configs": { + "discriminator": { + "type": "encodec", + "config": { + "filters": 64, + "n_ffts": [2048, 1024, 512, 256, 128], + "hop_lengths": [512, 256, 128, 64, 32], + "win_lengths": [2048, 1024, 512, 256, 128] + }, + "weights": { + "adversarial": 0.1, + "feature_matching": 5.0 + } + }, + "spectral": { + "type": "mrstft", + "config": { + "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32], + "hop_sizes": [512, 256, 128, 64, 32, 16, 8], + "win_lengths": [2048, 1024, 512, 256, 128, 64, 32], + "perceptual_weighting": true + }, + "weights": { + "mrstft": 1.0 + } + }, + "time": { + "type": "l1", + "weights": { + "l1": 0.0 + } + }, + "bottleneck": { + "type": "kl", + "weights": { + "kl": 1e-4 + } + } + }, + "demo": { + "demo_every": 2000 + } + } +} \ No newline at end of file diff --git a/stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_base.json b/stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_base.json new file mode 100644 index 0000000000000000000000000000000000000000..a57f9e4abc99157128f505c6f5e5188101808f9b --- /dev/null +++ b/stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_base.json @@ -0,0 +1,18 @@ +{ + "model_type": "diffusion_uncond", + "sample_size": 65536, + "sample_rate": 48000, + "model": { + "type": "DAU1d", + "config": { + "n_attn_layers": 5 + } + }, + "training": { + "learning_rate": 1e-4, + "demo": { + "demo_every": 2000, + "demo_steps": 250 + } + } +} \ No newline at end of file diff --git a/stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_base_16k.json b/stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_base_16k.json new file mode 100644 index 0000000000000000000000000000000000000000..4319a56731f981d2de1a294c2727e087475d1633 --- /dev/null +++ b/stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_base_16k.json @@ -0,0 +1,18 @@ +{ + "model_type": "diffusion_uncond", + "sample_size": 65536, + "sample_rate": 16000, + "model": { + "type": "DAU1d", + "config": { + "n_attn_layers": 5 + } + }, + "training": { + "learning_rate": 1e-4, + "demo": { + "demo_every": 2000, + "demo_steps": 250 + } + } +} \ No newline at end of file diff --git a/stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_base_44k.json b/stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_base_44k.json new file mode 100644 index 0000000000000000000000000000000000000000..fedb83fa3c741d7c1d4a7215e909862a81730805 --- /dev/null +++ b/stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_base_44k.json @@ -0,0 +1,18 @@ +{ + "model_type": "diffusion_uncond", + "sample_size": 65536, + "sample_rate": 44100, + "model": { + "type": "DAU1d", + "config": { + "n_attn_layers": 5 + } + }, + "training": { + "learning_rate": 4e-5, + "demo": { + "demo_every": 2000, + "demo_steps": 250 + } + } +} \ No newline at end of file diff --git a/stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_large.json b/stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_large.json new file mode 100644 index 0000000000000000000000000000000000000000..f9f96a455ad9e40b4ea624bda4b9c209fea4bcca --- /dev/null +++ b/stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_large.json @@ -0,0 +1,18 @@ +{ + "model_type": "diffusion_uncond", + "sample_size": 131072, + "sample_rate": 48000, + "model": { + "type": "DAU1d", + "config": { + "n_attn_layers": 5 + } + }, + "training": { + "learning_rate": 1e-4, + "demo": { + "demo_every": 2000, + "demo_steps": 250 + } + } +} \ No newline at end of file diff --git a/stable_audio_tools/configs/model_configs/txt2audio/musicgen_small_finetune.json b/stable_audio_tools/configs/model_configs/txt2audio/musicgen_small_finetune.json new file mode 100644 index 0000000000000000000000000000000000000000..ac1e432948fb49387197b2db5edfd738ce9ca20a --- /dev/null +++ b/stable_audio_tools/configs/model_configs/txt2audio/musicgen_small_finetune.json @@ -0,0 +1,22 @@ +{ + "model_type": "musicgen", + "sample_size": 320000, + "sample_rate": 32000, + "audio_channels": 1, + "model": { + "pretrained": "small" + }, + "training": { + "learning_rate": 1e-4, + "demo": { + "demo_every": 2000, + "demo_cond": [ + {"prompt": "Keywords: Atmospheres, Orchestral Drone, Bass, Sci-Fi Ambient Soundscape, Synthesiser, Middle Eastern Vocal, dramatic piano"}, + {"prompt": "Genre: Corporate|Instruments: Ukulele, Drums, Clapping, Glockenspiel"}, + {"prompt": "Description: 116 BPM rock drums, drum track for a rock song"}, + {"prompt": "A grand orchestral arrangement with thunderous percussion, epic brass fanfares, and soaring strings, creating a cinematic atmosphere fit for a heroic battle."} + ], + "demo_cfg_scales": [3, 6, 9] + } + } +} \ No newline at end of file diff --git a/stable_audio_tools/configs/model_configs/txt2audio/stable_audio_1_0.json b/stable_audio_tools/configs/model_configs/txt2audio/stable_audio_1_0.json new file mode 100644 index 0000000000000000000000000000000000000000..22db891d8529f894a26a0c7f7d173ef2ae84b744 --- /dev/null +++ b/stable_audio_tools/configs/model_configs/txt2audio/stable_audio_1_0.json @@ -0,0 +1,107 @@ +{ + "model_type": "diffusion_cond", + "sample_size": 4194304, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "dac", + "config": { + "in_channels": 2, + "latent_dim": 128, + "d_model": 128, + "strides": [4, 4, 8, 8] + } + }, + "decoder": { + "type": "dac", + "config": { + "out_channels": 2, + "latent_dim": 64, + "channels": 1536, + "rates": [8, 8, 4, 4] + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 1024, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "prompt", + "type": "clap_text", + "config": { + "audio_model_type": "HTSAT-base", + "enable_fusion": true, + "clap_ckpt_path": "/path/to/clap.ckpt", + "use_text_features": true, + "feature_layer_ix": -2 + } + }, + { + "id": "seconds_start", + "type": "int", + "config": { + "min_val": 0, + "max_val": 512 + } + }, + { + "id": "seconds_total", + "type": "int", + "config": { + "min_val": 0, + "max_val": 512 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "type": "adp_cfg_1d", + "cross_attention_cond_ids": ["prompt", "seconds_start", "seconds_total"], + "config": { + "in_channels": 64, + "context_embedding_features": 768, + "context_embedding_max_length": 79, + "channels": 256, + "resnet_groups": 16, + "kernel_multiplier_downsample": 2, + "multipliers": [4, 4, 4, 5, 5], + "factors": [1, 2, 2, 4], + "num_blocks": [2, 2, 2, 2], + "attentions": [1, 3, 3, 3, 3], + "attention_heads": 16, + "attention_multiplier": 4, + "use_nearest_upsample": false, + "use_skip_scale": true, + "use_context_time": true + } + }, + "io_channels": 64 + }, + "training": { + "learning_rate": 4e-5, + "demo": { + "demo_every": 2000, + "demo_steps": 250, + "num_demos": 4, + "demo_cond": [ + {"prompt": "A beautiful piano arpeggio", "seconds_start": 0, "seconds_total": 95}, + {"prompt": "A tropical house track with upbeat melodies, a driving bassline, and cheery vibes", "seconds_start": 0, "seconds_total": 90}, + {"prompt": "A cool 80s glam rock song with driving drums and distorted guitars", "seconds_start": 0, "seconds_total": 180}, + {"prompt": "A grand orchestral arrangement with thunderous percussion, epic brass fanfares, and soaring strings, creating a cinematic atmosphere fit for a heroic battle.", "seconds_start": 0, "seconds_total": 60} + ], + "demo_cfg_scales": [3, 6, 9] + } + } +} \ No newline at end of file diff --git a/stable_audio_tools/configs/model_configs/txt2audio/stable_audio_2_0.json b/stable_audio_tools/configs/model_configs/txt2audio/stable_audio_2_0.json new file mode 100644 index 0000000000000000000000000000000000000000..bf8d5742ec183452c683db763753995eb347dfbc --- /dev/null +++ b/stable_audio_tools/configs/model_configs/txt2audio/stable_audio_2_0.json @@ -0,0 +1,127 @@ +{ + "model_type": "diffusion_cond", + "sample_size": 12582912, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "prompt", + "type": "clap_text", + "config": { + "audio_model_type": "HTSAT-base", + "enable_fusion": true, + "clap_ckpt_path": "/path/to/clap.ckpt", + "use_text_features": true, + "feature_layer_ix": -2 + } + }, + { + "id": "seconds_start", + "type": "number", + "config": { + "min_val": 0, + "max_val": 512 + } + }, + { + "id": "seconds_total", + "type": "number", + "config": { + "min_val": 0, + "max_val": 512 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "cross_attention_cond_ids": ["prompt", "seconds_start", "seconds_total"], + "global_cond_ids": ["seconds_start", "seconds_total"], + "type": "dit", + "config": { + "io_channels": 64, + "embed_dim": 1536, + "depth": 24, + "num_heads": 24, + "cond_token_dim": 768, + "global_cond_dim": 1536, + "project_cond_tokens": false, + "transformer_type": "continuous_transformer" + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 5e-5, + "betas": [0.9, 0.999], + "weight_decay": 1e-3 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 2000, + "demo_steps": 250, + "num_demos": 4, + "demo_cond": [ + {"prompt": "A beautiful piano arpeggio", "seconds_start": 0, "seconds_total": 80}, + {"prompt": "A tropical house track with upbeat melodies, a driving bassline, and cheery vibes", "seconds_start": 0, "seconds_total": 250}, + {"prompt": "A cool 80s glam rock song with driving drums and distorted guitars", "seconds_start": 0, "seconds_total": 180}, + {"prompt": "A grand orchestral arrangement with thunderous percussion, epic brass fanfares, and soaring strings, creating a cinematic atmosphere fit for a heroic battle.", "seconds_start": 0, "seconds_total": 190} + ], + "demo_cfg_scales": [3, 6, 9] + } + } +} \ No newline at end of file diff --git a/stable_audio_tools/data/__init__.py b/stable_audio_tools/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/stable_audio_tools/data/dataset.py b/stable_audio_tools/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..0236e62df7c5ce84da96674cb7b7ccb921d070f4 --- /dev/null +++ b/stable_audio_tools/data/dataset.py @@ -0,0 +1,597 @@ +import importlib +import numpy as np +import io +import os +import posixpath +import random +import re +import subprocess +import time +import torch +import torchaudio +import webdataset as wds + +from aeiou.core import is_silence +from os import path +from pedalboard.io import AudioFile +from torchaudio import transforms as T +from typing import Optional, Callable, List + +from .utils import Stereo, Mono, PhaseFlipper, PadCrop_Normalized_T + +AUDIO_KEYS = ("flac", "wav", "mp3", "m4a", "ogg", "opus") + +# fast_scandir implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py + +def fast_scandir( + dir:str, # top-level directory at which to begin scanning + ext:list, # list of allowed file extensions, + #max_size = 1 * 1000 * 1000 * 1000 # Only files < 1 GB + ): + "very fast `glob` alternative. from https://stackoverflow.com/a/59803793/4259243" + subfolders, files = [], [] + ext = ['.'+x if x[0]!='.' else x for x in ext] # add starting period to extensions if needed + try: # hope to avoid 'permission denied' by this try + for f in os.scandir(dir): + try: # 'hope to avoid too many levels of symbolic links' error + if f.is_dir(): + subfolders.append(f.path) + elif f.is_file(): + file_ext = os.path.splitext(f.name)[1].lower() + is_hidden = os.path.basename(f.path).startswith(".") + + if file_ext in ext and not is_hidden: + files.append(f.path) + except: + pass + except: + pass + + for dir in list(subfolders): + sf, f = fast_scandir(dir, ext) + subfolders.extend(sf) + files.extend(f) + return subfolders, files + +def keyword_scandir( + dir: str, # top-level directory at which to begin scanning + ext: list, # list of allowed file extensions + keywords: list, # list of keywords to search for in the file name +): + "very fast `glob` alternative. from https://stackoverflow.com/a/59803793/4259243" + subfolders, files = [], [] + # make keywords case insensitive + keywords = [keyword.lower() for keyword in keywords] + # add starting period to extensions if needed + ext = ['.'+x if x[0] != '.' else x for x in ext] + banned_words = ["paxheader", "__macosx"] + try: # hope to avoid 'permission denied' by this try + for f in os.scandir(dir): + try: # 'hope to avoid too many levels of symbolic links' error + if f.is_dir(): + subfolders.append(f.path) + elif f.is_file(): + is_hidden = f.name.split("/")[-1][0] == '.' + has_ext = os.path.splitext(f.name)[1].lower() in ext + name_lower = f.name.lower() + has_keyword = any( + [keyword in name_lower for keyword in keywords]) + has_banned = any( + [banned_word in name_lower for banned_word in banned_words]) + if has_ext and has_keyword and not has_banned and not is_hidden and not os.path.basename(f.path).startswith("._"): + files.append(f.path) + except: + pass + except: + pass + + for dir in list(subfolders): + sf, f = keyword_scandir(dir, ext, keywords) + subfolders.extend(sf) + files.extend(f) + return subfolders, files + +def get_audio_filenames( + paths: list, # directories in which to search + keywords=None, + exts=['.wav', '.mp3', '.flac', '.ogg', '.aif', '.opus'] +): + "recursively get a list of audio filenames" + filenames = [] + if type(paths) is str: + paths = [paths] + for path in paths: # get a list of relevant filenames + if keywords is not None: + subfolders, files = keyword_scandir(path, exts, keywords) + else: + subfolders, files = fast_scandir(path, exts) + filenames.extend(files) + return filenames + +class SampleDataset(torch.utils.data.Dataset): + def __init__( + self, + paths, + sample_size=65536, + sample_rate=48000, + keywords=None, + relpath=None, + random_crop=True, + force_channels="stereo", + custom_metadata_fn: Optional[Callable[[str], str]] = None + ): + super().__init__() + self.filenames = [] + self.relpath = relpath + + self.augs = torch.nn.Sequential( + PhaseFlipper(), + ) + + self.pad_crop = PadCrop_Normalized_T(sample_size, sample_rate, randomize=random_crop) + + self.force_channels = force_channels + + self.encoding = torch.nn.Sequential( + Stereo() if self.force_channels == "stereo" else torch.nn.Identity(), + Mono() if self.force_channels == "mono" else torch.nn.Identity(), + ) + + self.filenames = get_audio_filenames(paths, keywords) + + print(f'Found {len(self.filenames)} files') + + self.sr = sample_rate + + self.custom_metadata_fn = custom_metadata_fn + + def load_file(self, filename): + ext = filename.split(".")[-1] + + if ext == "mp3": + with AudioFile(filename) as f: + audio = f.read(f.frames) + audio = torch.from_numpy(audio) + in_sr = f.samplerate + else: + audio, in_sr = torchaudio.load(filename, format=ext) + + if in_sr != self.sr: + resample_tf = T.Resample(in_sr, self.sr) + audio = resample_tf(audio) + + return audio + + def __len__(self): + return len(self.filenames) + + def __getitem__(self, idx): + audio_filename = self.filenames[idx] + try: + start_time = time.time() + audio = self.load_file(audio_filename) + + audio, t_start, t_end, seconds_start, seconds_total, padding_mask = self.pad_crop(audio) + + # Run augmentations on this sample (including random crop) + if self.augs is not None: + audio = self.augs(audio) + + audio = audio.clamp(-1, 1) + + # Encode the file to assist in prediction + if self.encoding is not None: + audio = self.encoding(audio) + + info = {} + + info["path"] = audio_filename + + if self.relpath is not None: + info["relpath"] = path.relpath(audio_filename, self.relpath) + + info["timestamps"] = (t_start, t_end) + info["seconds_start"] = seconds_start + info["seconds_total"] = seconds_total + info["padding_mask"] = padding_mask + + end_time = time.time() + + info["load_time"] = end_time - start_time + + if self.custom_metadata_fn is not None: + custom_metadata = self.custom_metadata_fn(info, audio) + info.update(custom_metadata) + + if "__reject__" in info and info["__reject__"]: + return self[random.randrange(len(self))] + + return (audio, info) + except Exception as e: + print(f'Couldn\'t load file {audio_filename}: {e}') + return self[random.randrange(len(self))] + +def group_by_keys(data, keys=wds.tariterators.base_plus_ext, lcase=True, suffixes=None, handler=None): + """Return function over iterator that groups key, value pairs into samples. + :param keys: function that splits the key into key and extension (base_plus_ext) + :param lcase: convert suffixes to lower case (Default value = True) + """ + current_sample = None + for filesample in data: + assert isinstance(filesample, dict) + fname, value = filesample["fname"], filesample["data"] + prefix, suffix = keys(fname) + if wds.tariterators.trace: + print( + prefix, + suffix, + current_sample.keys() if isinstance(current_sample, dict) else None, + ) + if prefix is None: + continue + if lcase: + suffix = suffix.lower() + if current_sample is None or prefix != current_sample["__key__"]: + if wds.tariterators.valid_sample(current_sample): + yield current_sample + current_sample = dict(__key__=prefix, __url__=filesample["__url__"]) + if suffix in current_sample: + print(f"{fname}: duplicate file name in tar file {suffix} {current_sample.keys()}") + if suffixes is None or suffix in suffixes: + current_sample[suffix] = value + if wds.tariterators.valid_sample(current_sample): + yield current_sample + +wds.tariterators.group_by_keys = group_by_keys + +# S3 code and WDS preprocessing code based on implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py + +def get_s3_contents(dataset_path, s3_url_prefix=None, filter='', recursive=True, debug=False, profile=None): + """ + Returns a list of full S3 paths to files in a given S3 bucket and directory path. + """ + # Ensure dataset_path ends with a trailing slash + if dataset_path != '' and not dataset_path.endswith('/'): + dataset_path += '/' + # Use posixpath to construct the S3 URL path + bucket_path = posixpath.join(s3_url_prefix or '', dataset_path) + # Construct the `aws s3 ls` command + cmd = ['aws', 's3', 'ls', bucket_path] + + if profile is not None: + cmd.extend(['--profile', profile]) + + if recursive: + # Add the --recursive flag if requested + cmd.append('--recursive') + + # Run the `aws s3 ls` command and capture the output + run_ls = subprocess.run(cmd, capture_output=True, check=True) + # Split the output into lines and strip whitespace from each line + contents = run_ls.stdout.decode('utf-8').split('\n') + contents = [x.strip() for x in contents if x] + # Remove the timestamp from lines that begin with a timestamp + contents = [re.sub(r'^\S+\s+\S+\s+\d+\s+', '', x) + if re.match(r'^\S+\s+\S+\s+\d+\s+', x) else x for x in contents] + # Construct a full S3 path for each file in the contents list + contents = [posixpath.join(s3_url_prefix or '', x) + for x in contents if not x.endswith('/')] + # Apply the filter, if specified + if filter: + contents = [x for x in contents if filter in x] + # Remove redundant directory names in the S3 URL + if recursive: + # Get the main directory name from the S3 URL + main_dir = "/".join(bucket_path.split('/')[3:]) + # Remove the redundant directory names from each file path + contents = [x.replace(f'{main_dir}', '').replace( + '//', '/') for x in contents] + # Print debugging information, if requested + if debug: + print("contents = \n", contents) + # Return the list of S3 paths to files + return contents + + +def get_all_s3_urls( + names=[], # list of all valid [LAION AudioDataset] dataset names + # list of subsets you want from those datasets, e.g. ['train','valid'] + subsets=[''], + s3_url_prefix=None, # prefix for those dataset names + recursive=True, # recursively list all tar files in all subdirs + filter_str='tar', # only grab files with this substring + # print debugging info -- note: info displayed likely to change at dev's whims + debug=False, + profiles={}, # dictionary of profiles for each item in names, e.g. {'dataset1': 'profile1', 'dataset2': 'profile2'} +): + "get urls of shards (tar files) for multiple datasets in one s3 bucket" + urls = [] + for name in names: + # If s3_url_prefix is not specified, assume the full S3 path is included in each element of the names list + if s3_url_prefix is None: + contents_str = name + else: + # Construct the S3 path using the s3_url_prefix and the current name value + contents_str = posixpath.join(s3_url_prefix, name) + if debug: + print(f"get_all_s3_urls: {contents_str}:") + for subset in subsets: + subset_str = posixpath.join(contents_str, subset) + if debug: + print(f"subset_str = {subset_str}") + # Get the list of tar files in the current subset directory + profile = profiles.get(name, None) + tar_list = get_s3_contents( + subset_str, s3_url_prefix=None, recursive=recursive, filter=filter_str, debug=debug, profile=profile) + for tar in tar_list: + # Escape spaces and parentheses in the tar filename for use in the shell command + tar = tar.replace(" ", "\ ").replace( + "(", "\(").replace(")", "\)") + # Construct the S3 path to the current tar file + s3_path = posixpath.join(name, subset, tar) + " -" + # Construct the AWS CLI command to download the current tar file + if s3_url_prefix is None: + request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {s3_path}" + else: + request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {posixpath.join(s3_url_prefix, s3_path)}" + if profiles.get(name): + request_str += f" --profile {profiles.get(name)}" + if debug: + print("request_str = ", request_str) + # Add the constructed URL to the list of URLs + urls.append(request_str) + return urls + + +def log_and_continue(exn): + """Call in an exception handler to ignore any exception, isssue a warning, and continue.""" + print(f"Handling webdataset error ({repr(exn)}). Ignoring.") + return True + + +def is_valid_sample(sample): + has_json = "json" in sample + has_audio = "audio" in sample + is_silent = is_silence(sample["audio"]) + is_rejected = "__reject__" in sample["json"] and sample["json"]["__reject__"] + + return has_json and has_audio and not is_silent and not is_rejected + +class S3DatasetConfig: + def __init__( + self, + id: str, + s3_path: str, + custom_metadata_fn: Optional[Callable[[str], str]] = None, + profile: Optional[str] = None, + ): + self.id = id + self.s3_path = s3_path + self.custom_metadata_fn = custom_metadata_fn + self.profile = profile + self.urls = [] + + def load_data_urls(self): + self.urls = get_all_s3_urls( + names=[self.s3_path], + s3_url_prefix=None, + recursive=True, + profiles={self.s3_path: self.profile} if self.profile else {}, + ) + + return self.urls + +def audio_decoder(key, value): + # Get file extension from key + ext = key.split(".")[-1] + + if ext in AUDIO_KEYS: + return torchaudio.load(io.BytesIO(value)) + else: + return None + +def collation_fn(samples): + batched = list(zip(*samples)) + result = [] + for b in batched: + if isinstance(b[0], (int, float)): + b = np.array(b) + elif isinstance(b[0], torch.Tensor): + b = torch.stack(b) + elif isinstance(b[0], np.ndarray): + b = np.array(b) + else: + b = b + result.append(b) + return result + +class S3WebDataLoader(): + def __init__( + self, + datasets: List[S3DatasetConfig], + batch_size, + sample_size, + sample_rate=48000, + num_workers=8, + epoch_steps=1000, + random_crop=True, + force_channels="stereo", + augment_phase=True, + **data_loader_kwargs + ): + + self.datasets = datasets + + self.sample_size = sample_size + self.sample_rate = sample_rate + self.random_crop = random_crop + self.force_channels = force_channels + self.augment_phase = augment_phase + + urls = [dataset.load_data_urls() for dataset in datasets] + + # Flatten the list of lists of URLs + urls = [url for dataset_urls in urls for url in dataset_urls] + + self.dataset = wds.DataPipeline( + wds.ResampledShards(urls), + wds.tarfile_to_samples(handler=log_and_continue), + wds.decode(audio_decoder, handler=log_and_continue), + wds.map(self.wds_preprocess, handler=log_and_continue), + wds.select(is_valid_sample), + wds.to_tuple("audio", "json", handler=log_and_continue), + wds.batched(batch_size, partial=False, collation_fn=collation_fn), + ).with_epoch(epoch_steps//num_workers if num_workers > 0 else epoch_steps) + + self.data_loader = wds.WebLoader(self.dataset, num_workers=num_workers, **data_loader_kwargs) + + def wds_preprocess(self, sample): + + found_key, rewrite_key = '', '' + for k, v in sample.items(): # print the all entries in dict + for akey in AUDIO_KEYS: + if k.endswith(akey): + # to rename long/weird key with its simpler counterpart + found_key, rewrite_key = k, akey + break + if '' != found_key: + break + if '' == found_key: # got no audio! + return None # try returning None to tell WebDataset to skip this one + + audio, in_sr = sample[found_key] + if in_sr != self.sample_rate: + resample_tf = T.Resample(in_sr, self.sample_rate) + audio = resample_tf(audio) + + if self.sample_size is not None: + # Pad/crop and get the relative timestamp + pad_crop = PadCrop_Normalized_T( + self.sample_size, randomize=self.random_crop, sample_rate=self.sample_rate) + audio, t_start, t_end, seconds_start, seconds_total, padding_mask = pad_crop( + audio) + sample["json"]["seconds_start"] = seconds_start + sample["json"]["seconds_total"] = seconds_total + sample["json"]["padding_mask"] = padding_mask + else: + t_start, t_end = 0, 1 + + # Check if audio is length zero, initialize to a single zero if so + if audio.shape[-1] == 0: + audio = torch.zeros(1, 1) + + # Make the audio stereo and augment by randomly inverting phase + augs = torch.nn.Sequential( + Stereo() if self.force_channels == "stereo" else torch.nn.Identity(), + Mono() if self.force_channels == "mono" else torch.nn.Identity(), + PhaseFlipper() if self.augment_phase else torch.nn.Identity() + ) + + audio = augs(audio) + + sample["json"]["timestamps"] = (t_start, t_end) + + if "text" in sample["json"]: + sample["json"]["prompt"] = sample["json"]["text"] + + # Check for custom metadata functions + for dataset in self.datasets: + if dataset.custom_metadata_fn is None: + continue + + if dataset.s3_path in sample["__url__"]: + custom_metadata = dataset.custom_metadata_fn(sample["json"], audio) + sample["json"].update(custom_metadata) + + if found_key != rewrite_key: # rename long/weird key with its simpler counterpart + del sample[found_key] + + sample["audio"] = audio + + # Add audio to the metadata as well for conditioning + sample["json"]["audio"] = audio + + return sample + +def create_dataloader_from_config(dataset_config, batch_size, sample_size, sample_rate, audio_channels=2, num_workers=4): + + dataset_type = dataset_config.get("dataset_type", None) + + assert dataset_type is not None, "Dataset type must be specified in dataset config" + + if audio_channels == 1: + force_channels = "mono" + else: + force_channels = "stereo" + + if dataset_type == "audio_dir": + + audio_dir_configs = dataset_config.get("datasets", None) + + assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]" + + training_dirs = [] + + custom_metadata_fn = None + custom_metadata_module_path = dataset_config.get("custom_metadata_module", None) + + if custom_metadata_module_path is not None: + spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path) + metadata_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(metadata_module) + + custom_metadata_fn = metadata_module.get_custom_metadata + + for audio_dir_config in audio_dir_configs: + audio_dir_path = audio_dir_config.get("path", None) + assert audio_dir_path is not None, "Path must be set for local audio directory configuration" + training_dirs.append(audio_dir_path) + + train_set = SampleDataset( + training_dirs, + sample_rate=sample_rate, + sample_size=sample_size, + random_crop=dataset_config.get("random_crop", True), + force_channels=force_channels, + custom_metadata_fn=custom_metadata_fn, + relpath=training_dirs[0] #TODO: Make relpath relative to each training dir + ) + + return torch.utils.data.DataLoader(train_set, batch_size, shuffle=True, + num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=True, collate_fn=collation_fn) + + elif dataset_type == "s3": + dataset_configs = [] + + for s3_config in dataset_config["datasets"]: + + custom_metadata_fn = None + custom_metadata_module_path = s3_config.get("custom_metadata_module", None) + + if custom_metadata_module_path is not None: + spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path) + metadata_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(metadata_module) + + custom_metadata_fn = metadata_module.get_custom_metadata + + dataset_configs.append( + S3DatasetConfig( + id=s3_config["id"], + s3_path=s3_config["s3_path"], + custom_metadata_fn=custom_metadata_fn, + profile=s3_config.get("profile", None), + ) + ) + + return S3WebDataLoader( + dataset_configs, + sample_rate=sample_rate, + sample_size=sample_size, + batch_size=batch_size, + random_crop=dataset_config.get("random_crop", True), + num_workers=num_workers, + persistent_workers=True, + force_channels=force_channels, + epoch_steps=dataset_config.get("epoch_steps", 2000), + ).data_loader \ No newline at end of file diff --git a/stable_audio_tools/data/utils.py b/stable_audio_tools/data/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..848012e46d057129d8383d9d70c80c89e50c77a4 --- /dev/null +++ b/stable_audio_tools/data/utils.py @@ -0,0 +1,96 @@ +import math +import random +import torch + +from torch import nn +from typing import Tuple + +class PadCrop(nn.Module): + def __init__(self, n_samples, randomize=True): + super().__init__() + self.n_samples = n_samples + self.randomize = randomize + + def __call__(self, signal): + n, s = signal.shape + start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item() + end = start + self.n_samples + output = signal.new_zeros([n, self.n_samples]) + output[:, :min(s, self.n_samples)] = signal[:, start:end] + return output + +class PadCrop_Normalized_T(nn.Module): + + def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True): + + super().__init__() + + self.n_samples = n_samples + self.sample_rate = sample_rate + self.randomize = randomize + + def __call__(self, source: torch.Tensor) -> Tuple[torch.Tensor, float, float, int, int]: + + n_channels, n_samples = source.shape + + # If the audio is shorter than the desired length, pad it + upper_bound = max(0, n_samples - self.n_samples) + + # If randomize is False, always start at the beginning of the audio + offset = 0 + if(self.randomize and n_samples > self.n_samples): + offset = random.randint(0, upper_bound) + + # Calculate the start and end times of the chunk + t_start = offset / (upper_bound + self.n_samples) + t_end = (offset + self.n_samples) / (upper_bound + self.n_samples) + + # Create the chunk + chunk = source.new_zeros([n_channels, self.n_samples]) + + # Copy the audio into the chunk + chunk[:, :min(n_samples, self.n_samples)] = source[:, offset:offset + self.n_samples] + + # Calculate the start and end times of the chunk in seconds + seconds_start = math.floor(offset / self.sample_rate) + seconds_total = math.ceil(n_samples / self.sample_rate) + + # Create a mask the same length as the chunk with 1s where the audio is and 0s where it isn't + padding_mask = torch.zeros([self.n_samples]) + padding_mask[:min(n_samples, self.n_samples)] = 1 + + + return ( + chunk, + t_start, + t_end, + seconds_start, + seconds_total, + padding_mask + ) + +class PhaseFlipper(nn.Module): + "Randomly invert the phase of a signal" + def __init__(self, p=0.5): + super().__init__() + self.p = p + def __call__(self, signal): + return -signal if (random.random() < self.p) else signal + +class Mono(nn.Module): + def __call__(self, signal): + return torch.mean(signal, dim=0, keepdims=True) if len(signal.shape) > 1 else signal + +class Stereo(nn.Module): + def __call__(self, signal): + signal_shape = signal.shape + # Check if it's mono + if len(signal_shape) == 1: # s -> 2, s + signal = signal.unsqueeze(0).repeat(2, 1) + elif len(signal_shape) == 2: + if signal_shape[0] == 1: #1, s -> 2, s + signal = signal.repeat(2, 1) + elif signal_shape[0] > 2: #?, s -> 2,s + signal = signal[:2, :] + + return signal diff --git a/stable_audio_tools/inference/__init__.py b/stable_audio_tools/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/stable_audio_tools/inference/generation.py b/stable_audio_tools/inference/generation.py new file mode 100644 index 0000000000000000000000000000000000000000..fcac33dbe0870a2b846334cd5d51e2e393c2407e --- /dev/null +++ b/stable_audio_tools/inference/generation.py @@ -0,0 +1,243 @@ +import numpy as np +import torch +import typing as tp +import math +from torchaudio import transforms as T + +from .utils import prepare_audio +from .sampling import sample, sample_k +from ..data.utils import PadCrop + +def generate_diffusion_uncond( + model, + steps: int = 250, + batch_size: int = 1, + sample_size: int = 2097152, + seed: int = -1, + device: str = "cuda", + init_audio: tp.Optional[tp.Tuple[int, torch.Tensor]] = None, + init_noise_level: float = 1.0, + return_latents = False, + **sampler_kwargs + ) -> torch.Tensor: + + # The length of the output in audio samples + audio_sample_size = sample_size + + # If this is latent diffusion, change sample_size instead to the downsampled latent size + if model.pretransform is not None: + sample_size = sample_size // model.pretransform.downsampling_ratio + + # Seed + # The user can explicitly set the seed to deterministically generate the same output. Otherwise, use a random seed. + seed = seed if seed != -1 else np.random.randint(0, 2**32 - 1, dtype=np.uint32) + print(seed) + torch.manual_seed(seed) + # Define the initial noise immediately after setting the seed + noise = torch.randn([batch_size, model.io_channels, sample_size], device=device) + + if init_audio is not None: + # The user supplied some initial audio (for inpainting or variation). Let us prepare the input audio. + in_sr, init_audio = init_audio + + io_channels = model.io_channels + + # For latent models, set the io_channels to the autoencoder's io_channels + if model.pretransform is not None: + io_channels = model.pretransform.io_channels + + # Prepare the initial audio for use by the model + init_audio = prepare_audio(init_audio, in_sr=in_sr, target_sr=model.sample_rate, target_length=audio_sample_size, target_channels=io_channels, device=device) + + # For latent models, encode the initial audio into latents + if model.pretransform is not None: + init_audio = model.pretransform.encode(init_audio) + + init_audio = init_audio.repeat(batch_size, 1, 1) + else: + # The user did not supply any initial audio for inpainting or variation. Generate new output from scratch. + init_audio = None + init_noise_level = None + + # Inpainting mask + + if init_audio is not None: + # variations + sampler_kwargs["sigma_max"] = init_noise_level + mask = None + else: + mask = None + + # Now the generative AI part: + # k-diffusion denoising process go! + sampled = sample_k(model.model, noise, init_audio, mask, steps, **sampler_kwargs, device=device) + + # Denoising process done. + # If this is latent diffusion, decode latents back into audio + if model.pretransform is not None and not return_latents: + sampled = model.pretransform.decode(sampled) + + # Return audio + return sampled + + +def generate_diffusion_cond( + model, + steps: int = 250, + cfg_scale=6, + conditioning: dict = None, + conditioning_tensors: tp.Optional[dict] = None, + negative_conditioning: dict = None, + negative_conditioning_tensors: tp.Optional[dict] = None, + batch_size: int = 1, + sample_size: int = 2097152, + sample_rate: int = 48000, + seed: int = -1, + device: str = "cuda", + init_audio: tp.Optional[tp.Tuple[int, torch.Tensor]] = None, + init_noise_level: float = 1.0, + mask_args: dict = None, + return_latents = False, + **sampler_kwargs + ) -> torch.Tensor: + """ + Generate audio from a prompt using a diffusion model. + + Args: + model: The diffusion model to use for generation. + steps: The number of diffusion steps to use. + cfg_scale: Classifier-free guidance scale + conditioning: A dictionary of conditioning parameters to use for generation. + conditioning_tensors: A dictionary of precomputed conditioning tensors to use for generation. + batch_size: The batch size to use for generation. + sample_size: The length of the audio to generate, in samples. + sample_rate: The sample rate of the audio to generate (Deprecated, now pulled from the model directly) + seed: The random seed to use for generation, or -1 to use a random seed. + device: The device to use for generation. + init_audio: A tuple of (sample_rate, audio) to use as the initial audio for generation. + init_noise_level: The noise level to use when generating from an initial audio sample. + return_latents: Whether to return the latents used for generation instead of the decoded audio. + **sampler_kwargs: Additional keyword arguments to pass to the sampler. + """ + + # The length of the output in audio samples + audio_sample_size = sample_size + + # If this is latent diffusion, change sample_size instead to the downsampled latent size + if model.pretransform is not None: + sample_size = sample_size // model.pretransform.downsampling_ratio + + # Seed + # The user can explicitly set the seed to deterministically generate the same output. Otherwise, use a random seed. + seed = seed if seed != -1 else np.random.randint(0, 2**32 - 1) + print(seed) + torch.manual_seed(seed) + # Define the initial noise immediately after setting the seed + noise = torch.randn([batch_size, model.io_channels, sample_size], device=device) + + # Conditioning + assert conditioning is not None or conditioning_tensors is not None, "Must provide either conditioning or conditioning_tensors" + if conditioning_tensors is None: + conditioning_tensors = model.conditioner(conditioning, device) + conditioning_tensors = model.get_conditioning_inputs(conditioning_tensors) + + if negative_conditioning is not None or negative_conditioning_tensors is not None: + + if negative_conditioning_tensors is None: + negative_conditioning_tensors = model.conditioner(negative_conditioning, device) + + negative_conditioning_tensors = model.get_conditioning_inputs(negative_conditioning_tensors, negative=True) + else: + negative_conditioning_tensors = {} + + if init_audio is not None: + # The user supplied some initial audio (for inpainting or variation). Let us prepare the input audio. + in_sr, init_audio = init_audio + + io_channels = model.io_channels + + # For latent models, set the io_channels to the autoencoder's io_channels + if model.pretransform is not None: + io_channels = model.pretransform.io_channels + + # Prepare the initial audio for use by the model + init_audio = prepare_audio(init_audio, in_sr=in_sr, target_sr=model.sample_rate, target_length=audio_sample_size, target_channels=io_channels, device=device) + + # For latent models, encode the initial audio into latents + if model.pretransform is not None: + init_audio = model.pretransform.encode(init_audio) + + init_audio = init_audio.repeat(batch_size, 1, 1) + else: + # The user did not supply any initial audio for inpainting or variation. Generate new output from scratch. + init_audio = None + init_noise_level = None + mask_args = None + + # Inpainting mask + if init_audio is not None and mask_args is not None: + # Cut and paste init_audio according to cropfrom, pastefrom, pasteto + # This is helpful for forward and reverse outpainting + cropfrom = math.floor(mask_args["cropfrom"]/100.0 * sample_size) + pastefrom = math.floor(mask_args["pastefrom"]/100.0 * sample_size) + pasteto = math.ceil(mask_args["pasteto"]/100.0 * sample_size) + assert pastefrom < pasteto, "Paste From should be less than Paste To" + croplen = pasteto - pastefrom + if cropfrom + croplen > sample_size: + croplen = sample_size - cropfrom + cropto = cropfrom + croplen + pasteto = pastefrom + croplen + cutpaste = init_audio.new_zeros(init_audio.shape) + cutpaste[:, :, pastefrom:pasteto] = init_audio[:,:,cropfrom:cropto] + #print(cropfrom, cropto, pastefrom, pasteto) + init_audio = cutpaste + # Build a soft mask (list of floats 0 to 1, the size of the latent) from the given args + mask = build_mask(sample_size, mask_args) + mask = mask.to(device) + elif init_audio is not None and mask_args is None: + # variations + sampler_kwargs["sigma_max"] = init_noise_level + mask = None + else: + mask = None + + # Now the generative AI part: + # k-diffusion denoising process go! + sampled = sample_k(model.model, noise, init_audio, mask, steps, **sampler_kwargs, **conditioning_tensors, **negative_conditioning_tensors, cfg_scale=cfg_scale, batch_cfg=True, rescale_cfg=True, device=device) + + # v-diffusion: + #sampled = sample(model.model, noise, steps, 0, **conditioning_tensors, embedding_scale=cfg_scale) + + # Denoising process done. + # If this is latent diffusion, decode latents back into audio + if model.pretransform is not None and not return_latents: + #cast sampled latents to pretransform dtype + sampled = sampled.to(next(model.pretransform.parameters()).dtype) + sampled = model.pretransform.decode(sampled) + + # Return audio + return sampled + +# builds a softmask given the parameters +# returns array of values 0 to 1, size sample_size, where 0 means noise / fresh generation, 1 means keep the input audio, +# and anything between is a mixture of old/new +# ideally 0.5 is half/half mixture but i haven't figured this out yet +def build_mask(sample_size, mask_args): + maskstart = math.floor(mask_args["maskstart"]/100.0 * sample_size) + maskend = math.ceil(mask_args["maskend"]/100.0 * sample_size) + softnessL = round(mask_args["softnessL"]/100.0 * sample_size) + softnessR = round(mask_args["softnessR"]/100.0 * sample_size) + marination = mask_args["marination"] + # use hann windows for softening the transition (i don't know if this is correct) + hannL = torch.hann_window(softnessL*2, periodic=False)[:softnessL] + hannR = torch.hann_window(softnessR*2, periodic=False)[softnessR:] + # build the mask. + mask = torch.zeros((sample_size)) + mask[maskstart:maskend] = 1 + mask[maskstart:maskstart+softnessL] = hannL + mask[maskend-softnessR:maskend] = hannR + # marination finishes the inpainting early in the denoising schedule, and lets audio get changed in the final rounds + if marination > 0: + mask = mask * (1-marination) + #print(mask) + return mask \ No newline at end of file diff --git a/stable_audio_tools/inference/sampling.py b/stable_audio_tools/inference/sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..1c6efcd3770416c7e43a672b805dbd47f646fc4e --- /dev/null +++ b/stable_audio_tools/inference/sampling.py @@ -0,0 +1,170 @@ +import torch +import math +from tqdm import trange + +import k_diffusion as K + +# Define the noise schedule and sampling loop +def get_alphas_sigmas(t): + """Returns the scaling factors for the clean image (alpha) and for the + noise (sigma), given a timestep.""" + return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2) + +def alpha_sigma_to_t(alpha, sigma): + """Returns a timestep, given the scaling factors for the clean image and for + the noise.""" + return torch.atan2(sigma, alpha) / math.pi * 2 + +def t_to_alpha_sigma(t): + """Returns the scaling factors for the clean image and for the noise, given + a timestep.""" + return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2) + +@torch.no_grad() +def sample(model, x, steps, eta, **extra_args): + """Draws samples from a model given starting noise. v-diffusion""" + ts = x.new_ones([x.shape[0]]) + + # Create the noise schedule + t = torch.linspace(1, 0, steps + 1)[:-1] + + alphas, sigmas = get_alphas_sigmas(t) + + # The sampling loop + for i in trange(steps): + + # Get the model output (v, the predicted velocity) + with torch.cuda.amp.autocast(): + v = model(x, ts * t[i], **extra_args).float() + + # Predict the noise and the denoised image + pred = x * alphas[i] - v * sigmas[i] + eps = x * sigmas[i] + v * alphas[i] + + # If we are not on the last timestep, compute the noisy image for the + # next timestep. + if i < steps - 1: + # If eta > 0, adjust the scaling factor for the predicted noise + # downward according to the amount of additional noise to add + ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \ + (1 - alphas[i]**2 / alphas[i + 1]**2).sqrt() + adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt() + + # Recombine the predicted noise and predicted denoised image in the + # correct proportions for the next step + x = pred * alphas[i + 1] + eps * adjusted_sigma + + # Add the correct amount of fresh noise + if eta: + x += torch.randn_like(x) * ddim_sigma + + # If we are on the last timestep, output the denoised image + return pred + +# Soft mask inpainting is just shrinking hard (binary) mask inpainting +# Given a float-valued soft mask (values between 0 and 1), get the binary mask for this particular step +def get_bmask(i, steps, mask): + strength = (i+1)/(steps) + # convert to binary mask + bmask = torch.where(mask<=strength,1,0) + return bmask + +def make_cond_model_fn(model, cond_fn): + def cond_model_fn(x, sigma, **kwargs): + with torch.enable_grad(): + x = x.detach().requires_grad_() + denoised = model(x, sigma, **kwargs) + cond_grad = cond_fn(x, sigma, denoised=denoised, **kwargs).detach() + cond_denoised = denoised.detach() + cond_grad * K.utils.append_dims(sigma**2, x.ndim) + return cond_denoised + return cond_model_fn + +# Uses k-diffusion from https://github.com/crowsonkb/k-diffusion +# init_data is init_audio as latents (if this is latent diffusion) +# For sampling, set both init_data and mask to None +# For variations, set init_data +# For inpainting, set both init_data & mask +def sample_k( + model_fn, + noise, + init_data=None, + mask=None, + steps=100, + sampler_type="dpmpp-2m-sde", + sigma_min=0.5, + sigma_max=50, + rho=1.0, device="cuda", + callback=None, + cond_fn=None, + **extra_args + ): + + denoiser = K.external.VDenoiser(model_fn) + + if cond_fn is not None: + denoiser = make_cond_model_fn(denoiser, cond_fn) + + # Make the list of sigmas. Sigma values are scalars related to the amount of noise each denoising step has + sigmas = K.sampling.get_sigmas_polyexponential(steps, sigma_min, sigma_max, rho, device=device) + # Scale the initial noise by sigma + noise = noise * sigmas[0] + + wrapped_callback = callback + + if mask is None and init_data is not None: + # VARIATION (no inpainting) + # set the initial latent to the init_data, and noise it with initial sigma + x = init_data + noise + elif mask is not None and init_data is not None: + # INPAINTING + bmask = get_bmask(0, steps, mask) + # initial noising + input_noised = init_data + noise + # set the initial latent to a mix of init_data and noise, based on step 0's binary mask + x = input_noised * bmask + noise * (1-bmask) + # define the inpainting callback function (Note: side effects, it mutates x) + # See https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py#L596C13-L596C105 + # callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + # This is called immediately after `denoised = model(x, sigmas[i] * s_in, **extra_args)` + def inpainting_callback(args): + i = args["i"] + x = args["x"] + sigma = args["sigma"] + #denoised = args["denoised"] + # noise the init_data input with this step's appropriate amount of noise + input_noised = init_data + torch.randn_like(init_data) * sigma + # shrinking hard mask + bmask = get_bmask(i, steps, mask) + # mix input_noise with x, using binary mask + new_x = input_noised * bmask + x * (1-bmask) + # mutate x + x[:,:,:] = new_x[:,:,:] + # wrap together the inpainting callback and the user-submitted callback. + if callback is None: + wrapped_callback = inpainting_callback + else: + wrapped_callback = lambda args: (inpainting_callback(args), callback(args)) + else: + # SAMPLING + # set the initial latent to noise + x = noise + + + with torch.cuda.amp.autocast(): + if sampler_type == "k-heun": + return K.sampling.sample_heun(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) + elif sampler_type == "k-lms": + return K.sampling.sample_lms(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) + elif sampler_type == "k-dpmpp-2s-ancestral": + return K.sampling.sample_dpmpp_2s_ancestral(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) + elif sampler_type == "k-dpm-2": + return K.sampling.sample_dpm_2(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) + elif sampler_type == "k-dpm-fast": + return K.sampling.sample_dpm_fast(denoiser, x, sigma_min, sigma_max, steps, disable=False, callback=wrapped_callback, extra_args=extra_args) + elif sampler_type == "k-dpm-adaptive": + return K.sampling.sample_dpm_adaptive(denoiser, x, sigma_min, sigma_max, rtol=0.01, atol=0.01, disable=False, callback=wrapped_callback, extra_args=extra_args) + elif sampler_type == "dpmpp-2m-sde": + return K.sampling.sample_dpmpp_2m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) + elif sampler_type == "dpmpp-3m-sde": + return K.sampling.sample_dpmpp_3m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) + diff --git a/stable_audio_tools/inference/utils.py b/stable_audio_tools/inference/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6a6c0a57609f68156ad244da9b5819666329772e --- /dev/null +++ b/stable_audio_tools/inference/utils.py @@ -0,0 +1,35 @@ +from ..data.utils import PadCrop + +from torchaudio import transforms as T + +def set_audio_channels(audio, target_channels): + if target_channels == 1: + # Convert to mono + audio = audio.mean(1, keepdim=True) + elif target_channels == 2: + # Convert to stereo + if audio.shape[1] == 1: + audio = audio.repeat(1, 2, 1) + elif audio.shape[1] > 2: + audio = audio[:, :2, :] + return audio + +def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device): + + audio = audio.to(device) + + if in_sr != target_sr: + resample_tf = T.Resample(in_sr, target_sr).to(device) + audio = resample_tf(audio) + + audio = PadCrop(target_length, randomize=False)(audio) + + # Add batch dimension + if audio.dim() == 1: + audio = audio.unsqueeze(0).unsqueeze(0) + elif audio.dim() == 2: + audio = audio.unsqueeze(0) + + audio = set_audio_channels(audio, target_channels) + + return audio \ No newline at end of file diff --git a/stable_audio_tools/interface/__init__.py b/stable_audio_tools/interface/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/stable_audio_tools/interface/gradio.py b/stable_audio_tools/interface/gradio.py new file mode 100644 index 0000000000000000000000000000000000000000..3d4e216536e4f68a6cbca169f61cbd77f3a2e9de --- /dev/null +++ b/stable_audio_tools/interface/gradio.py @@ -0,0 +1,788 @@ +import gc +import numpy as np +import gradio as gr +import json +import torch +import torchaudio + +from aeiou.viz import audio_spectrogram_image +from einops import rearrange +from safetensors.torch import load_file +from torch.nn import functional as F +from torchaudio import transforms as T + +from ..inference.generation import generate_diffusion_cond, generate_diffusion_uncond +from ..models.factory import create_model_from_config +from ..models.pretrained import get_pretrained_model +from ..models.utils import load_ckpt_state_dict +from ..inference.utils import prepare_audio +from ..training.utils import copy_state_dict + +# Define preset values +presets = { + "Pied Currawong": { + "latitude": -33.6467, + "longitude": 150.3246, + "temperature": 12.43, + "humidity": 86, + "wind_speed": 0.66, + "pressure": 1013, + "minutes_of_day": 369, + "day_of_year": 297, + }, + "Yellow-tailed Black Cockatoo": { + "latitude": -32.8334, + "longitude": 150.2001, + "temperature": 23.23, + "humidity": 45, + "wind_speed": 1.37, + "pressure": 1009, + "minutes_of_day": 986, + "day_of_year": 78, + }, + "Australian Magpie": { + "latitude": -38.522, + "longitude": 145.3365, + "temperature": 18.75, + "humidity": 67, + "wind_speed": 1.5, + "pressure": 1023, + "minutes_of_day": 940, + "day_of_year": 307, + }, + "Laughing Kookaburra": { + "latitude": -27.2685099, + "longitude": 152.8587437, + "temperature": 9.02, + "humidity": 94, + "wind_speed": 1.5, + "pressure": 1025, + "minutes_of_day": 320, + "day_of_year": 236, + } +} + +def update_sliders(preset_name): + preset = presets[preset_name] + return (preset["latitude"], preset["longitude"], preset["temperature"], preset["humidity"], preset["wind_speed"], preset["pressure"], preset["minutes_of_day"], preset["day_of_year"]) + + +model = None +sample_rate = 44100 +sample_size = 524288 + + +def load_model(model_config=None, model_ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, device="cuda", model_half=False): + global model, sample_rate, sample_size + + if pretrained_name is not None: + print(f"Loading pretrained model {pretrained_name}") + model, model_config = get_pretrained_model(pretrained_name) + + elif model_config is not None and model_ckpt_path is not None: + print(f"Creating model from config") + model = create_model_from_config(model_config) + + print(f"Loading model checkpoint from {model_ckpt_path}") + # Load checkpoint + copy_state_dict(model, load_ckpt_state_dict(model_ckpt_path)) + #model.load_state_dict(load_ckpt_state_dict(model_ckpt_path)) + + sample_rate = model_config["sample_rate"] + sample_size = model_config["sample_size"] + + if pretransform_ckpt_path is not None: + print(f"Loading pretransform checkpoint from {pretransform_ckpt_path}") + model.pretransform.load_state_dict(load_ckpt_state_dict(pretransform_ckpt_path), strict=False) + print(f"Done loading pretransform") + + model.to(device).eval().requires_grad_(False) + + if model_half: + model.to(torch.float16) + + print(f"Done loading model") + + return model, model_config + +def generate_cond( + seconds_start=0, + seconds_total=30, + latitude = 0.0, + longitude = 0.0, + temperature = 0.0, + humidity = 0.0, + wind_speed = 0.0, + pressure = 0.0, + minutes_of_day = 0.0, + day_of_year = 0.0, + cfg_scale=6.0, + steps=250, + preview_every=None, + seed=-1, + sampler_type="dpmpp-2m-sde", + sigma_min=0.03, + sigma_max=50, + cfg_rescale=0.4, + use_init=False, + init_audio=None, + init_noise_level=1.0, + mask_cropfrom=None, + mask_pastefrom=None, + mask_pasteto=None, + mask_maskstart=None, + mask_maskend=None, + mask_softnessL=None, + mask_softnessR=None, + mask_marination=None, + batch_size=1 + ): + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + + global preview_images + preview_images = [] + if preview_every == 0: + preview_every = None + + # Return fake stereo audio + conditioning = [{"latitude": -latitude, "longitude": longitude, "temperature": temperature, "humidity": humidity, "wind_speed": wind_speed, "pressure": pressure, "minutes_of_day": minutes_of_day,"day_of_year": day_of_year, "seconds_start":seconds_start, "seconds_total": seconds_total }] * batch_size + + #Get the device from the model + device = next(model.parameters()).device + + seed = int(seed) + + if not use_init: + init_audio = None + + input_sample_size = sample_size + + if init_audio is not None: + in_sr, init_audio = init_audio + # Turn into torch tensor, converting from int16 to float32 + init_audio = torch.from_numpy(init_audio).float().div(32767) + + if init_audio.dim() == 1: + init_audio = init_audio.unsqueeze(0) # [1, n] + elif init_audio.dim() == 2: + init_audio = init_audio.transpose(0, 1) # [n, 2] -> [2, n] + + if in_sr != sample_rate: + resample_tf = T.Resample(in_sr, sample_rate).to(init_audio.device) + init_audio = resample_tf(init_audio) + + audio_length = init_audio.shape[-1] + + if audio_length > sample_size: + + input_sample_size = audio_length + (model.min_input_length - (audio_length % model.min_input_length)) % model.min_input_length + + init_audio = (sample_rate, init_audio) + + def progress_callback(callback_info): + global preview_images + denoised = callback_info["denoised"] + current_step = callback_info["i"] + sigma = callback_info["sigma"] + + if (current_step - 1) % preview_every == 0: + if model.pretransform is not None: + denoised = model.pretransform.decode(denoised) + denoised = rearrange(denoised, "b d n -> d (b n)") + denoised = denoised.clamp(-1, 1).mul(32767).to(torch.int16).cpu() + audio_spectrogram = audio_spectrogram_image(denoised, sample_rate=sample_rate) + preview_images.append((audio_spectrogram, f"Step {current_step} sigma={sigma:.3f})")) + + # If inpainting, send mask args + # This will definitely change in the future + if mask_cropfrom is not None: + mask_args = { + "cropfrom": mask_cropfrom, + "pastefrom": mask_pastefrom, + "pasteto": mask_pasteto, + "maskstart": mask_maskstart, + "maskend": mask_maskend, + "softnessL": mask_softnessL, + "softnessR": mask_softnessR, + "marination": mask_marination, + } + else: + mask_args = None + + # Do the audio generation + audio = generate_diffusion_cond( + model, + conditioning=conditioning, + steps=steps, + cfg_scale=cfg_scale, + batch_size=batch_size, + sample_size=input_sample_size, + sample_rate=sample_rate, + seed=seed, + device=device, + sampler_type=sampler_type, + sigma_min=sigma_min, + sigma_max=sigma_max, + init_audio=init_audio, + init_noise_level=init_noise_level, + mask_args = mask_args, + callback = progress_callback if preview_every is not None else None, + scale_phi = cfg_rescale + ) + + # Convert to WAV file + audio = rearrange(audio, "b d n -> d (b n)") + audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() + torchaudio.save("output.wav", audio, sample_rate) + + # Let's look at a nice spectrogram too + audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate) + + return ("output.wav", [audio_spectrogram, *preview_images]) + +def generate_uncond( + steps=250, + seed=-1, + sampler_type="dpmpp-2m-sde", + sigma_min=0.03, + sigma_max=50, + use_init=False, + init_audio=None, + init_noise_level=1.0, + batch_size=1, + preview_every=None + ): + + global preview_images + + preview_images = [] + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + #Get the device from the model + device = next(model.parameters()).device + + seed = int(seed) + + if not use_init: + init_audio = None + + input_sample_size = sample_size + + if init_audio is not None: + in_sr, init_audio = init_audio + # Turn into torch tensor, converting from int16 to float32 + init_audio = torch.from_numpy(init_audio).float().div(32767) + + if init_audio.dim() == 1: + init_audio = init_audio.unsqueeze(0) # [1, n] + elif init_audio.dim() == 2: + init_audio = init_audio.transpose(0, 1) # [n, 2] -> [2, n] + + if in_sr != sample_rate: + resample_tf = T.Resample(in_sr, sample_rate).to(init_audio.device) + init_audio = resample_tf(init_audio) + + audio_length = init_audio.shape[-1] + + if audio_length > sample_size: + + input_sample_size = audio_length + (model.min_input_length - (audio_length % model.min_input_length)) % model.min_input_length + + init_audio = (sample_rate, init_audio) + + def progress_callback(callback_info): + global preview_images + denoised = callback_info["denoised"] + current_step = callback_info["i"] + sigma = callback_info["sigma"] + + if (current_step - 1) % preview_every == 0: + + if model.pretransform is not None: + denoised = model.pretransform.decode(denoised) + + denoised = rearrange(denoised, "b d n -> d (b n)") + + denoised = denoised.clamp(-1, 1).mul(32767).to(torch.int16).cpu() + + audio_spectrogram = audio_spectrogram_image(denoised, sample_rate=sample_rate) + + preview_images.append((audio_spectrogram, f"Step {current_step} sigma={sigma:.3f})")) + + audio = generate_diffusion_uncond( + model, + steps=steps, + batch_size=batch_size, + sample_size=input_sample_size, + seed=seed, + device=device, + sampler_type=sampler_type, + sigma_min=sigma_min, + sigma_max=sigma_max, + init_audio=init_audio, + init_noise_level=init_noise_level, + callback = progress_callback if preview_every is not None else None + ) + + audio = rearrange(audio, "b d n -> d (b n)") + + audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() + + torchaudio.save("output.wav", audio, sample_rate) + + audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate) + + return ("output.wav", [audio_spectrogram, *preview_images]) + +def generate_lm( + temperature=1.0, + top_p=0.95, + top_k=0, + batch_size=1, + ): + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + #Get the device from the model + device = next(model.parameters()).device + + audio = model.generate_audio( + batch_size=batch_size, + max_gen_len = sample_size//model.pretransform.downsampling_ratio, + conditioning=None, + temp=temperature, + top_p=top_p, + top_k=top_k, + use_cache=True + ) + + audio = rearrange(audio, "b d n -> d (b n)") + + audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() + + torchaudio.save("output.wav", audio, sample_rate) + + audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate) + + return ("output.wav", [audio_spectrogram]) + + +def create_uncond_sampling_ui(model_config): + generate_button = gr.Button("Generate", variant='primary', scale=1) + + with gr.Row(equal_height=False): + with gr.Column(): + with gr.Row(): + # Steps slider + steps_slider = gr.Slider(minimum=1, maximum=500, step=1, value=100, label="Steps") + + with gr.Accordion("Sampler params", open=False): + + # Seed + seed_textbox = gr.Textbox(label="Seed (set to -1 for random seed)", value="-1") + + # Sampler params + with gr.Row(): + sampler_type_dropdown = gr.Dropdown(["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"], label="Sampler type", value="dpmpp-2m-sde") + sigma_min_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=0.03, label="Sigma min") + sigma_max_slider = gr.Slider(minimum=0.0, maximum=200.0, step=0.1, value=80, label="Sigma max") + + with gr.Accordion("Init audio", open=False): + init_audio_checkbox = gr.Checkbox(label="Use init audio") + init_audio_input = gr.Audio(label="Init audio") + init_noise_level_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.01, value=0.1, label="Init noise level") + + with gr.Column(): + audio_output = gr.Audio(label="Output audio", interactive=False) + audio_spectrogram_output = gr.Gallery(label="Output spectrogram", show_label=False) + send_to_init_button = gr.Button("Send to init audio", scale=1) + send_to_init_button.click(fn=lambda audio: audio, inputs=[audio_output], outputs=[init_audio_input]) + + generate_button.click(fn=generate_uncond, + inputs=[ + steps_slider, + seed_textbox, + sampler_type_dropdown, + sigma_min_slider, + sigma_max_slider, + init_audio_checkbox, + init_audio_input, + init_noise_level_slider, + ], + outputs=[ + audio_output, + audio_spectrogram_output + ], + api_name="generate") +def create_conditioning_slider(min_val, max_val, default_value, label): + """ + Create a Gradio slider for a given conditioning parameter. + + Args: + - min_val: The minimum value for the slider. + - max_val: The maximum value for the slider. + - label: The label for the slider, which is displayed in the UI. + + Returns: + - A gr.Slider object configured according to the provided parameters. + """ + step = (max_val - min_val) / 1000 + default_val = default_value + print(f"Creating slider for {label} with min_val={min_val}, max_val={max_val}, step={step}, default_val={default_val}") + return gr.Slider(minimum=min_val, maximum=max_val, step=step, value=default_val, label=label) + +def create_sampling_ui(model_config): + with gr.Row(): + + generate_button = gr.Button("Generate", variant='primary', scale=1) + + model_conditioning_config = model_config["model"].get("conditioning", None) + + has_seconds_start = False + has_seconds_total = False + + if model_conditioning_config is not None: + for conditioning_config in model_conditioning_config["configs"]: + if conditioning_config["id"] == "seconds_start": + has_seconds_start = True + if conditioning_config["id"] == "seconds_total": + has_seconds_total = True + + with gr.Row(equal_height=False): + with gr.Column(): + with gr.Row(): + + seconds_start_slider = gr.Slider(minimum=0, maximum=512, step=1, value=0, label="Seconds start", visible=has_seconds_start) + + seconds_total_slider = gr.Slider(minimum=0, maximum=22, step=1, value=sample_size//sample_rate, label="Seconds total", visible=has_seconds_total) + + with gr.Row(): + # Steps slider + steps_slider = gr.Slider(minimum=1, maximum=500, step=1, value=250, label="Steps") + + # Preview Every slider + preview_every_slider = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Preview Every") + + # CFG scale + cfg_scale_slider = gr.Slider(minimum=0.0, maximum=25.0, step=0.1, value=4.0, label="CFG scale") + + with gr.Accordion("Climate and location", open=True): + preset_dropdown = gr.Dropdown(choices=list(presets.keys()), label="Select Preset") + + latitude_config = next((item for item in model_conditioning_config["configs"] if item["id"] == "latitude"), None) + if latitude_config: + latitude_slider = create_conditioning_slider( + min_val=latitude_config["config"]["min_val"], + max_val=latitude_config["config"]["max_val"], + default_value = -29.8913, + label="latitude") + + longitude_config = next((item for item in model_conditioning_config["configs"] if item["id"] == "longitude"), None) + if longitude_config: + longitude_slider = create_conditioning_slider( + min_val=longitude_config["config"]["min_val"], + max_val=longitude_config["config"]["max_val"], + default_value=152.4951, + label="longitude") + + temperature_config = next((item for item in model_conditioning_config["configs"] if item["id"] == "temperature"), None) + if temperature_config: + temperature_slider = create_conditioning_slider( + min_val=temperature_config["config"]["min_val"], + max_val=temperature_config["config"]["max_val"], + default_value=22.05, + label="temperature") + + humidity_config = next((item for item in model_conditioning_config["configs"] if item["id"] == "humidity"), None) + if humidity_config: + humidity_slider = create_conditioning_slider( + min_val=humidity_config["config"]["min_val"], + max_val=humidity_config["config"]["max_val"], + default_value=88, + label="humidity") + + wind_speed_config = next((item for item in model_conditioning_config["configs"] if item["id"] == "wind_speed"), None) + if wind_speed_config: + wind_speed_slider = create_conditioning_slider( + min_val=wind_speed_config["config"]["min_val"], + max_val=wind_speed_config["config"]["max_val"], + default_value=0.54, + label="wind_speed") + + pressure_config = next((item for item in model_conditioning_config["configs"] if item["id"] == "pressure"), None) + if pressure_config: + pressure_slider = create_conditioning_slider( + min_val=pressure_config["config"]["min_val"], + max_val=pressure_config["config"]["max_val"], + default_value=1021, + label="pressure") + + minutes_of_day_config = next((item for item in model_conditioning_config["configs"] if item["id"] == "minutes_of_day"), None) + if minutes_of_day_config: + minutes_of_day_slider = create_conditioning_slider( + min_val=minutes_of_day_config["config"]["min_val"], + max_val=minutes_of_day_config["config"]["max_val"], + default_value=1354, + label="minutes_of_day") + + day_of_year_config = next((item for item in model_conditioning_config["configs"] if item["id"] == "day_of_year"), None) + if day_of_year_config: + day_of_year_slider = create_conditioning_slider( + min_val=day_of_year_config["config"]["min_val"], + max_val=day_of_year_config["config"]["max_val"], + default_value=342, + label="Day of year") + + with gr.Accordion("Sampler params", open=False): + + # Seed + seed_textbox = gr.Textbox(label="Seed (set to -1 for random seed)", value="-1") + + # Sampler params + with gr.Row(): + sampler_type_dropdown = gr.Dropdown(["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"], label="Sampler type", value="dpmpp-2m-sde") + sigma_min_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=0.03, label="Sigma min") + sigma_max_slider = gr.Slider(minimum=0.0, maximum=200.0, step=0.1, value=50, label="Sigma max") + cfg_rescale_slider = gr.Slider(minimum=0.0, maximum=1, step=0.01, value=0.4, label="CFG rescale amount") + + + # Default generation tab + with gr.Accordion("Init audio", open=False): + init_audio_input = gr.Audio(label="Init audio") + init_noise_level_slider = gr.Slider(minimum=0.1, maximum=100.0, step=0.01, value=1.0, label="Init noise level") + + inputs = [ + seconds_start_slider, + seconds_total_slider, + latitude_slider, + longitude_slider, + temperature_slider, + humidity_slider, + wind_speed_slider, + pressure_slider, + minutes_of_day_slider, + day_of_year_slider, + cfg_scale_slider, + steps_slider, + preview_every_slider, + seed_textbox, + sampler_type_dropdown, + sigma_min_slider, + sigma_max_slider, + cfg_rescale_slider, + init_noise_level_slider + ] + + with gr.Column(): + audio_output = gr.Audio(label="Output audio", interactive=False) + audio_spectrogram_output = gr.Gallery(label="Output spectrogram", show_label=False) + + generate_button.click(fn=generate_cond, + inputs=inputs, + outputs=[ + audio_output, + audio_spectrogram_output + ], + api_name="generate") + + preset_dropdown.change( + fn=update_sliders, + inputs=[preset_dropdown], + outputs=[ + latitude_slider, + longitude_slider, + temperature_slider, + humidity_slider, + wind_speed_slider, + pressure_slider, + minutes_of_day_slider, + day_of_year_slider + ] + ) + +def create_txt2audio_ui(model_config): + with gr.Blocks() as ui: + with gr.Tab("Generation"): + create_sampling_ui(model_config) + # with gr.Tab("Inpainting"): + # create_sampling_ui(model_config, inpainting=True) + return ui + +def create_diffusion_uncond_ui(model_config): + with gr.Blocks() as ui: + create_uncond_sampling_ui(model_config) + + return ui + +def autoencoder_process(audio, latent_noise, n_quantizers): + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + #Get the device from the model + device = next(model.parameters()).device + + in_sr, audio = audio + + audio = torch.from_numpy(audio).float().div(32767).to(device) + + if audio.dim() == 1: + audio = audio.unsqueeze(0) + else: + audio = audio.transpose(0, 1) + + audio = model.preprocess_audio_for_encoder(audio, in_sr) + # Note: If you need to do chunked encoding, to reduce VRAM, + # then add these arguments to encode_audio and decode_audio: chunked=True, overlap=32, chunk_size=128 + # To turn it off, do chunked=False + # Optimal overlap and chunk_size values will depend on the model. + # See encode_audio & decode_audio in autoencoders.py for more info + # Get dtype of model + dtype = next(model.parameters()).dtype + + audio = audio.to(dtype) + + if n_quantizers > 0: + latents = model.encode_audio(audio, chunked=False, n_quantizers=n_quantizers) + else: + latents = model.encode_audio(audio, chunked=False) + + if latent_noise > 0: + latents = latents + torch.randn_like(latents) * latent_noise + + audio = model.decode_audio(latents, chunked=False) + + audio = rearrange(audio, "b d n -> d (b n)") + + audio = audio.to(torch.float32).clamp(-1, 1).mul(32767).to(torch.int16).cpu() + + torchaudio.save("output.wav", audio, sample_rate) + + return "output.wav" + +def create_autoencoder_ui(model_config): + + is_dac_rvq = "model" in model_config and "bottleneck" in model_config["model"] and model_config["model"]["bottleneck"]["type"] in ["dac_rvq","dac_rvq_vae"] + + if is_dac_rvq: + n_quantizers = model_config["model"]["bottleneck"]["config"]["n_codebooks"] + else: + n_quantizers = 0 + + with gr.Blocks() as ui: + input_audio = gr.Audio(label="Input audio") + output_audio = gr.Audio(label="Output audio", interactive=False) + n_quantizers_slider = gr.Slider(minimum=1, maximum=n_quantizers, step=1, value=n_quantizers, label="# quantizers", visible=is_dac_rvq) + latent_noise_slider = gr.Slider(minimum=0.0, maximum=10.0, step=0.001, value=0.0, label="Add latent noise") + process_button = gr.Button("Process", variant='primary', scale=1) + process_button.click(fn=autoencoder_process, inputs=[input_audio, latent_noise_slider, n_quantizers_slider], outputs=output_audio, api_name="process") + + return ui + +def diffusion_prior_process(audio, steps, sampler_type, sigma_min, sigma_max): + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + #Get the device from the model + device = next(model.parameters()).device + + in_sr, audio = audio + + audio = torch.from_numpy(audio).float().div(32767).to(device) + + if audio.dim() == 1: + audio = audio.unsqueeze(0) # [1, n] + elif audio.dim() == 2: + audio = audio.transpose(0, 1) # [n, 2] -> [2, n] + + audio = audio.unsqueeze(0) + + audio = model.stereoize(audio, in_sr, steps, sampler_kwargs={"sampler_type": sampler_type, "sigma_min": sigma_min, "sigma_max": sigma_max}) + + audio = rearrange(audio, "b d n -> d (b n)") + + audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() + + torchaudio.save("output.wav", audio, sample_rate) + + return "output.wav" + +def create_diffusion_prior_ui(model_config): + with gr.Blocks() as ui: + input_audio = gr.Audio(label="Input audio") + output_audio = gr.Audio(label="Output audio", interactive=False) + # Sampler params + with gr.Row(): + steps_slider = gr.Slider(minimum=1, maximum=500, step=1, value=100, label="Steps") + sampler_type_dropdown = gr.Dropdown(["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"], label="Sampler type", value="dpmpp-2m-sde") + sigma_min_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=0.03, label="Sigma min") + sigma_max_slider = gr.Slider(minimum=0.0, maximum=200.0, step=0.1, value=80, label="Sigma max") + process_button = gr.Button("Process", variant='primary', scale=1) + process_button.click(fn=diffusion_prior_process, inputs=[input_audio, steps_slider, sampler_type_dropdown, sigma_min_slider, sigma_max_slider], outputs=output_audio, api_name="process") + + return ui + +def create_lm_ui(model_config): + with gr.Blocks() as ui: + output_audio = gr.Audio(label="Output audio", interactive=False) + audio_spectrogram_output = gr.Gallery(label="Output spectrogram", show_label=False) + + # Sampling params + with gr.Row(): + temperature_slider = gr.Slider(minimum=0, maximum=5, step=0.01, value=1.0, label="Temperature") + top_p_slider = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.95, label="Top p") + top_k_slider = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Top k") + + generate_button = gr.Button("Generate", variant='primary', scale=1) + generate_button.click( + fn=generate_lm, + inputs=[ + temperature_slider, + top_p_slider, + top_k_slider + ], + outputs=[output_audio, audio_spectrogram_output], + api_name="generate" + ) + + return ui + +def create_ui(model_config_path=None, ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, model_half=False): + + assert (pretrained_name is not None) ^ (model_config_path is not None and ckpt_path is not None), "Must specify either pretrained name or provide a model config and checkpoint, but not both" + + if model_config_path is not None: + # Load config from json file + with open(model_config_path) as f: + model_config = json.load(f) + else: + model_config = None + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + _, model_config = load_model(model_config, ckpt_path, pretrained_name=pretrained_name, pretransform_ckpt_path=pretransform_ckpt_path, model_half=model_half, device=device) + + model_type = model_config["model_type"] + + if model_type == "diffusion_cond": + ui = create_txt2audio_ui(model_config) + elif model_type == "diffusion_uncond": + ui = create_diffusion_uncond_ui(model_config) + elif model_type == "autoencoder" or model_type == "diffusion_autoencoder": + ui = create_autoencoder_ui(model_config) + elif model_type == "diffusion_prior": + ui = create_diffusion_prior_ui(model_config) + elif model_type == "lm": + ui = create_lm_ui(model_config) + + return ui \ No newline at end of file diff --git a/stable_audio_tools/interface/testing.py b/stable_audio_tools/interface/testing.py new file mode 100644 index 0000000000000000000000000000000000000000..e7fcd0abe4df5a6c923213f8f6464bec7fd63e50 --- /dev/null +++ b/stable_audio_tools/interface/testing.py @@ -0,0 +1,409 @@ +import gc +import numpy as np +import json +import torch +import torchaudio +import os +import re + +from aeiou.viz import audio_spectrogram_image +from einops import rearrange +from safetensors.torch import load_file +from torch.nn import functional as F +from torchaudio import transforms as T + +from ..inference.generation import generate_diffusion_cond, generate_diffusion_uncond +from ..models.factory import create_model_from_config +from ..models.pretrained import get_pretrained_model +from ..models.utils import load_ckpt_state_dict +from ..inference.utils import prepare_audio +from ..training.utils import copy_state_dict + + +model = None +sample_rate = 44100 +sample_size = 524288 + +def load_model(model_config=None, model_ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, device="cuda", model_half=False): + global model, sample_rate, sample_size + + if pretrained_name is not None: + print(f"Loading pretrained model {pretrained_name}") + model, model_config = get_pretrained_model(pretrained_name) + + elif model_config is not None and model_ckpt_path is not None: + print(f"Creating model from config") + model = create_model_from_config(model_config) + + print(f"Loading model checkpoint from {model_ckpt_path}") + # Load checkpoint + copy_state_dict(model, load_ckpt_state_dict(model_ckpt_path)) + #model.load_state_dict(load_ckpt_state_dict(model_ckpt_path)) + + sample_rate = model_config["sample_rate"] + sample_size = model_config["sample_size"] + + if pretransform_ckpt_path is not None: + print(f"Loading pretransform checkpoint from {pretransform_ckpt_path}") + model.pretransform.load_state_dict(load_ckpt_state_dict(pretransform_ckpt_path), strict=False) + print(f"Done loading pretransform") + + model.to(device).eval().requires_grad_(False) + + if model_half: + model.to(torch.float16) + + print(f"Done loading model") + + return model, model_config + +def generate_cond_with_path( + prompt, + negative_prompt=None, + seconds_start=0, + seconds_total=30, + latitude = 0.0, + longitude = 0.0, + temperature = 0.0, + humidity = 0.0, + wind_speed = 0.0, + pressure = 0.0, + minutes_of_day = 0.0, + day_of_year = 0.0, + cfg_scale=6.0, + steps=250, + preview_every=None, + seed=-1, + sampler_type="dpmpp-2m-sde", + sigma_min=0.03, + sigma_max=50, + cfg_rescale=0.4, + use_init=False, + init_audio=None, + init_noise_level=1.0, + mask_cropfrom=None, + mask_pastefrom=None, + mask_pasteto=None, + mask_maskstart=None, + mask_maskend=None, + mask_softnessL=None, + mask_softnessR=None, + mask_marination=None, + batch_size=1, + destination_folder=None, + file_name=None + ): + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + print(f"Prompt: {prompt}") + + global preview_images + preview_images = [] + if preview_every == 0: + preview_every = None + + # Return fake stereo audio + conditioning = [{"prompt": prompt, "latitude": latitude, "longitude": longitude, "temperature": temperature, "humidity": humidity, "wind_speed": wind_speed, "pressure": pressure, "minutes_of_day": minutes_of_day,"day_of_year": day_of_year, "seconds_start":seconds_start, "seconds_total": seconds_total }] * batch_size + + if negative_prompt: + negative_conditioning = [{"prompt": negative_prompt, "latitude": latitude, "longitude": longitude, "temperature": temperature, "humidity": humidity, "wind_speed": wind_speed, "pressure": pressure, "minutes_of_day": minutes_of_day,"day_of_year": day_of_year, "seconds_start":seconds_start, "seconds_total": seconds_total}] * batch_size + else: + negative_conditioning = None + + #Get the device from the model + device = next(model.parameters()).device + + seed = int(seed) + + if not use_init: + init_audio = None + + input_sample_size = sample_size + + if init_audio is not None: + in_sr, init_audio = init_audio + # Turn into torch tensor, converting from int16 to float32 + init_audio = torch.from_numpy(init_audio).float().div(32767) + + if init_audio.dim() == 1: + init_audio = init_audio.unsqueeze(0) # [1, n] + elif init_audio.dim() == 2: + init_audio = init_audio.transpose(0, 1) # [n, 2] -> [2, n] + + if in_sr != sample_rate: + resample_tf = T.Resample(in_sr, sample_rate).to(init_audio.device) + init_audio = resample_tf(init_audio) + + audio_length = init_audio.shape[-1] + + if audio_length > sample_size: + + input_sample_size = audio_length + (model.min_input_length - (audio_length % model.min_input_length)) % model.min_input_length + + init_audio = (sample_rate, init_audio) + + def progress_callback(callback_info): + global preview_images + denoised = callback_info["denoised"] + current_step = callback_info["i"] + sigma = callback_info["sigma"] + + if (current_step - 1) % preview_every == 0: + if model.pretransform is not None: + denoised = model.pretransform.decode(denoised) + denoised = rearrange(denoised, "b d n -> d (b n)") + denoised = denoised.clamp(-1, 1).mul(32767).to(torch.int16).cpu() + audio_spectrogram = audio_spectrogram_image(denoised, sample_rate=sample_rate) + preview_images.append((audio_spectrogram, f"Step {current_step} sigma={sigma:.3f})")) + + # If inpainting, send mask args + # This will definitely change in the future + if mask_cropfrom is not None: + mask_args = { + "cropfrom": mask_cropfrom, + "pastefrom": mask_pastefrom, + "pasteto": mask_pasteto, + "maskstart": mask_maskstart, + "maskend": mask_maskend, + "softnessL": mask_softnessL, + "softnessR": mask_softnessR, + "marination": mask_marination, + } + else: + mask_args = None + + # Do the audio generation + audio = generate_diffusion_cond( + model, + conditioning=conditioning, + negative_conditioning=negative_conditioning, + steps=steps, + cfg_scale=cfg_scale, + batch_size=batch_size, + sample_size=input_sample_size, + sample_rate=sample_rate, + seed=seed, + device=device, + sampler_type=sampler_type, + sigma_min=sigma_min, + sigma_max=sigma_max, + init_audio=init_audio, + init_noise_level=init_noise_level, + mask_args = mask_args, + callback = progress_callback if preview_every is not None else None, + scale_phi = cfg_rescale + ) + + # Convert to WAV file + audio = rearrange(audio, "b d n -> d (b n)") + audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() + #save to the desired folder with the required filename and add the .wav extension + + if destination_folder is not None and file_name is not None: + torchaudio.save(f"{destination_folder}/{file_name}.wav", audio, sample_rate) + + + + # Let's look at a nice spectrogram too + # audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate) + + # return ("output.wav", [audio_spectrogram, *preview_images]) + + + +def generate_lm( + temperature=1.0, + top_p=0.95, + top_k=0, + batch_size=1, + ): + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + #Get the device from the model + device = next(model.parameters()).device + + audio = model.generate_audio( + batch_size=batch_size, + max_gen_len = sample_size//model.pretransform.downsampling_ratio, + conditioning=None, + temp=temperature, + top_p=top_p, + top_k=top_k, + use_cache=True + ) + + audio = rearrange(audio, "b d n -> d (b n)") + + audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() + + torchaudio.save("output.wav", audio, sample_rate) + + audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate) + + return ("output.wav", [audio_spectrogram]) + + + + +def autoencoder_process(audio, latent_noise, n_quantizers): + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + #Get the device from the model + device = next(model.parameters()).device + + in_sr, audio = audio + + audio = torch.from_numpy(audio).float().div(32767).to(device) + + if audio.dim() == 1: + audio = audio.unsqueeze(0) + else: + audio = audio.transpose(0, 1) + + audio = model.preprocess_audio_for_encoder(audio, in_sr) + # Note: If you need to do chunked encoding, to reduce VRAM, + # then add these arguments to encode_audio and decode_audio: chunked=True, overlap=32, chunk_size=128 + # To turn it off, do chunked=False + # Optimal overlap and chunk_size values will depend on the model. + # See encode_audio & decode_audio in autoencoders.py for more info + # Get dtype of model + dtype = next(model.parameters()).dtype + + audio = audio.to(dtype) + + if n_quantizers > 0: + latents = model.encode_audio(audio, chunked=False, n_quantizers=n_quantizers) + else: + latents = model.encode_audio(audio, chunked=False) + + if latent_noise > 0: + latents = latents + torch.randn_like(latents) * latent_noise + + audio = model.decode_audio(latents, chunked=False) + + audio = rearrange(audio, "b d n -> d (b n)") + + audio = audio.to(torch.float32).clamp(-1, 1).mul(32767).to(torch.int16).cpu() + + torchaudio.save("output.wav", audio, sample_rate) + + return "output.wav" + +def load_and_generate(model_path, json_dir, output_dir): + """Load JSON files and generate audio for each set of conditions.""" + # List all files in the json_dir + files = os.listdir(json_dir) + + # Filter for JSON files + json_files = [file for file in files if file.endswith('.json')] + + if not json_files: + print(f"No JSON files found in {json_dir}. Please check the directory path and file permissions.") + return + + for json_filename in json_files: + json_file_path = os.path.join(json_dir, json_filename) + + try: + with open(json_file_path, 'r') as file: + data = json.load(file) + except Exception as e: + print(f"Failed to read or parse {json_file_path}: {e}") + continue + + # Print the JSON path + print(json_file_path) + + # Extract conditions from JSON + conditions = { + 'birdSpecies': data['birdSpecies'], + 'latitude': data['coord']['lat'], + 'longitude': data['coord']['lon'], + 'temperature': data['main']['temp'], + 'humidity': data['main']['humidity'], + 'pressure': data['main']['pressure'], + 'wind_speed': data['wind']['speed'], + 'day_of_year': data['dayOfYear'], + 'minutes_of_day': data['minutesOfDay'] + } + + # Extract base filename components + step_number = re.search(r'step=(\d+)', model_path).group(1) + bird_species = conditions['birdSpecies'].replace(' ', '_') + base_filename = f"{bird_species}_{os.path.splitext(json_filename)[0]}_{step_number}_cfg_scale_" + + + + #An array of cfg scale values to test + cfg_scales = [1.8, 2.5, 4.0, 5.0, 12.0] + + # Generate audio we do this 4 times with a loop + for scale in cfg_scales: + generate_cond_with_path(prompt = "", + negative_prompt="", + seconds_start=0, + seconds_total=22, + latitude = conditions['latitude'], + longitude = conditions['longitude'], + temperature = conditions['temperature'], + humidity = conditions['humidity'], + wind_speed = conditions['wind_speed'], + pressure = conditions['pressure'], + minutes_of_day = conditions['minutes_of_day'], + day_of_year = conditions['day_of_year'], + cfg_scale=scale, + steps=250, + preview_every=None, + seed=-1, + sampler_type="dpmpp-2m-sde", + sigma_min=0.03, + sigma_max=50, + cfg_rescale=0.4, + use_init=False, + init_audio=None, + init_noise_level=1.0, + mask_cropfrom=None, + mask_pastefrom=None, + mask_pasteto=None, + mask_maskstart=None, + mask_maskend=None, + mask_softnessL=None, + mask_softnessR=None, + mask_marination=None, + batch_size=1, + destination_folder=output_dir, + file_name=base_filename + str(scale)) + + +def runTests(model_config_path=None, ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, model_half=False, json_dir=None, output_dir=None): + assert (pretrained_name is not None) ^ (model_config_path is not None and ckpt_path is not None), "Must specify either pretrained name or provide a model config and checkpoint, but not both" + + if model_config_path is not None: + # Load config from json file + with open(model_config_path) as f: + model_config = json.load(f) + else: + model_config = None + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + _, model_config = load_model(model_config, ckpt_path, pretrained_name=pretrained_name, pretransform_ckpt_path=pretransform_ckpt_path, model_half=model_half, device=device) + + # Ensure output directory exists- os.makedirs(args.output_dir, exist_ok=True) + + # Process all JSON files and generate audio + load_and_generate(ckpt_path, json_dir, output_dir) + + + + + + + diff --git a/stable_audio_tools/models/__init__.py b/stable_audio_tools/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7e27bbcb19a00a93e05ed6cf2a3a38895f26975d --- /dev/null +++ b/stable_audio_tools/models/__init__.py @@ -0,0 +1 @@ +from .factory import create_model_from_config, create_model_from_config_path \ No newline at end of file diff --git a/stable_audio_tools/models/adp.py b/stable_audio_tools/models/adp.py new file mode 100644 index 0000000000000000000000000000000000000000..8a8f5dc2d78a0306d93235bea2391c4df9dd873a --- /dev/null +++ b/stable_audio_tools/models/adp.py @@ -0,0 +1,1588 @@ +# Copied and modified from https://github.com/archinetai/audio-diffusion-pytorch/blob/v0.0.94/audio_diffusion_pytorch/modules.py under MIT License +# License can be found in LICENSES/LICENSE_ADP.txt + +import math +from inspect import isfunction +from math import ceil, floor, log, pi, log2 +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union +from packaging import version + +import torch +import torch.nn as nn +from einops import rearrange, reduce, repeat +from einops.layers.torch import Rearrange +from einops_exts import rearrange_many +from torch import Tensor, einsum +from torch.backends.cuda import sdp_kernel +from torch.nn import functional as F +from dac.nn.layers import Snake1d + +""" +Utils +""" + + +class ConditionedSequential(nn.Module): + def __init__(self, *modules): + super().__init__() + self.module_list = nn.ModuleList(*modules) + + def forward(self, x: Tensor, mapping: Optional[Tensor] = None): + for module in self.module_list: + x = module(x, mapping) + return x + +T = TypeVar("T") + +def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T: + if exists(val): + return val + return d() if isfunction(d) else d + +def exists(val: Optional[T]) -> T: + return val is not None + +def closest_power_2(x: float) -> int: + exponent = log2(x) + distance_fn = lambda z: abs(x - 2 ** z) # noqa + exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn) + return 2 ** int(exponent_closest) + +def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]: + return_dicts: Tuple[Dict, Dict] = ({}, {}) + for key in d.keys(): + no_prefix = int(not key.startswith(prefix)) + return_dicts[no_prefix][key] = d[key] + return return_dicts + +def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]: + kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d) + if keep_prefix: + return kwargs_with_prefix, kwargs + kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()} + return kwargs_no_prefix, kwargs + +""" +Convolutional Blocks +""" +import typing as tp + +# Copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conv.py under MIT License +# License available in LICENSES/LICENSE_META.txt + +def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, + padding_total: int = 0) -> int: + """See `pad_for_conv1d`.""" + length = x.shape[-1] + n_frames = (length - kernel_size + padding_total) / stride + 1 + ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) + return ideal_length - length + + +def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0): + """Pad for a convolution to make sure that the last window is full. + Extra padding is added at the end. This is required to ensure that we can rebuild + an output of the same length, as otherwise, even with padding, some time steps + might get removed. + For instance, with total padding = 4, kernel size = 4, stride = 2: + 0 0 1 2 3 4 5 0 0 # (0s are padding) + 1 2 3 # (output frames of a convolution, last 0 is never used) + 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding) + 1 2 3 4 # once you removed padding, we are missing one time step ! + """ + extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) + return F.pad(x, (0, extra_padding)) + + +def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.): + """Tiny wrapper around F.pad, just to allow for reflect padding on small input. + If this is the case, we insert extra 0 padding to the right before the reflection happen. + """ + length = x.shape[-1] + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + if mode == 'reflect': + max_pad = max(padding_left, padding_right) + extra_pad = 0 + if length <= max_pad: + extra_pad = max_pad - length + 1 + x = F.pad(x, (0, extra_pad)) + padded = F.pad(x, paddings, mode, value) + end = padded.shape[-1] - extra_pad + return padded[..., :end] + else: + return F.pad(x, paddings, mode, value) + + +def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): + """Remove padding from x, handling properly zero padding. Only for 1d!""" + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + assert (padding_left + padding_right) <= x.shape[-1] + end = x.shape[-1] - padding_right + return x[..., padding_left: end] + + +class Conv1d(nn.Conv1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x: Tensor, causal=False) -> Tensor: + kernel_size = self.kernel_size[0] + stride = self.stride[0] + dilation = self.dilation[0] + kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations + padding_total = kernel_size - stride + extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) + if causal: + # Left padding for causal + x = pad1d(x, (padding_total, extra_padding)) + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + x = pad1d(x, (padding_left, padding_right + extra_padding)) + return super().forward(x) + +class ConvTranspose1d(nn.ConvTranspose1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x: Tensor, causal=False) -> Tensor: + kernel_size = self.kernel_size[0] + stride = self.stride[0] + padding_total = kernel_size - stride + + y = super().forward(x) + + # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be + # removed at the very end, when keeping only the right length for the output, + # as removing it here would require also passing the length at the matching layer + # in the encoder. + if causal: + padding_right = ceil(padding_total) + padding_left = padding_total - padding_right + y = unpad1d(y, (padding_left, padding_right)) + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + y = unpad1d(y, (padding_left, padding_right)) + return y + + +def Downsample1d( + in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2 +) -> nn.Module: + assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even" + + return Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=factor * kernel_multiplier + 1, + stride=factor + ) + + +def Upsample1d( + in_channels: int, out_channels: int, factor: int, use_nearest: bool = False +) -> nn.Module: + + if factor == 1: + return Conv1d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3 + ) + + if use_nearest: + return nn.Sequential( + nn.Upsample(scale_factor=factor, mode="nearest"), + Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3 + ), + ) + else: + return ConvTranspose1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=factor * 2, + stride=factor + ) + + +class ConvBlock1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + *, + kernel_size: int = 3, + stride: int = 1, + dilation: int = 1, + num_groups: int = 8, + use_norm: bool = True, + use_snake: bool = False + ) -> None: + super().__init__() + + self.groupnorm = ( + nn.GroupNorm(num_groups=num_groups, num_channels=in_channels) + if use_norm + else nn.Identity() + ) + + if use_snake: + self.activation = Snake1d(in_channels) + else: + self.activation = nn.SiLU() + + self.project = Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + ) + + def forward( + self, x: Tensor, scale_shift: Optional[Tuple[Tensor, Tensor]] = None, causal=False + ) -> Tensor: + x = self.groupnorm(x) + if exists(scale_shift): + scale, shift = scale_shift + x = x * (scale + 1) + shift + x = self.activation(x) + return self.project(x, causal=causal) + + +class MappingToScaleShift(nn.Module): + def __init__( + self, + features: int, + channels: int, + ): + super().__init__() + + self.to_scale_shift = nn.Sequential( + nn.SiLU(), + nn.Linear(in_features=features, out_features=channels * 2), + ) + + def forward(self, mapping: Tensor) -> Tuple[Tensor, Tensor]: + scale_shift = self.to_scale_shift(mapping) + scale_shift = rearrange(scale_shift, "b c -> b c 1") + scale, shift = scale_shift.chunk(2, dim=1) + return scale, shift + + +class ResnetBlock1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + *, + kernel_size: int = 3, + stride: int = 1, + dilation: int = 1, + use_norm: bool = True, + use_snake: bool = False, + num_groups: int = 8, + context_mapping_features: Optional[int] = None, + ) -> None: + super().__init__() + + self.use_mapping = exists(context_mapping_features) + + self.block1 = ConvBlock1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + use_norm=use_norm, + num_groups=num_groups, + use_snake=use_snake + ) + + if self.use_mapping: + assert exists(context_mapping_features) + self.to_scale_shift = MappingToScaleShift( + features=context_mapping_features, channels=out_channels + ) + + self.block2 = ConvBlock1d( + in_channels=out_channels, + out_channels=out_channels, + use_norm=use_norm, + num_groups=num_groups, + use_snake=use_snake + ) + + self.to_out = ( + Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1) + if in_channels != out_channels + else nn.Identity() + ) + + def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor: + assert_message = "context mapping required if context_mapping_features > 0" + assert not (self.use_mapping ^ exists(mapping)), assert_message + + h = self.block1(x, causal=causal) + + scale_shift = None + if self.use_mapping: + scale_shift = self.to_scale_shift(mapping) + + h = self.block2(h, scale_shift=scale_shift, causal=causal) + + return h + self.to_out(x) + + +class Patcher(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + patch_size: int, + context_mapping_features: Optional[int] = None, + use_snake: bool = False, + ): + super().__init__() + assert_message = f"out_channels must be divisible by patch_size ({patch_size})" + assert out_channels % patch_size == 0, assert_message + self.patch_size = patch_size + + self.block = ResnetBlock1d( + in_channels=in_channels, + out_channels=out_channels // patch_size, + num_groups=1, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + + def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor: + x = self.block(x, mapping, causal=causal) + x = rearrange(x, "b c (l p) -> b (c p) l", p=self.patch_size) + return x + + +class Unpatcher(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + patch_size: int, + context_mapping_features: Optional[int] = None, + use_snake: bool = False + ): + super().__init__() + assert_message = f"in_channels must be divisible by patch_size ({patch_size})" + assert in_channels % patch_size == 0, assert_message + self.patch_size = patch_size + + self.block = ResnetBlock1d( + in_channels=in_channels // patch_size, + out_channels=out_channels, + num_groups=1, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + + def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor: + x = rearrange(x, " b (c p) l -> b c (l p) ", p=self.patch_size) + x = self.block(x, mapping, causal=causal) + return x + + +""" +Attention Components +""" +def FeedForward(features: int, multiplier: int) -> nn.Module: + mid_features = features * multiplier + return nn.Sequential( + nn.Linear(in_features=features, out_features=mid_features), + nn.GELU(), + nn.Linear(in_features=mid_features, out_features=features), + ) + +def add_mask(sim: Tensor, mask: Tensor) -> Tensor: + b, ndim = sim.shape[0], mask.ndim + if ndim == 3: + mask = rearrange(mask, "b n m -> b 1 n m") + if ndim == 2: + mask = repeat(mask, "n m -> b 1 n m", b=b) + max_neg_value = -torch.finfo(sim.dtype).max + sim = sim.masked_fill(~mask, max_neg_value) + return sim + +def causal_mask(q: Tensor, k: Tensor) -> Tensor: + b, i, j, device = q.shape[0], q.shape[-2], k.shape[-2], q.device + mask = ~torch.ones((i, j), dtype=torch.bool, device=device).triu(j - i + 1) + mask = repeat(mask, "n m -> b n m", b=b) + return mask + +class AttentionBase(nn.Module): + def __init__( + self, + features: int, + *, + head_features: int, + num_heads: int, + out_features: Optional[int] = None, + ): + super().__init__() + self.scale = head_features**-0.5 + self.num_heads = num_heads + mid_features = head_features * num_heads + out_features = default(out_features, features) + + self.to_out = nn.Linear( + in_features=mid_features, out_features=out_features + ) + + self.use_flash = False + + if not self.use_flash: + return + + device_properties = torch.cuda.get_device_properties(torch.device('cuda')) + + if device_properties.major == 8 and device_properties.minor == 0: + # Use flash attention for A100 GPUs + self.sdp_kernel_config = (False, True, True) + else: + # Don't use flash attention for other GPUs + self.sdp_kernel_config = (False, True, True) + + def forward( + self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, is_causal: bool = False + ) -> Tensor: + # Split heads + q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads) + + if not self.use_flash: + if is_causal and not mask: + # Mask out future tokens for causal attention + mask = causal_mask(q, k) + + # Compute similarity matrix and add eventual mask + sim = einsum("... n d, ... m d -> ... n m", q, k) * self.scale + sim = add_mask(sim, mask) if exists(mask) else sim + + # Get attention matrix with softmax + attn = sim.softmax(dim=-1, dtype=torch.float32) + + # Compute values + out = einsum("... n m, ... m d -> ... n d", attn, v) + else: + with sdp_kernel(*self.sdp_kernel_config): + out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, is_causal=is_causal) + + out = rearrange(out, "b h n d -> b n (h d)") + return self.to_out(out) + +class Attention(nn.Module): + def __init__( + self, + features: int, + *, + head_features: int, + num_heads: int, + out_features: Optional[int] = None, + context_features: Optional[int] = None, + causal: bool = False, + ): + super().__init__() + self.context_features = context_features + self.causal = causal + mid_features = head_features * num_heads + context_features = default(context_features, features) + + self.norm = nn.LayerNorm(features) + self.norm_context = nn.LayerNorm(context_features) + self.to_q = nn.Linear( + in_features=features, out_features=mid_features, bias=False + ) + self.to_kv = nn.Linear( + in_features=context_features, out_features=mid_features * 2, bias=False + ) + self.attention = AttentionBase( + features, + num_heads=num_heads, + head_features=head_features, + out_features=out_features, + ) + + def forward( + self, + x: Tensor, # [b, n, c] + context: Optional[Tensor] = None, # [b, m, d] + context_mask: Optional[Tensor] = None, # [b, m], false is masked, + causal: Optional[bool] = False, + ) -> Tensor: + assert_message = "You must provide a context when using context_features" + assert not self.context_features or exists(context), assert_message + # Use context if provided + context = default(context, x) + # Normalize then compute q from input and k,v from context + x, context = self.norm(x), self.norm_context(context) + + q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1)) + + if exists(context_mask): + # Mask out cross-attention for padding tokens + mask = repeat(context_mask, "b m -> b m d", d=v.shape[-1]) + k, v = k * mask, v * mask + + # Compute and return attention + return self.attention(q, k, v, is_causal=self.causal or causal) + + +def FeedForward(features: int, multiplier: int) -> nn.Module: + mid_features = features * multiplier + return nn.Sequential( + nn.Linear(in_features=features, out_features=mid_features), + nn.GELU(), + nn.Linear(in_features=mid_features, out_features=features), + ) + +""" +Transformer Blocks +""" + + +class TransformerBlock(nn.Module): + def __init__( + self, + features: int, + num_heads: int, + head_features: int, + multiplier: int, + context_features: Optional[int] = None, + ): + super().__init__() + + self.use_cross_attention = exists(context_features) and context_features > 0 + + self.attention = Attention( + features=features, + num_heads=num_heads, + head_features=head_features + ) + + if self.use_cross_attention: + self.cross_attention = Attention( + features=features, + num_heads=num_heads, + head_features=head_features, + context_features=context_features + ) + + self.feed_forward = FeedForward(features=features, multiplier=multiplier) + + def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal: Optional[bool] = False) -> Tensor: + x = self.attention(x, causal=causal) + x + if self.use_cross_attention: + x = self.cross_attention(x, context=context, context_mask=context_mask) + x + x = self.feed_forward(x) + x + return x + + +""" +Transformers +""" + + +class Transformer1d(nn.Module): + def __init__( + self, + num_layers: int, + channels: int, + num_heads: int, + head_features: int, + multiplier: int, + context_features: Optional[int] = None, + ): + super().__init__() + + self.to_in = nn.Sequential( + nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True), + Conv1d( + in_channels=channels, + out_channels=channels, + kernel_size=1, + ), + Rearrange("b c t -> b t c"), + ) + + self.blocks = nn.ModuleList( + [ + TransformerBlock( + features=channels, + head_features=head_features, + num_heads=num_heads, + multiplier=multiplier, + context_features=context_features, + ) + for i in range(num_layers) + ] + ) + + self.to_out = nn.Sequential( + Rearrange("b t c -> b c t"), + Conv1d( + in_channels=channels, + out_channels=channels, + kernel_size=1, + ), + ) + + def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal=False) -> Tensor: + x = self.to_in(x) + for block in self.blocks: + x = block(x, context=context, context_mask=context_mask, causal=causal) + x = self.to_out(x) + return x + + +""" +Time Embeddings +""" + + +class SinusoidalEmbedding(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, x: Tensor) -> Tensor: + device, half_dim = x.device, self.dim // 2 + emb = torch.tensor(log(10000) / (half_dim - 1), device=device) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j") + return torch.cat((emb.sin(), emb.cos()), dim=-1) + + +class LearnedPositionalEmbedding(nn.Module): + """Used for continuous time""" + + def __init__(self, dim: int): + super().__init__() + assert (dim % 2) == 0 + half_dim = dim // 2 + self.weights = nn.Parameter(torch.randn(half_dim)) + + def forward(self, x: Tensor) -> Tensor: + x = rearrange(x, "b -> b 1") + freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) + fouriered = torch.cat((x, fouriered), dim=-1) + return fouriered + + +def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module: + return nn.Sequential( + LearnedPositionalEmbedding(dim), + nn.Linear(in_features=dim + 1, out_features=out_features), + ) + + +""" +Encoder/Decoder Components +""" + + +class DownsampleBlock1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + *, + factor: int, + num_groups: int, + num_layers: int, + kernel_multiplier: int = 2, + use_pre_downsample: bool = True, + use_skip: bool = False, + use_snake: bool = False, + extract_channels: int = 0, + context_channels: int = 0, + num_transformer_blocks: int = 0, + attention_heads: Optional[int] = None, + attention_features: Optional[int] = None, + attention_multiplier: Optional[int] = None, + context_mapping_features: Optional[int] = None, + context_embedding_features: Optional[int] = None, + ): + super().__init__() + self.use_pre_downsample = use_pre_downsample + self.use_skip = use_skip + self.use_transformer = num_transformer_blocks > 0 + self.use_extract = extract_channels > 0 + self.use_context = context_channels > 0 + + channels = out_channels if use_pre_downsample else in_channels + + self.downsample = Downsample1d( + in_channels=in_channels, + out_channels=out_channels, + factor=factor, + kernel_multiplier=kernel_multiplier, + ) + + self.blocks = nn.ModuleList( + [ + ResnetBlock1d( + in_channels=channels + context_channels if i == 0 else channels, + out_channels=channels, + num_groups=num_groups, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + for i in range(num_layers) + ] + ) + + if self.use_transformer: + assert ( + (exists(attention_heads) or exists(attention_features)) + and exists(attention_multiplier) + ) + + if attention_features is None and attention_heads is not None: + attention_features = channels // attention_heads + + if attention_heads is None and attention_features is not None: + attention_heads = channels // attention_features + + self.transformer = Transformer1d( + num_layers=num_transformer_blocks, + channels=channels, + num_heads=attention_heads, + head_features=attention_features, + multiplier=attention_multiplier, + context_features=context_embedding_features + ) + + if self.use_extract: + num_extract_groups = min(num_groups, extract_channels) + self.to_extracted = ResnetBlock1d( + in_channels=out_channels, + out_channels=extract_channels, + num_groups=num_extract_groups, + use_snake=use_snake + ) + + def forward( + self, + x: Tensor, + *, + mapping: Optional[Tensor] = None, + channels: Optional[Tensor] = None, + embedding: Optional[Tensor] = None, + embedding_mask: Optional[Tensor] = None, + causal: Optional[bool] = False + ) -> Union[Tuple[Tensor, List[Tensor]], Tensor]: + + if self.use_pre_downsample: + x = self.downsample(x) + + if self.use_context and exists(channels): + x = torch.cat([x, channels], dim=1) + + skips = [] + for block in self.blocks: + x = block(x, mapping=mapping, causal=causal) + skips += [x] if self.use_skip else [] + + if self.use_transformer: + x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal) + skips += [x] if self.use_skip else [] + + if not self.use_pre_downsample: + x = self.downsample(x) + + if self.use_extract: + extracted = self.to_extracted(x) + return x, extracted + + return (x, skips) if self.use_skip else x + + +class UpsampleBlock1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + *, + factor: int, + num_layers: int, + num_groups: int, + use_nearest: bool = False, + use_pre_upsample: bool = False, + use_skip: bool = False, + use_snake: bool = False, + skip_channels: int = 0, + use_skip_scale: bool = False, + extract_channels: int = 0, + num_transformer_blocks: int = 0, + attention_heads: Optional[int] = None, + attention_features: Optional[int] = None, + attention_multiplier: Optional[int] = None, + context_mapping_features: Optional[int] = None, + context_embedding_features: Optional[int] = None, + ): + super().__init__() + + self.use_extract = extract_channels > 0 + self.use_pre_upsample = use_pre_upsample + self.use_transformer = num_transformer_blocks > 0 + self.use_skip = use_skip + self.skip_scale = 2 ** -0.5 if use_skip_scale else 1.0 + + channels = out_channels if use_pre_upsample else in_channels + + self.blocks = nn.ModuleList( + [ + ResnetBlock1d( + in_channels=channels + skip_channels, + out_channels=channels, + num_groups=num_groups, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + for _ in range(num_layers) + ] + ) + + if self.use_transformer: + assert ( + (exists(attention_heads) or exists(attention_features)) + and exists(attention_multiplier) + ) + + if attention_features is None and attention_heads is not None: + attention_features = channels // attention_heads + + if attention_heads is None and attention_features is not None: + attention_heads = channels // attention_features + + self.transformer = Transformer1d( + num_layers=num_transformer_blocks, + channels=channels, + num_heads=attention_heads, + head_features=attention_features, + multiplier=attention_multiplier, + context_features=context_embedding_features, + ) + + self.upsample = Upsample1d( + in_channels=in_channels, + out_channels=out_channels, + factor=factor, + use_nearest=use_nearest, + ) + + if self.use_extract: + num_extract_groups = min(num_groups, extract_channels) + self.to_extracted = ResnetBlock1d( + in_channels=out_channels, + out_channels=extract_channels, + num_groups=num_extract_groups, + use_snake=use_snake + ) + + def add_skip(self, x: Tensor, skip: Tensor) -> Tensor: + return torch.cat([x, skip * self.skip_scale], dim=1) + + def forward( + self, + x: Tensor, + *, + skips: Optional[List[Tensor]] = None, + mapping: Optional[Tensor] = None, + embedding: Optional[Tensor] = None, + embedding_mask: Optional[Tensor] = None, + causal: Optional[bool] = False + ) -> Union[Tuple[Tensor, Tensor], Tensor]: + + if self.use_pre_upsample: + x = self.upsample(x) + + for block in self.blocks: + x = self.add_skip(x, skip=skips.pop()) if exists(skips) else x + x = block(x, mapping=mapping, causal=causal) + + if self.use_transformer: + x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal) + + if not self.use_pre_upsample: + x = self.upsample(x) + + if self.use_extract: + extracted = self.to_extracted(x) + return x, extracted + + return x + + +class BottleneckBlock1d(nn.Module): + def __init__( + self, + channels: int, + *, + num_groups: int, + num_transformer_blocks: int = 0, + attention_heads: Optional[int] = None, + attention_features: Optional[int] = None, + attention_multiplier: Optional[int] = None, + context_mapping_features: Optional[int] = None, + context_embedding_features: Optional[int] = None, + use_snake: bool = False, + ): + super().__init__() + self.use_transformer = num_transformer_blocks > 0 + + self.pre_block = ResnetBlock1d( + in_channels=channels, + out_channels=channels, + num_groups=num_groups, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + + if self.use_transformer: + assert ( + (exists(attention_heads) or exists(attention_features)) + and exists(attention_multiplier) + ) + + if attention_features is None and attention_heads is not None: + attention_features = channels // attention_heads + + if attention_heads is None and attention_features is not None: + attention_heads = channels // attention_features + + self.transformer = Transformer1d( + num_layers=num_transformer_blocks, + channels=channels, + num_heads=attention_heads, + head_features=attention_features, + multiplier=attention_multiplier, + context_features=context_embedding_features, + ) + + self.post_block = ResnetBlock1d( + in_channels=channels, + out_channels=channels, + num_groups=num_groups, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + + def forward( + self, + x: Tensor, + *, + mapping: Optional[Tensor] = None, + embedding: Optional[Tensor] = None, + embedding_mask: Optional[Tensor] = None, + causal: Optional[bool] = False + ) -> Tensor: + x = self.pre_block(x, mapping=mapping, causal=causal) + if self.use_transformer: + x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal) + x = self.post_block(x, mapping=mapping, causal=causal) + return x + + +""" +UNet +""" + + +class UNet1d(nn.Module): + def __init__( + self, + in_channels: int, + channels: int, + multipliers: Sequence[int], + factors: Sequence[int], + num_blocks: Sequence[int], + attentions: Sequence[int], + patch_size: int = 1, + resnet_groups: int = 8, + use_context_time: bool = True, + kernel_multiplier_downsample: int = 2, + use_nearest_upsample: bool = False, + use_skip_scale: bool = True, + use_snake: bool = False, + use_stft: bool = False, + use_stft_context: bool = False, + out_channels: Optional[int] = None, + context_features: Optional[int] = None, + context_features_multiplier: int = 4, + context_channels: Optional[Sequence[int]] = None, + context_embedding_features: Optional[int] = None, + **kwargs, + ): + super().__init__() + out_channels = default(out_channels, in_channels) + context_channels = list(default(context_channels, [])) + num_layers = len(multipliers) - 1 + use_context_features = exists(context_features) + use_context_channels = len(context_channels) > 0 + context_mapping_features = None + + attention_kwargs, kwargs = groupby("attention_", kwargs, keep_prefix=True) + + self.num_layers = num_layers + self.use_context_time = use_context_time + self.use_context_features = use_context_features + self.use_context_channels = use_context_channels + self.use_stft = use_stft + self.use_stft_context = use_stft_context + + self.context_features = context_features + context_channels_pad_length = num_layers + 1 - len(context_channels) + context_channels = context_channels + [0] * context_channels_pad_length + self.context_channels = context_channels + self.context_embedding_features = context_embedding_features + + if use_context_channels: + has_context = [c > 0 for c in context_channels] + self.has_context = has_context + self.channels_ids = [sum(has_context[:i]) for i in range(len(has_context))] + + assert ( + len(factors) == num_layers + and len(attentions) >= num_layers + and len(num_blocks) == num_layers + ) + + if use_context_time or use_context_features: + context_mapping_features = channels * context_features_multiplier + + self.to_mapping = nn.Sequential( + nn.Linear(context_mapping_features, context_mapping_features), + nn.GELU(), + nn.Linear(context_mapping_features, context_mapping_features), + nn.GELU(), + ) + + if use_context_time: + assert exists(context_mapping_features) + self.to_time = nn.Sequential( + TimePositionalEmbedding( + dim=channels, out_features=context_mapping_features + ), + nn.GELU(), + ) + + if use_context_features: + assert exists(context_features) and exists(context_mapping_features) + self.to_features = nn.Sequential( + nn.Linear( + in_features=context_features, out_features=context_mapping_features + ), + nn.GELU(), + ) + + if use_stft: + stft_kwargs, kwargs = groupby("stft_", kwargs) + assert "num_fft" in stft_kwargs, "stft_num_fft required if use_stft=True" + stft_channels = (stft_kwargs["num_fft"] // 2 + 1) * 2 + in_channels *= stft_channels + out_channels *= stft_channels + context_channels[0] *= stft_channels if use_stft_context else 1 + assert exists(in_channels) and exists(out_channels) + self.stft = STFT(**stft_kwargs) + + assert not kwargs, f"Unknown arguments: {', '.join(list(kwargs.keys()))}" + + self.to_in = Patcher( + in_channels=in_channels + context_channels[0], + out_channels=channels * multipliers[0], + patch_size=patch_size, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + + self.downsamples = nn.ModuleList( + [ + DownsampleBlock1d( + in_channels=channels * multipliers[i], + out_channels=channels * multipliers[i + 1], + context_mapping_features=context_mapping_features, + context_channels=context_channels[i + 1], + context_embedding_features=context_embedding_features, + num_layers=num_blocks[i], + factor=factors[i], + kernel_multiplier=kernel_multiplier_downsample, + num_groups=resnet_groups, + use_pre_downsample=True, + use_skip=True, + use_snake=use_snake, + num_transformer_blocks=attentions[i], + **attention_kwargs, + ) + for i in range(num_layers) + ] + ) + + self.bottleneck = BottleneckBlock1d( + channels=channels * multipliers[-1], + context_mapping_features=context_mapping_features, + context_embedding_features=context_embedding_features, + num_groups=resnet_groups, + num_transformer_blocks=attentions[-1], + use_snake=use_snake, + **attention_kwargs, + ) + + self.upsamples = nn.ModuleList( + [ + UpsampleBlock1d( + in_channels=channels * multipliers[i + 1], + out_channels=channels * multipliers[i], + context_mapping_features=context_mapping_features, + context_embedding_features=context_embedding_features, + num_layers=num_blocks[i] + (1 if attentions[i] else 0), + factor=factors[i], + use_nearest=use_nearest_upsample, + num_groups=resnet_groups, + use_skip_scale=use_skip_scale, + use_pre_upsample=False, + use_skip=True, + use_snake=use_snake, + skip_channels=channels * multipliers[i + 1], + num_transformer_blocks=attentions[i], + **attention_kwargs, + ) + for i in reversed(range(num_layers)) + ] + ) + + self.to_out = Unpatcher( + in_channels=channels * multipliers[0], + out_channels=out_channels, + patch_size=patch_size, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + + def get_channels( + self, channels_list: Optional[Sequence[Tensor]] = None, layer: int = 0 + ) -> Optional[Tensor]: + """Gets context channels at `layer` and checks that shape is correct""" + use_context_channels = self.use_context_channels and self.has_context[layer] + if not use_context_channels: + return None + assert exists(channels_list), "Missing context" + # Get channels index (skipping zero channel contexts) + channels_id = self.channels_ids[layer] + # Get channels + channels = channels_list[channels_id] + message = f"Missing context for layer {layer} at index {channels_id}" + assert exists(channels), message + # Check channels + num_channels = self.context_channels[layer] + message = f"Expected context with {num_channels} channels at idx {channels_id}" + assert channels.shape[1] == num_channels, message + # STFT channels if requested + channels = self.stft.encode1d(channels) if self.use_stft_context else channels # type: ignore # noqa + return channels + + def get_mapping( + self, time: Optional[Tensor] = None, features: Optional[Tensor] = None + ) -> Optional[Tensor]: + """Combines context time features and features into mapping""" + items, mapping = [], None + # Compute time features + if self.use_context_time: + assert_message = "use_context_time=True but no time features provided" + assert exists(time), assert_message + items += [self.to_time(time)] + # Compute features + if self.use_context_features: + assert_message = "context_features exists but no features provided" + assert exists(features), assert_message + items += [self.to_features(features)] + # Compute joint mapping + if self.use_context_time or self.use_context_features: + mapping = reduce(torch.stack(items), "n b m -> b m", "sum") + mapping = self.to_mapping(mapping) + return mapping + + def forward( + self, + x: Tensor, + time: Optional[Tensor] = None, + *, + features: Optional[Tensor] = None, + channels_list: Optional[Sequence[Tensor]] = None, + embedding: Optional[Tensor] = None, + embedding_mask: Optional[Tensor] = None, + causal: Optional[bool] = False, + ) -> Tensor: + channels = self.get_channels(channels_list, layer=0) + # Apply stft if required + x = self.stft.encode1d(x) if self.use_stft else x # type: ignore + # Concat context channels at layer 0 if provided + x = torch.cat([x, channels], dim=1) if exists(channels) else x + # Compute mapping from time and features + mapping = self.get_mapping(time, features) + x = self.to_in(x, mapping, causal=causal) + skips_list = [x] + + for i, downsample in enumerate(self.downsamples): + channels = self.get_channels(channels_list, layer=i + 1) + x, skips = downsample( + x, mapping=mapping, channels=channels, embedding=embedding, embedding_mask=embedding_mask, causal=causal + ) + skips_list += [skips] + + x = self.bottleneck(x, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal) + + for i, upsample in enumerate(self.upsamples): + skips = skips_list.pop() + x = upsample(x, skips=skips, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal) + + x += skips_list.pop() + x = self.to_out(x, mapping, causal=causal) + x = self.stft.decode1d(x) if self.use_stft else x + + return x + + +""" Conditioning Modules """ + + +class FixedEmbedding(nn.Module): + def __init__(self, max_length: int, features: int): + super().__init__() + self.max_length = max_length + self.embedding = nn.Embedding(max_length, features) + + def forward(self, x: Tensor) -> Tensor: + batch_size, length, device = *x.shape[0:2], x.device + assert_message = "Input sequence length must be <= max_length" + assert length <= self.max_length, assert_message + position = torch.arange(length, device=device) + fixed_embedding = self.embedding(position) + fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size) + return fixed_embedding + + +def rand_bool(shape: Any, proba: float, device: Any = None) -> Tensor: + if proba == 1: + return torch.ones(shape, device=device, dtype=torch.bool) + elif proba == 0: + return torch.zeros(shape, device=device, dtype=torch.bool) + else: + return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool) + + +class UNetCFG1d(UNet1d): + + """UNet1d with Classifier-Free Guidance""" + + def __init__( + self, + context_embedding_max_length: int, + context_embedding_features: int, + use_xattn_time: bool = False, + **kwargs, + ): + super().__init__( + context_embedding_features=context_embedding_features, **kwargs + ) + + self.use_xattn_time = use_xattn_time + + if use_xattn_time: + assert exists(context_embedding_features) + self.to_time_embedding = nn.Sequential( + TimePositionalEmbedding( + dim=kwargs["channels"], out_features=context_embedding_features + ), + nn.GELU(), + ) + + context_embedding_max_length += 1 # Add one for time embedding + + self.fixed_embedding = FixedEmbedding( + max_length=context_embedding_max_length, features=context_embedding_features + ) + + def forward( # type: ignore + self, + x: Tensor, + time: Tensor, + *, + embedding: Tensor, + embedding_mask: Optional[Tensor] = None, + embedding_scale: float = 1.0, + embedding_mask_proba: float = 0.0, + batch_cfg: bool = False, + rescale_cfg: bool = False, + scale_phi: float = 0.4, + negative_embedding: Optional[Tensor] = None, + negative_embedding_mask: Optional[Tensor] = None, + **kwargs, + ) -> Tensor: + b, device = embedding.shape[0], embedding.device + + if self.use_xattn_time: + embedding = torch.cat([embedding, self.to_time_embedding(time).unsqueeze(1)], dim=1) + + if embedding_mask is not None: + embedding_mask = torch.cat([embedding_mask, torch.ones((b, 1), device=device)], dim=1) + + fixed_embedding = self.fixed_embedding(embedding) + + if embedding_mask_proba > 0.0: + # Randomly mask embedding + batch_mask = rand_bool( + shape=(b, 1, 1), proba=embedding_mask_proba, device=device + ) + embedding = torch.where(batch_mask, fixed_embedding, embedding) + + if embedding_scale != 1.0: + if batch_cfg: + batch_x = torch.cat([x, x], dim=0) + batch_time = torch.cat([time, time], dim=0) + + if negative_embedding is not None: + if negative_embedding_mask is not None: + negative_embedding_mask = negative_embedding_mask.to(torch.bool).unsqueeze(2) + + negative_embedding = torch.where(negative_embedding_mask, negative_embedding, fixed_embedding) + + batch_embed = torch.cat([embedding, negative_embedding], dim=0) + + else: + batch_embed = torch.cat([embedding, fixed_embedding], dim=0) + + batch_mask = None + if embedding_mask is not None: + batch_mask = torch.cat([embedding_mask, embedding_mask], dim=0) + + batch_features = None + features = kwargs.pop("features", None) + if self.use_context_features: + batch_features = torch.cat([features, features], dim=0) + + batch_channels = None + channels_list = kwargs.pop("channels_list", None) + if self.use_context_channels: + batch_channels = [] + for channels in channels_list: + batch_channels += [torch.cat([channels, channels], dim=0)] + + # Compute both normal and fixed embedding outputs + batch_out = super().forward(batch_x, batch_time, embedding=batch_embed, embedding_mask=batch_mask, features=batch_features, channels_list=batch_channels, **kwargs) + out, out_masked = batch_out.chunk(2, dim=0) + + else: + # Compute both normal and fixed embedding outputs + out = super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs) + out_masked = super().forward(x, time, embedding=fixed_embedding, embedding_mask=embedding_mask, **kwargs) + + out_cfg = out_masked + (out - out_masked) * embedding_scale + + if rescale_cfg: + + out_std = out.std(dim=1, keepdim=True) + out_cfg_std = out_cfg.std(dim=1, keepdim=True) + + return scale_phi * (out_cfg * (out_std/out_cfg_std)) + (1-scale_phi) * out_cfg + + else: + + return out_cfg + + else: + return super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs) + + +class UNetNCCA1d(UNet1d): + + """UNet1d with Noise Channel Conditioning Augmentation""" + + def __init__(self, context_features: int, **kwargs): + super().__init__(context_features=context_features, **kwargs) + self.embedder = NumberEmbedder(features=context_features) + + def expand(self, x: Any, shape: Tuple[int, ...]) -> Tensor: + x = x if torch.is_tensor(x) else torch.tensor(x) + return x.expand(shape) + + def forward( # type: ignore + self, + x: Tensor, + time: Tensor, + *, + channels_list: Sequence[Tensor], + channels_augmentation: Union[ + bool, Sequence[bool], Sequence[Sequence[bool]], Tensor + ] = False, + channels_scale: Union[ + float, Sequence[float], Sequence[Sequence[float]], Tensor + ] = 0, + **kwargs, + ) -> Tensor: + b, n = x.shape[0], len(channels_list) + channels_augmentation = self.expand(channels_augmentation, shape=(b, n)).to(x) + channels_scale = self.expand(channels_scale, shape=(b, n)).to(x) + + # Augmentation (for each channel list item) + for i in range(n): + scale = channels_scale[:, i] * channels_augmentation[:, i] + scale = rearrange(scale, "b -> b 1 1") + item = channels_list[i] + channels_list[i] = torch.randn_like(item) * scale + item * (1 - scale) # type: ignore # noqa + + # Scale embedding (sum reduction if more than one channel list item) + channels_scale_emb = self.embedder(channels_scale) + channels_scale_emb = reduce(channels_scale_emb, "b n d -> b d", "sum") + + return super().forward( + x=x, + time=time, + channels_list=channels_list, + features=channels_scale_emb, + **kwargs, + ) + + +class UNetAll1d(UNetCFG1d, UNetNCCA1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, *args, **kwargs): # type: ignore + return UNetCFG1d.forward(self, *args, **kwargs) + + +def XUNet1d(type: str = "base", **kwargs) -> UNet1d: + if type == "base": + return UNet1d(**kwargs) + elif type == "all": + return UNetAll1d(**kwargs) + elif type == "cfg": + return UNetCFG1d(**kwargs) + elif type == "ncca": + return UNetNCCA1d(**kwargs) + else: + raise ValueError(f"Unknown XUNet1d type: {type}") + +class NumberEmbedder(nn.Module): + def __init__( + self, + features: int, + dim: int = 256, + ): + super().__init__() + self.features = features + self.embedding = TimePositionalEmbedding(dim=dim, out_features=features) + + def forward(self, x: Union[List[float], Tensor]) -> Tensor: + if not torch.is_tensor(x): + device = next(self.embedding.parameters()).device + x = torch.tensor(x, device=device) + assert isinstance(x, Tensor) + shape = x.shape + x = rearrange(x, "... -> (...)") + embedding = self.embedding(x) + x = embedding.view(*shape, self.features) + return x # type: ignore + + +""" +Audio Transforms +""" + + +class STFT(nn.Module): + """Helper for torch stft and istft""" + + def __init__( + self, + num_fft: int = 1023, + hop_length: int = 256, + window_length: Optional[int] = None, + length: Optional[int] = None, + use_complex: bool = False, + ): + super().__init__() + self.num_fft = num_fft + self.hop_length = default(hop_length, floor(num_fft // 4)) + self.window_length = default(window_length, num_fft) + self.length = length + self.register_buffer("window", torch.hann_window(self.window_length)) + self.use_complex = use_complex + + def encode(self, wave: Tensor) -> Tuple[Tensor, Tensor]: + b = wave.shape[0] + wave = rearrange(wave, "b c t -> (b c) t") + + stft = torch.stft( + wave, + n_fft=self.num_fft, + hop_length=self.hop_length, + win_length=self.window_length, + window=self.window, # type: ignore + return_complex=True, + normalized=True, + ) + + if self.use_complex: + # Returns real and imaginary + stft_a, stft_b = stft.real, stft.imag + else: + # Returns magnitude and phase matrices + magnitude, phase = torch.abs(stft), torch.angle(stft) + stft_a, stft_b = magnitude, phase + + return rearrange_many((stft_a, stft_b), "(b c) f l -> b c f l", b=b) + + def decode(self, stft_a: Tensor, stft_b: Tensor) -> Tensor: + b, l = stft_a.shape[0], stft_a.shape[-1] # noqa + length = closest_power_2(l * self.hop_length) + + stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> (b c) f l") + + if self.use_complex: + real, imag = stft_a, stft_b + else: + magnitude, phase = stft_a, stft_b + real, imag = magnitude * torch.cos(phase), magnitude * torch.sin(phase) + + stft = torch.stack([real, imag], dim=-1) + + wave = torch.istft( + stft, + n_fft=self.num_fft, + hop_length=self.hop_length, + win_length=self.window_length, + window=self.window, # type: ignore + length=default(self.length, length), + normalized=True, + ) + + return rearrange(wave, "(b c) t -> b c t", b=b) + + def encode1d( + self, wave: Tensor, stacked: bool = True + ) -> Union[Tensor, Tuple[Tensor, Tensor]]: + stft_a, stft_b = self.encode(wave) + stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> b (c f) l") + return torch.cat((stft_a, stft_b), dim=1) if stacked else (stft_a, stft_b) + + def decode1d(self, stft_pair: Tensor) -> Tensor: + f = self.num_fft // 2 + 1 + stft_a, stft_b = stft_pair.chunk(chunks=2, dim=1) + stft_a, stft_b = rearrange_many((stft_a, stft_b), "b (c f) l -> b c f l", f=f) + return self.decode(stft_a, stft_b) diff --git a/stable_audio_tools/models/autoencoders.py b/stable_audio_tools/models/autoencoders.py new file mode 100644 index 0000000000000000000000000000000000000000..878a254bc7ecabc6c81be8e40d3d332e900f44a1 --- /dev/null +++ b/stable_audio_tools/models/autoencoders.py @@ -0,0 +1,800 @@ +import torch +import math +import numpy as np + +from torch import nn, sin, pow +from torch.nn import functional as F +from torchaudio import transforms as T +from alias_free_torch import Activation1d +from dac.nn.layers import WNConv1d, WNConvTranspose1d +from typing import List, Literal, Dict, Any, Callable +from einops import rearrange + +from ..inference.sampling import sample +from ..inference.utils import prepare_audio +from .blocks import SnakeBeta +from .bottleneck import Bottleneck, DiscreteBottleneck +from .diffusion import ConditionedDiffusionModel, DAU1DCondWrapper, UNet1DCondWrapper, DiTWrapper +from .factory import create_pretransform_from_config, create_bottleneck_from_config +from .pretransforms import Pretransform, AutoencoderPretransform + +def checkpoint(function, *args, **kwargs): + kwargs.setdefault("use_reentrant", False) + return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) + +def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module: + if activation == "elu": + act = nn.ELU() + elif activation == "snake": + act = SnakeBeta(channels) + elif activation == "none": + act = nn.Identity() + else: + raise ValueError(f"Unknown activation {activation}") + + if antialias: + act = Activation1d(act) + + return act + +class ResidualUnit(nn.Module): + def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False): + super().__init__() + + self.dilation = dilation + + act = get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels) + + padding = (dilation * (7-1)) // 2 + + self.layers = nn.Sequential( + act, + WNConv1d(in_channels=in_channels, out_channels=out_channels, + kernel_size=7, dilation=dilation, padding=padding), + act, + WNConv1d(in_channels=out_channels, out_channels=out_channels, + kernel_size=1) + ) + + def forward(self, x): + res = x + + # Disable checkpoint until tensor mismatch is fixed + #x = checkpoint(self.layers, x) + x = self.layers(x) + + +class EncoderBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False): + super().__init__() + + act = get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels) + + self.layers = nn.Sequential( + ResidualUnit(in_channels=in_channels, + out_channels=in_channels, dilation=1, use_snake=use_snake), + ResidualUnit(in_channels=in_channels, + out_channels=in_channels, dilation=3, use_snake=use_snake), + ResidualUnit(in_channels=in_channels, + out_channels=in_channels, dilation=9, use_snake=use_snake), + act, + WNConv1d(in_channels=in_channels, out_channels=out_channels, + kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)), + ) + + def forward(self, x): + return self.layers(x) + +class DecoderBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False): + super().__init__() + + if use_nearest_upsample: + upsample_layer = nn.Sequential( + nn.Upsample(scale_factor=stride, mode="nearest"), + WNConv1d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=2*stride, + stride=1, + bias=False, + padding='same') + ) + else: + upsample_layer = WNConvTranspose1d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)) + + act = get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels) + + self.layers = nn.Sequential( + act, + upsample_layer, + ResidualUnit(in_channels=out_channels, out_channels=out_channels, + dilation=1, use_snake=use_snake), + ResidualUnit(in_channels=out_channels, out_channels=out_channels, + dilation=3, use_snake=use_snake), + ResidualUnit(in_channels=out_channels, out_channels=out_channels, + dilation=9, use_snake=use_snake), + ) + + def forward(self, x): + return self.layers(x) + +class OobleckEncoder(nn.Module): + def __init__(self, + in_channels=2, + channels=128, + latent_dim=32, + c_mults = [1, 2, 4, 8], + strides = [2, 4, 8, 8], + use_snake=False, + antialias_activation=False + ): + super().__init__() + + c_mults = [1] + c_mults + + self.depth = len(c_mults) + + layers = [ + WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3) + ] + + for i in range(self.depth-1): + layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)] + + layers += [ + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels), + WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1) + ] + + self.layers = nn.Sequential(*layers) + + def forward(self, x): + return self.layers(x) + + +class OobleckDecoder(nn.Module): + def __init__(self, + out_channels=2, + channels=128, + latent_dim=32, + c_mults = [1, 2, 4, 8], + strides = [2, 4, 8, 8], + use_snake=False, + antialias_activation=False, + use_nearest_upsample=False, + final_tanh=True): + super().__init__() + + c_mults = [1] + c_mults + + self.depth = len(c_mults) + + layers = [ + WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3), + ] + + for i in range(self.depth-1, 0, -1): + layers += [DecoderBlock( + in_channels=c_mults[i]*channels, + out_channels=c_mults[i-1]*channels, + stride=strides[i-1], + use_snake=use_snake, + antialias_activation=antialias_activation, + use_nearest_upsample=use_nearest_upsample + ) + ] + + layers += [ + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels), + WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False), + nn.Tanh() if final_tanh else nn.Identity() + ] + + self.layers = nn.Sequential(*layers) + + def forward(self, x): + return self.layers(x) + +class DACEncoderWrapper(nn.Module): + def __init__(self, in_channels=1, **kwargs): + super().__init__() + + from dac.model.dac import Encoder as DACEncoder + + latent_dim = kwargs.pop("latent_dim", None) + + encoder_out_dim = kwargs["d_model"] * (2 ** len(kwargs["strides"])) + self.encoder = DACEncoder(d_latent=encoder_out_dim, **kwargs) + self.latent_dim = latent_dim + + # Latent-dim support was added to DAC after this was first written, and implemented differently, so this is for backwards compatibility + self.proj_out = nn.Conv1d(self.encoder.enc_dim, latent_dim, kernel_size=1) if latent_dim is not None else nn.Identity() + + if in_channels != 1: + self.encoder.block[0] = WNConv1d(in_channels, kwargs.get("d_model", 64), kernel_size=7, padding=3) + + def forward(self, x): + x = self.encoder(x) + x = self.proj_out(x) + return x + +class DACDecoderWrapper(nn.Module): + def __init__(self, latent_dim, out_channels=1, **kwargs): + super().__init__() + + from dac.model.dac import Decoder as DACDecoder + + self.decoder = DACDecoder(**kwargs, input_channel = latent_dim, d_out=out_channels) + + self.latent_dim = latent_dim + + def forward(self, x): + return self.decoder(x) + +class AudioAutoencoder(nn.Module): + def __init__( + self, + encoder, + decoder, + latent_dim, + downsampling_ratio, + sample_rate, + io_channels=2, + bottleneck: Bottleneck = None, + pretransform: Pretransform = None, + in_channels = None, + out_channels = None, + soft_clip = False + ): + super().__init__() + + self.downsampling_ratio = downsampling_ratio + self.sample_rate = sample_rate + + self.latent_dim = latent_dim + self.io_channels = io_channels + self.in_channels = io_channels + self.out_channels = io_channels + + self.min_length = self.downsampling_ratio + + if in_channels is not None: + self.in_channels = in_channels + + if out_channels is not None: + self.out_channels = out_channels + + self.bottleneck = bottleneck + + self.encoder = encoder + + self.decoder = decoder + + self.pretransform = pretransform + + self.soft_clip = soft_clip + + self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete + + def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batch=False, **kwargs): + + info = {} + + if self.pretransform is not None and not skip_pretransform: + if self.pretransform.enable_grad: + if iterate_batch: + audios = [] + for i in range(audio.shape[0]): + audios.append(self.pretransform.encode(audio[i:i+1])) + audio = torch.cat(audios, dim=0) + else: + audio = self.pretransform.encode(audio) + else: + with torch.no_grad(): + if iterate_batch: + audios = [] + for i in range(audio.shape[0]): + audios.append(self.pretransform.encode(audio[i:i+1])) + audio = torch.cat(audios, dim=0) + else: + audio = self.pretransform.encode(audio) + + if self.encoder is not None: + if iterate_batch: + latents = [] + for i in range(audio.shape[0]): + latents.append(self.encoder(audio[i:i+1])) + latents = torch.cat(latents, dim=0) + else: + latents = self.encoder(audio) + else: + latents = audio + + if self.bottleneck is not None: + # TODO: Add iterate batch logic, needs to merge the info dicts + latents, bottleneck_info = self.bottleneck.encode(latents, return_info=True, **kwargs) + + info.update(bottleneck_info) + + if return_info: + return latents, info + + return latents + + def decode(self, latents, iterate_batch=False, **kwargs): + + if self.bottleneck is not None: + if iterate_batch: + decoded = [] + for i in range(latents.shape[0]): + decoded.append(self.bottleneck.decode(latents[i:i+1])) + decoded = torch.cat(decoded, dim=0) + else: + latents = self.bottleneck.decode(latents) + + if iterate_batch: + decoded = [] + for i in range(latents.shape[0]): + decoded.append(self.decoder(latents[i:i+1])) + decoded = torch.cat(decoded, dim=0) + else: + decoded = self.decoder(latents, **kwargs) + + if self.pretransform is not None: + if self.pretransform.enable_grad: + if iterate_batch: + decodeds = [] + for i in range(decoded.shape[0]): + decodeds.append(self.pretransform.decode(decoded[i:i+1])) + decoded = torch.cat(decodeds, dim=0) + else: + decoded = self.pretransform.decode(decoded) + else: + with torch.no_grad(): + if iterate_batch: + decodeds = [] + for i in range(latents.shape[0]): + decodeds.append(self.pretransform.decode(decoded[i:i+1])) + decoded = torch.cat(decodeds, dim=0) + else: + decoded = self.pretransform.decode(decoded) + + if self.soft_clip: + decoded = torch.tanh(decoded) + + return decoded + + def decode_tokens(self, tokens, **kwargs): + ''' + Decode discrete tokens to audio + Only works with discrete autoencoders + ''' + + assert isinstance(self.bottleneck, DiscreteBottleneck), "decode_tokens only works with discrete autoencoders" + + latents = self.bottleneck.decode_tokens(tokens, **kwargs) + + return self.decode(latents, **kwargs) + + + def preprocess_audio_for_encoder(self, audio, in_sr): + ''' + Preprocess single audio tensor (Channels x Length) to be compatible with the encoder. + If the model is mono, stereo audio will be converted to mono. + Audio will be silence-padded to be a multiple of the model's downsampling ratio. + Audio will be resampled to the model's sample rate. + The output will have batch size 1 and be shape (1 x Channels x Length) + ''' + return self.preprocess_audio_list_for_encoder([audio], [in_sr]) + + def preprocess_audio_list_for_encoder(self, audio_list, in_sr_list): + ''' + Preprocess a [list] of audio (Channels x Length) into a batch tensor to be compatable with the encoder. + The audio in that list can be of different lengths and channels. + in_sr can be an integer or list. If it's an integer it will be assumed it is the input sample_rate for every audio. + All audio will be resampled to the model's sample rate. + Audio will be silence-padded to the longest length, and further padded to be a multiple of the model's downsampling ratio. + If the model is mono, all audio will be converted to mono. + The output will be a tensor of shape (Batch x Channels x Length) + ''' + batch_size = len(audio_list) + if isinstance(in_sr_list, int): + in_sr_list = [in_sr_list]*batch_size + assert len(in_sr_list) == batch_size, "list of sample rates must be the same length of audio_list" + new_audio = [] + max_length = 0 + # resample & find the max length + for i in range(batch_size): + audio = audio_list[i] + in_sr = in_sr_list[i] + if len(audio.shape) == 3 and audio.shape[0] == 1: + # batchsize 1 was given by accident. Just squeeze it. + audio = audio.squeeze(0) + elif len(audio.shape) == 1: + # Mono signal, channel dimension is missing, unsqueeze it in + audio = audio.unsqueeze(0) + assert len(audio.shape)==2, "Audio should be shape (Channels x Length) with no batch dimension" + # Resample audio + if in_sr != self.sample_rate: + resample_tf = T.Resample(in_sr, self.sample_rate).to(audio.device) + audio = resample_tf(audio) + new_audio.append(audio) + if audio.shape[-1] > max_length: + max_length = audio.shape[-1] + # Pad every audio to the same length, multiple of model's downsampling ratio + padded_audio_length = max_length + (self.min_length - (max_length % self.min_length)) % self.min_length + for i in range(batch_size): + # Pad it & if necessary, mixdown/duplicate stereo/mono channels to support model + new_audio[i] = prepare_audio(new_audio[i], in_sr=in_sr, target_sr=in_sr, target_length=padded_audio_length, + target_channels=self.in_channels, device=new_audio[i].device).squeeze(0) + # convert to tensor + return torch.stack(new_audio) + + def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwargs): + ''' + Encode audios into latents. Audios should already be preprocesed by preprocess_audio_for_encoder. + If chunked is True, split the audio into chunks of a given maximum size chunk_size, with given overlap. + Overlap and chunk_size params are both measured in number of latents (not audio samples) + # and therefore you likely could use the same values with decode_audio. + A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size. + Every autoencoder will have a different receptive field size, and thus ideal overlap. + You can determine it empirically by diffing unchunked vs chunked output and looking at maximum diff. + The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks. + Smaller chunk_size uses less memory, but more compute. + The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version + For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks + ''' + if not chunked: + # default behavior. Encode the entire audio in parallel + return self.encode(audio, **kwargs) + else: + # CHUNKED ENCODING + # samples_per_latent is just the downsampling ratio (which is also the upsampling ratio) + samples_per_latent = self.downsampling_ratio + total_size = audio.shape[2] # in samples + batch_size = audio.shape[0] + chunk_size *= samples_per_latent # converting metric in latents to samples + overlap *= samples_per_latent # converting metric in latents to samples + hop_size = chunk_size - overlap + chunks = [] + for i in range(0, total_size - chunk_size + 1, hop_size): + chunk = audio[:,:,i:i+chunk_size] + chunks.append(chunk) + if i+chunk_size != total_size: + # Final chunk + chunk = audio[:,:,-chunk_size:] + chunks.append(chunk) + chunks = torch.stack(chunks) + num_chunks = chunks.shape[0] + # Note: y_size might be a different value from the latent length used in diffusion training + # because we can encode audio of varying lengths + # However, the audio should've been padded to a multiple of samples_per_latent by now. + y_size = total_size // samples_per_latent + # Create an empty latent, we will populate it with chunks as we encode them + y_final = torch.zeros((batch_size,self.latent_dim,y_size)).to(audio.device) + for i in range(num_chunks): + x_chunk = chunks[i,:] + # encode the chunk + y_chunk = self.encode(x_chunk) + # figure out where to put the audio along the time domain + if i == num_chunks-1: + # final chunk always goes at the end + t_end = y_size + t_start = t_end - y_chunk.shape[2] + else: + t_start = i * hop_size // samples_per_latent + t_end = t_start + chunk_size // samples_per_latent + # remove the edges of the overlaps + ol = overlap//samples_per_latent//2 + chunk_start = 0 + chunk_end = y_chunk.shape[2] + if i > 0: + # no overlap for the start of the first chunk + t_start += ol + chunk_start += ol + if i < num_chunks-1: + # no overlap for the end of the last chunk + t_end -= ol + chunk_end -= ol + # paste the chunked audio into our y_final output audio + y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end] + return y_final + + def decode_audio(self, latents, chunked=False, overlap=32, chunk_size=128, **kwargs): + ''' + Decode latents to audio. + If chunked is True, split the latents into chunks of a given maximum size chunk_size, with given overlap, both of which are measured in number of latents. + A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size. + Every autoencoder will have a different receptive field size, and thus ideal overlap. + You can determine it empirically by diffing unchunked vs chunked audio and looking at maximum diff. + The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks. + Smaller chunk_size uses less memory, but more compute. + The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version + For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks + ''' + if not chunked: + # default behavior. Decode the entire latent in parallel + return self.decode(latents, **kwargs) + else: + # chunked decoding + hop_size = chunk_size - overlap + total_size = latents.shape[2] + batch_size = latents.shape[0] + chunks = [] + for i in range(0, total_size - chunk_size + 1, hop_size): + chunk = latents[:,:,i:i+chunk_size] + chunks.append(chunk) + if i+chunk_size != total_size: + # Final chunk + chunk = latents[:,:,-chunk_size:] + chunks.append(chunk) + chunks = torch.stack(chunks) + num_chunks = chunks.shape[0] + # samples_per_latent is just the downsampling ratio + samples_per_latent = self.downsampling_ratio + # Create an empty waveform, we will populate it with chunks as decode them + y_size = total_size * samples_per_latent + y_final = torch.zeros((batch_size,self.out_channels,y_size)).to(latents.device) + for i in range(num_chunks): + x_chunk = chunks[i,:] + # decode the chunk + y_chunk = self.decode(x_chunk) + # figure out where to put the audio along the time domain + if i == num_chunks-1: + # final chunk always goes at the end + t_end = y_size + t_start = t_end - y_chunk.shape[2] + else: + t_start = i * hop_size * samples_per_latent + t_end = t_start + chunk_size * samples_per_latent + # remove the edges of the overlaps + ol = (overlap//2) * samples_per_latent + chunk_start = 0 + chunk_end = y_chunk.shape[2] + if i > 0: + # no overlap for the start of the first chunk + t_start += ol + chunk_start += ol + if i < num_chunks-1: + # no overlap for the end of the last chunk + t_end -= ol + chunk_end -= ol + # paste the chunked audio into our y_final output audio + y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end] + return y_final + + +class DiffusionAutoencoder(AudioAutoencoder): + def __init__( + self, + diffusion: ConditionedDiffusionModel, + diffusion_downsampling_ratio, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + + self.diffusion = diffusion + + self.min_length = self.downsampling_ratio * diffusion_downsampling_ratio + + if self.encoder is not None: + # Shrink the initial encoder parameters to avoid saturated latents + with torch.no_grad(): + for param in self.encoder.parameters(): + param *= 0.5 + + def decode(self, latents, steps=100): + + upsampled_length = latents.shape[2] * self.downsampling_ratio + + if self.bottleneck is not None: + latents = self.bottleneck.decode(latents) + + if self.decoder is not None: + latents = self.decode(latents) + + # Upsample latents to match diffusion length + if latents.shape[2] != upsampled_length: + latents = F.interpolate(latents, size=upsampled_length, mode='nearest') + + noise = torch.randn(latents.shape[0], self.io_channels, upsampled_length, device=latents.device) + decoded = sample(self.diffusion, noise, steps, 0, input_concat_cond=latents) + + if self.pretransform is not None: + if self.pretransform.enable_grad: + decoded = self.pretransform.decode(decoded) + else: + with torch.no_grad(): + decoded = self.pretransform.decode(decoded) + + return decoded + +# AE factories + +def create_encoder_from_config(encoder_config: Dict[str, Any]): + encoder_type = encoder_config.get("type", None) + assert encoder_type is not None, "Encoder type must be specified" + + if encoder_type == "oobleck": + encoder = OobleckEncoder( + **encoder_config["config"] + ) + + elif encoder_type == "seanet": + from encodec.modules import SEANetEncoder + seanet_encoder_config = encoder_config["config"] + + #SEANet encoder expects strides in reverse order + seanet_encoder_config["ratios"] = list(reversed(seanet_encoder_config.get("ratios", [2, 2, 2, 2, 2]))) + encoder = SEANetEncoder( + **seanet_encoder_config + ) + elif encoder_type == "dac": + dac_config = encoder_config["config"] + + encoder = DACEncoderWrapper(**dac_config) + elif encoder_type == "local_attn": + from .local_attention import TransformerEncoder1D + + local_attn_config = encoder_config["config"] + + encoder = TransformerEncoder1D( + **local_attn_config + ) + else: + raise ValueError(f"Unknown encoder type {encoder_type}") + + requires_grad = encoder_config.get("requires_grad", True) + if not requires_grad: + for param in encoder.parameters(): + param.requires_grad = False + + return encoder + +def create_decoder_from_config(decoder_config: Dict[str, Any]): + decoder_type = decoder_config.get("type", None) + assert decoder_type is not None, "Decoder type must be specified" + + if decoder_type == "oobleck": + decoder = OobleckDecoder( + **decoder_config["config"] + ) + elif decoder_type == "seanet": + from encodec.modules import SEANetDecoder + + decoder = SEANetDecoder( + **decoder_config["config"] + ) + elif decoder_type == "dac": + dac_config = decoder_config["config"] + + decoder = DACDecoderWrapper(**dac_config) + elif decoder_type == "local_attn": + from .local_attention import TransformerDecoder1D + + local_attn_config = decoder_config["config"] + + decoder = TransformerDecoder1D( + **local_attn_config + ) + else: + raise ValueError(f"Unknown decoder type {decoder_type}") + + requires_grad = decoder_config.get("requires_grad", True) + if not requires_grad: + for param in decoder.parameters(): + param.requires_grad = False + + return decoder + +def create_autoencoder_from_config(config: Dict[str, Any]): + + ae_config = config["model"] + + encoder = create_encoder_from_config(ae_config["encoder"]) + decoder = create_decoder_from_config(ae_config["decoder"]) + + bottleneck = ae_config.get("bottleneck", None) + + latent_dim = ae_config.get("latent_dim", None) + assert latent_dim is not None, "latent_dim must be specified in model config" + downsampling_ratio = ae_config.get("downsampling_ratio", None) + assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config" + io_channels = ae_config.get("io_channels", None) + assert io_channels is not None, "io_channels must be specified in model config" + sample_rate = config.get("sample_rate", None) + assert sample_rate is not None, "sample_rate must be specified in model config" + + in_channels = ae_config.get("in_channels", None) + out_channels = ae_config.get("out_channels", None) + + pretransform = ae_config.get("pretransform", None) + + if pretransform is not None: + pretransform = create_pretransform_from_config(pretransform, sample_rate) + + if bottleneck is not None: + bottleneck = create_bottleneck_from_config(bottleneck) + + soft_clip = ae_config["decoder"].get("soft_clip", False) + + return AudioAutoencoder( + encoder, + decoder, + io_channels=io_channels, + latent_dim=latent_dim, + downsampling_ratio=downsampling_ratio, + sample_rate=sample_rate, + bottleneck=bottleneck, + pretransform=pretransform, + in_channels=in_channels, + out_channels=out_channels, + soft_clip=soft_clip + ) + +def create_diffAE_from_config(config: Dict[str, Any]): + + diffae_config = config["model"] + + if "encoder" in diffae_config: + encoder = create_encoder_from_config(diffae_config["encoder"]) + else: + encoder = None + + if "decoder" in diffae_config: + decoder = create_decoder_from_config(diffae_config["decoder"]) + else: + decoder = None + + diffusion_model_type = diffae_config["diffusion"]["type"] + + if diffusion_model_type == "DAU1d": + diffusion = DAU1DCondWrapper(**diffae_config["diffusion"]["config"]) + elif diffusion_model_type == "adp_1d": + diffusion = UNet1DCondWrapper(**diffae_config["diffusion"]["config"]) + elif diffusion_model_type == "dit": + diffusion = DiTWrapper(**diffae_config["diffusion"]["config"]) + + latent_dim = diffae_config.get("latent_dim", None) + assert latent_dim is not None, "latent_dim must be specified in model config" + downsampling_ratio = diffae_config.get("downsampling_ratio", None) + assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config" + io_channels = diffae_config.get("io_channels", None) + assert io_channels is not None, "io_channels must be specified in model config" + sample_rate = config.get("sample_rate", None) + assert sample_rate is not None, "sample_rate must be specified in model config" + + bottleneck = diffae_config.get("bottleneck", None) + + pretransform = diffae_config.get("pretransform", None) + + if pretransform is not None: + pretransform = create_pretransform_from_config(pretransform, sample_rate) + + if bottleneck is not None: + bottleneck = create_bottleneck_from_config(bottleneck) + + diffusion_downsampling_ratio = None, + + if diffusion_model_type == "DAU1d": + diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["strides"]) + elif diffusion_model_type == "adp_1d": + diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["factors"]) + elif diffusion_model_type == "dit": + diffusion_downsampling_ratio = 1 + + return DiffusionAutoencoder( + encoder=encoder, + decoder=decoder, + diffusion=diffusion, + io_channels=io_channels, + sample_rate=sample_rate, + latent_dim=latent_dim, + downsampling_ratio=downsampling_ratio, + diffusion_downsampling_ratio=diffusion_downsampling_ratio, + bottleneck=bottleneck, + pretransform=pretransform + ) \ No newline at end of file diff --git a/stable_audio_tools/models/blocks.py b/stable_audio_tools/models/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..bbc0f6295037d8d8601cb616dc6cc4be030763d1 --- /dev/null +++ b/stable_audio_tools/models/blocks.py @@ -0,0 +1,339 @@ +from functools import reduce +import math +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +from torch.backends.cuda import sdp_kernel +from packaging import version + +from dac.nn.layers import Snake1d + +class ResidualBlock(nn.Module): + def __init__(self, main, skip=None): + super().__init__() + self.main = nn.Sequential(*main) + self.skip = skip if skip else nn.Identity() + + def forward(self, input): + return self.main(input) + self.skip(input) + +class ResConvBlock(ResidualBlock): + def __init__(self, c_in, c_mid, c_out, is_last=False, kernel_size=5, conv_bias=True, use_snake=False): + skip = None if c_in == c_out else nn.Conv1d(c_in, c_out, 1, bias=False) + super().__init__([ + nn.Conv1d(c_in, c_mid, kernel_size, padding=kernel_size//2, bias=conv_bias), + nn.GroupNorm(1, c_mid), + Snake1d(c_mid) if use_snake else nn.GELU(), + nn.Conv1d(c_mid, c_out, kernel_size, padding=kernel_size//2, bias=conv_bias), + nn.GroupNorm(1, c_out) if not is_last else nn.Identity(), + (Snake1d(c_out) if use_snake else nn.GELU()) if not is_last else nn.Identity(), + ], skip) + +class SelfAttention1d(nn.Module): + def __init__(self, c_in, n_head=1, dropout_rate=0.): + super().__init__() + assert c_in % n_head == 0 + self.norm = nn.GroupNorm(1, c_in) + self.n_head = n_head + self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1) + self.out_proj = nn.Conv1d(c_in, c_in, 1) + self.dropout = nn.Dropout(dropout_rate, inplace=True) + + self.use_flash = False + + if not self.use_flash: + return + + device_properties = torch.cuda.get_device_properties(torch.device('cuda')) + + if device_properties.major == 8 and device_properties.minor == 0: + # Use flash attention for A100 GPUs + self.sdp_kernel_config = (False, True, True) + else: + # Don't use flash attention for other GPUs + self.sdp_kernel_config = (False, True, True) + + def forward(self, input): + n, c, s = input.shape + qkv = self.qkv_proj(self.norm(input)) + qkv = qkv.view( + [n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3) + q, k, v = qkv.chunk(3, dim=1) + scale = k.shape[3]**-0.25 + + if self.use_flash: + with sdp_kernel(*self.sdp_kernel_config): + y = F.scaled_dot_product_attention(q, k, v, is_causal=False).contiguous().view([n, c, s]) + else: + att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3) + y = (att @ v).transpose(2, 3).contiguous().view([n, c, s]) + + + return input + self.dropout(self.out_proj(y)) + +class SkipBlock(nn.Module): + def __init__(self, *main): + super().__init__() + self.main = nn.Sequential(*main) + + def forward(self, input): + return torch.cat([self.main(input), input], dim=1) + +class FourierFeatures(nn.Module): + def __init__(self, in_features, out_features, std=1.): + super().__init__() + assert out_features % 2 == 0 + self.weight = nn.Parameter(torch.randn( + [out_features // 2, in_features]) * std) + + def forward(self, input): + f = 2 * math.pi * input @ self.weight.T + return torch.cat([f.cos(), f.sin()], dim=-1) + +def expand_to_planes(input, shape): + return input[..., None].repeat([1, 1, shape[2]]) + +_kernels = { + 'linear': + [1 / 8, 3 / 8, 3 / 8, 1 / 8], + 'cubic': + [-0.01171875, -0.03515625, 0.11328125, 0.43359375, + 0.43359375, 0.11328125, -0.03515625, -0.01171875], + 'lanczos3': + [0.003689131001010537, 0.015056144446134567, -0.03399861603975296, + -0.066637322306633, 0.13550527393817902, 0.44638532400131226, + 0.44638532400131226, 0.13550527393817902, -0.066637322306633, + -0.03399861603975296, 0.015056144446134567, 0.003689131001010537] +} + +class Downsample1d(nn.Module): + def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor(_kernels[kernel]) + self.pad = kernel_1d.shape[0] // 2 - 1 + self.register_buffer('kernel', kernel_1d) + self.channels_last = channels_last + + def forward(self, x): + if self.channels_last: + x = x.permute(0, 2, 1) + x = F.pad(x, (self.pad,) * 2, self.pad_mode) + weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]]) + indices = torch.arange(x.shape[1], device=x.device) + weight[indices, indices] = self.kernel.to(weight) + x = F.conv1d(x, weight, stride=2) + if self.channels_last: + x = x.permute(0, 2, 1) + return x + + +class Upsample1d(nn.Module): + def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor(_kernels[kernel]) * 2 + self.pad = kernel_1d.shape[0] // 2 - 1 + self.register_buffer('kernel', kernel_1d) + self.channels_last = channels_last + + def forward(self, x): + if self.channels_last: + x = x.permute(0, 2, 1) + x = F.pad(x, ((self.pad + 1) // 2,) * 2, self.pad_mode) + weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]]) + indices = torch.arange(x.shape[1], device=x.device) + weight[indices, indices] = self.kernel.to(weight) + x = F.conv_transpose1d(x, weight, stride=2, padding=self.pad * 2 + 1) + if self.channels_last: + x = x.permute(0, 2, 1) + return x + +def Downsample1d_2( + in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2 +) -> nn.Module: + assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even" + + return nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=factor * kernel_multiplier + 1, + stride=factor, + padding=factor * (kernel_multiplier // 2), + ) + + +def Upsample1d_2( + in_channels: int, out_channels: int, factor: int, use_nearest: bool = False +) -> nn.Module: + + if factor == 1: + return nn.Conv1d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1 + ) + + if use_nearest: + return nn.Sequential( + nn.Upsample(scale_factor=factor, mode="nearest"), + nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + ), + ) + else: + return nn.ConvTranspose1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=factor * 2, + stride=factor, + padding=factor // 2 + factor % 2, + output_padding=factor % 2, + ) + +def zero_init(layer): + nn.init.zeros_(layer.weight) + if layer.bias is not None: + nn.init.zeros_(layer.bias) + return layer + +def rms_norm(x, scale, eps): + dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32)) + mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True) + scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps) + return x * scale.to(x.dtype) + +rms_norm = torch.compile(rms_norm) + +class AdaRMSNorm(nn.Module): + def __init__(self, features, cond_features, eps=1e-6): + super().__init__() + self.eps = eps + self.linear = zero_init(nn.Linear(cond_features, features, bias=False)) + + def extra_repr(self): + return f"eps={self.eps}," + + def forward(self, x, cond): + return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps) + +def normalize(x, eps=1e-4): + dim = list(range(1, x.ndim)) + n = torch.linalg.vector_norm(x, dim=dim, keepdim=True) + alpha = np.sqrt(n.numel() / x.numel()) + return x / torch.add(eps, n, alpha=alpha) + +class ForcedWNConv1d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=1): + super().__init__() + self.weight = nn.Parameter(torch.randn([out_channels, in_channels, kernel_size])) + + def forward(self, x): + if self.training: + with torch.no_grad(): + self.weight.copy_(normalize(self.weight)) + + fan_in = self.weight[0].numel() + + w = normalize(self.weight) / math.sqrt(fan_in) + + return F.conv1d(x, w, padding='same') + +# Kernels + +use_compile = True + +def compile(function, *args, **kwargs): + if not use_compile: + return function + try: + return torch.compile(function, *args, **kwargs) + except RuntimeError: + return function + + +@compile +def linear_geglu(x, weight, bias=None): + x = x @ weight.mT + if bias is not None: + x = x + bias + x, gate = x.chunk(2, dim=-1) + return x * F.gelu(gate) + + +@compile +def rms_norm(x, scale, eps): + dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32)) + mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True) + scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps) + return x * scale.to(x.dtype) + +# Layers + +class LinearGEGLU(nn.Linear): + def __init__(self, in_features, out_features, bias=True): + super().__init__(in_features, out_features * 2, bias=bias) + self.out_features = out_features + + def forward(self, x): + return linear_geglu(x, self.weight, self.bias) + + +class RMSNorm(nn.Module): + def __init__(self, shape, fix_scale = False, eps=1e-6): + super().__init__() + self.eps = eps + + if fix_scale: + self.register_buffer("scale", torch.ones(shape)) + else: + self.scale = nn.Parameter(torch.ones(shape)) + + def extra_repr(self): + return f"shape={tuple(self.scale.shape)}, eps={self.eps}" + + def forward(self, x): + return rms_norm(x, self.scale, self.eps) + +def snake_beta(x, alpha, beta): + return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2) + +try: + snake_beta = torch.compile(snake_beta) +except RuntimeError: + pass + +# Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license +# License available in LICENSES/LICENSE_NVIDIA.txt +class SnakeBeta(nn.Module): + + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True): + super(SnakeBeta, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = nn.Parameter(torch.zeros(in_features) * alpha) + self.beta = nn.Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = nn.Parameter(torch.ones(in_features) * alpha) + self.beta = nn.Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + x = snake_beta(x, alpha, beta) + + return x \ No newline at end of file diff --git a/stable_audio_tools/models/bottleneck.py b/stable_audio_tools/models/bottleneck.py new file mode 100644 index 0000000000000000000000000000000000000000..d7e8ac2a1f95783073adb226dc0905293887a02a --- /dev/null +++ b/stable_audio_tools/models/bottleneck.py @@ -0,0 +1,326 @@ +import torch +from torch import nn +from torch.nn import functional as F + +from einops import rearrange +from vector_quantize_pytorch import ResidualVQ, FSQ +from dac.nn.quantize import ResidualVectorQuantize as DACResidualVQ + +class Bottleneck(nn.Module): + def __init__(self, is_discrete: bool = False): + super().__init__() + + self.is_discrete = is_discrete + + def encode(self, x, return_info=False, **kwargs): + raise NotImplementedError + + def decode(self, x): + raise NotImplementedError + +class DiscreteBottleneck(Bottleneck): + def __init__(self, num_quantizers, codebook_size, tokens_id): + super().__init__(is_discrete=True) + + self.num_quantizers = num_quantizers + self.codebook_size = codebook_size + self.tokens_id = tokens_id + + def decode_tokens(self, codes, **kwargs): + raise NotImplementedError + +class TanhBottleneck(Bottleneck): + def __init__(self): + super().__init__(is_discrete=False) + self.tanh = nn.Tanh() + + def encode(self, x, return_info=False): + info = {} + + x = torch.tanh(x) + + if return_info: + return x, info + else: + return x + + def decode(self, x): + return x + +def vae_sample(mean, scale): + stdev = nn.functional.softplus(scale) + 1e-4 + var = stdev * stdev + logvar = torch.log(var) + latents = torch.randn_like(mean) * stdev + mean + + kl = (mean * mean + var - logvar - 1).sum(1).mean() + + return latents, kl + +class VAEBottleneck(Bottleneck): + def __init__(self): + super().__init__(is_discrete=False) + + def encode(self, x, return_info=False, **kwargs): + info = {} + + mean, scale = x.chunk(2, dim=1) + + x, kl = vae_sample(mean, scale) + + info["kl"] = kl + + if return_info: + return x, info + else: + return x + + def decode(self, x): + return x + +def compute_mean_kernel(x, y): + kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1] + return torch.exp(-kernel_input).mean() + +def compute_mmd(latents): + latents_reshaped = latents.permute(0, 2, 1).reshape(-1, latents.shape[1]) + noise = torch.randn_like(latents_reshaped) + + latents_kernel = compute_mean_kernel(latents_reshaped, latents_reshaped) + noise_kernel = compute_mean_kernel(noise, noise) + latents_noise_kernel = compute_mean_kernel(latents_reshaped, noise) + + mmd = latents_kernel + noise_kernel - 2 * latents_noise_kernel + return mmd.mean() + +class WassersteinBottleneck(Bottleneck): + def __init__(self, noise_augment_dim: int = 0): + super().__init__(is_discrete=False) + + self.noise_augment_dim = noise_augment_dim + + def encode(self, x, return_info=False): + info = {} + + if self.training and return_info: + mmd = compute_mmd(x) + info["mmd"] = mmd + + if return_info: + return x, info + + return x + + def decode(self, x): + + if self.noise_augment_dim > 0: + noise = torch.randn(x.shape[0], self.noise_augment_dim, + x.shape[-1]).type_as(x) + x = torch.cat([x, noise], dim=1) + + return x + +class L2Bottleneck(Bottleneck): + def __init__(self): + super().__init__(is_discrete=False) + + def encode(self, x, return_info=False): + info = {} + + x = F.normalize(x, dim=1) + + if return_info: + return x, info + else: + return x + + def decode(self, x): + return F.normalize(x, dim=1) + +class RVQBottleneck(DiscreteBottleneck): + def __init__(self, **quantizer_kwargs): + super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices") + self.quantizer = ResidualVQ(**quantizer_kwargs) + self.num_quantizers = quantizer_kwargs["num_quantizers"] + + def encode(self, x, return_info=False, **kwargs): + info = {} + + x = rearrange(x, "b c n -> b n c") + x, indices, loss = self.quantizer(x) + x = rearrange(x, "b n c -> b c n") + + info["quantizer_indices"] = indices + info["quantizer_loss"] = loss.mean() + + if return_info: + return x, info + else: + return x + + def decode(self, x): + return x + + def decode_tokens(self, codes, **kwargs): + latents = self.quantizer.get_outputs_from_indices(codes) + + return self.decode(latents, **kwargs) + +class RVQVAEBottleneck(DiscreteBottleneck): + def __init__(self, **quantizer_kwargs): + super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices") + self.quantizer = ResidualVQ(**quantizer_kwargs) + self.num_quantizers = quantizer_kwargs["num_quantizers"] + + def encode(self, x, return_info=False): + info = {} + + x, kl = vae_sample(*x.chunk(2, dim=1)) + + info["kl"] = kl + + x = rearrange(x, "b c n -> b n c") + x, indices, loss = self.quantizer(x) + x = rearrange(x, "b n c -> b c n") + + info["quantizer_indices"] = indices + info["quantizer_loss"] = loss.mean() + + if return_info: + return x, info + else: + return x + + def decode(self, x): + return x + + def decode_tokens(self, codes, **kwargs): + latents = self.quantizer.get_outputs_from_indices(codes) + + return self.decode(latents, **kwargs) + +class DACRVQBottleneck(DiscreteBottleneck): + def __init__(self, quantize_on_decode=False, **quantizer_kwargs): + super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes") + self.quantizer = DACResidualVQ(**quantizer_kwargs) + self.num_quantizers = quantizer_kwargs["n_codebooks"] + self.quantize_on_decode = quantize_on_decode + + def encode(self, x, return_info=False, **kwargs): + info = {} + + info["pre_quantizer"] = x + + if self.quantize_on_decode: + return x, info if return_info else x + + z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, **kwargs) + + output = { + "z": z, + "codes": codes, + "latents": latents, + "vq/commitment_loss": commitment_loss, + "vq/codebook_loss": codebook_loss, + } + + output["vq/commitment_loss"] /= self.num_quantizers + output["vq/codebook_loss"] /= self.num_quantizers + + info.update(output) + + if return_info: + return output["z"], info + + return output["z"] + + def decode(self, x): + + if self.quantize_on_decode: + x = self.quantizer(x)[0] + + return x + + def decode_tokens(self, codes, **kwargs): + latents, _, _ = self.quantizer.from_codes(codes) + + return self.decode(latents, **kwargs) + +class DACRVQVAEBottleneck(DiscreteBottleneck): + def __init__(self, quantize_on_decode=False, **quantizer_kwargs): + super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes") + self.quantizer = DACResidualVQ(**quantizer_kwargs) + self.num_quantizers = quantizer_kwargs["n_codebooks"] + self.quantize_on_decode = quantize_on_decode + + def encode(self, x, return_info=False, n_quantizers: int = None): + info = {} + + mean, scale = x.chunk(2, dim=1) + + x, kl = vae_sample(mean, scale) + + info["pre_quantizer"] = x + info["kl"] = kl + + if self.quantize_on_decode: + return x, info if return_info else x + + z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, n_quantizers=n_quantizers) + + output = { + "z": z, + "codes": codes, + "latents": latents, + "vq/commitment_loss": commitment_loss, + "vq/codebook_loss": codebook_loss, + } + + output["vq/commitment_loss"] /= self.num_quantizers + output["vq/codebook_loss"] /= self.num_quantizers + + info.update(output) + + if return_info: + return output["z"], info + + return output["z"] + + def decode(self, x): + + if self.quantize_on_decode: + x = self.quantizer(x)[0] + + return x + + def decode_tokens(self, codes, **kwargs): + latents, _, _ = self.quantizer.from_codes(codes) + + return self.decode(latents, **kwargs) + +class FSQBottleneck(DiscreteBottleneck): + def __init__(self, dim, levels): + super().__init__(num_quantizers = 1, codebook_size = levels ** dim, tokens_id = "quantizer_indices") + self.quantizer = FSQ(levels=[levels] * dim) + + def encode(self, x, return_info=False): + info = {} + + x = rearrange(x, "b c n -> b n c") + x, indices = self.quantizer(x) + x = rearrange(x, "b n c -> b c n") + + info["quantizer_indices"] = indices + + if return_info: + return x, info + else: + return x + + def decode(self, x): + return x + + def decode_tokens(self, tokens, **kwargs): + latents = self.quantizer.indices_to_codes(tokens) + + return self.decode(latents, **kwargs) \ No newline at end of file diff --git a/stable_audio_tools/models/conditioners.py b/stable_audio_tools/models/conditioners.py new file mode 100644 index 0000000000000000000000000000000000000000..567a4fc6ac0516e10679e8aa499027520d3c6917 --- /dev/null +++ b/stable_audio_tools/models/conditioners.py @@ -0,0 +1,558 @@ +#Heavily influenced by https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conditioners.py + +import torch +import logging, warnings +import string +import typing as tp +import gc + +from .adp import NumberEmbedder +from ..inference.utils import set_audio_channels +from .factory import create_pretransform_from_config +from .pretransforms import Pretransform +from ..training.utils import copy_state_dict +from .utils import load_ckpt_state_dict + +from torch import nn + +class Conditioner(nn.Module): + def __init__( + self, + dim: int, + output_dim: int, + project_out: bool = False, + ): + + super().__init__() + + self.dim = dim + self.output_dim = output_dim + self.proj_out = nn.Linear(dim, output_dim) if (dim != output_dim or project_out) else nn.Identity() + + def forward(self, x: tp.Any) -> tp.Any: + raise NotImplementedError() + +class IntConditioner(Conditioner): + def __init__(self, + output_dim: int, + min_val: int=0, + max_val: int=512 + ): + super().__init__(output_dim, output_dim) + + self.min_val = min_val + self.max_val = max_val + self.int_embedder = nn.Embedding(max_val - min_val + 1, output_dim).requires_grad_(True) + + def forward(self, ints: tp.List[int], device=None) -> tp.Any: + + #self.int_embedder.to(device) + + ints = torch.tensor(ints).to(device) + ints = ints.clamp(self.min_val, self.max_val) + + int_embeds = self.int_embedder(ints).unsqueeze(1) + + return [int_embeds, torch.ones(int_embeds.shape[0], 1).to(device)] + +class NumberConditioner(Conditioner): + ''' + Conditioner that takes a list of floats, normalizes them for a given range, and returns a list of embeddings + ''' + def __init__(self, + output_dim: int, + min_val: float=0, + max_val: float=1 + ): + super().__init__(output_dim, output_dim) + + self.min_val = min_val + self.max_val = max_val + + self.embedder = NumberEmbedder(features=output_dim) + + def forward(self, floats: tp.List[float], device=None) -> tp.Any: + + # Cast the inputs to floats + floats = [float(x) for x in floats] + + floats = torch.tensor(floats).to(device) + + floats = floats.clamp(self.min_val, self.max_val) + + normalized_floats = (floats - self.min_val) / (self.max_val - self.min_val) + + # Cast floats to same type as embedder + embedder_dtype = next(self.embedder.parameters()).dtype + normalized_floats = normalized_floats.to(embedder_dtype) + + float_embeds = self.embedder(normalized_floats).unsqueeze(1) + + return [float_embeds, torch.ones(float_embeds.shape[0], 1).to(device)] + +class CLAPTextConditioner(Conditioner): + def __init__(self, + output_dim: int, + clap_ckpt_path, + use_text_features = False, + feature_layer_ix: int = -1, + audio_model_type="HTSAT-base", + enable_fusion=True, + project_out: bool = False, + finetune: bool = False): + super().__init__(768 if use_text_features else 512, output_dim, project_out=project_out) + + self.use_text_features = use_text_features + self.feature_layer_ix = feature_layer_ix + self.finetune = finetune + + # Suppress logging from transformers + previous_level = logging.root.manager.disable + logging.disable(logging.ERROR) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + import laion_clap + from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict + + model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device='cpu') + + if self.finetune: + self.model = model + else: + self.__dict__["model"] = model + + state_dict = clap_load_state_dict(clap_ckpt_path) + self.model.model.load_state_dict(state_dict, strict=False) + + if self.finetune: + self.model.model.text_branch.requires_grad_(True) + self.model.model.text_branch.train() + else: + self.model.model.text_branch.requires_grad_(False) + self.model.model.text_branch.eval() + + finally: + logging.disable(previous_level) + + del self.model.model.audio_branch + + gc.collect() + torch.cuda.empty_cache() + + def get_clap_features(self, prompts, layer_ix=-2, device: tp.Any = "cuda"): + prompt_tokens = self.model.tokenizer(prompts) + attention_mask = prompt_tokens["attention_mask"].to(device=device, non_blocking=True) + prompt_features = self.model.model.text_branch( + input_ids=prompt_tokens["input_ids"].to(device=device, non_blocking=True), + attention_mask=attention_mask, + output_hidden_states=True + )["hidden_states"][layer_ix] + + return prompt_features, attention_mask + + def forward(self, texts: tp.List[str], device: tp.Any = "cuda") -> tp.Any: + self.model.to(device) + + if self.use_text_features: + if len(texts) == 1: + text_features, text_attention_mask = self.get_clap_features([texts[0], ""], layer_ix=self.feature_layer_ix, device=device) + text_features = text_features[:1, ...] + text_attention_mask = text_attention_mask[:1, ...] + else: + text_features, text_attention_mask = self.get_clap_features(texts, layer_ix=self.feature_layer_ix, device=device) + return [self.proj_out(text_features), text_attention_mask] + + # Fix for CLAP bug when only one text is passed + if len(texts) == 1: + text_embedding = self.model.get_text_embedding([texts[0], ""], use_tensor=True)[:1, ...] + else: + text_embedding = self.model.get_text_embedding(texts, use_tensor=True) + + text_embedding = text_embedding.unsqueeze(1).to(device) + + return [self.proj_out(text_embedding), torch.ones(text_embedding.shape[0], 1).to(device)] + +class CLAPAudioConditioner(Conditioner): + def __init__(self, + output_dim: int, + clap_ckpt_path, + audio_model_type="HTSAT-base", + enable_fusion=True, + project_out: bool = False): + super().__init__(512, output_dim, project_out=project_out) + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # Suppress logging from transformers + previous_level = logging.root.manager.disable + logging.disable(logging.ERROR) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + import laion_clap + from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict + + model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device='cpu') + + if self.finetune: + self.model = model + else: + self.__dict__["model"] = model + + state_dict = clap_load_state_dict(clap_ckpt_path) + self.model.model.load_state_dict(state_dict, strict=False) + + if self.finetune: + self.model.model.audio_branch.requires_grad_(True) + self.model.model.audio_branch.train() + else: + self.model.model.audio_branch.requires_grad_(False) + self.model.model.audio_branch.eval() + + finally: + logging.disable(previous_level) + + del self.model.model.text_branch + + gc.collect() + torch.cuda.empty_cache() + + def forward(self, audios: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]] , device: tp.Any = "cuda") -> tp.Any: + + self.model.to(device) + + if isinstance(audios, list) or isinstance(audios, tuple): + audios = torch.cat(audios, dim=0) + + # Convert to mono + mono_audios = audios.mean(dim=1) + + with torch.cuda.amp.autocast(enabled=False): + audio_embedding = self.model.get_audio_embedding_from_data(mono_audios.float(), use_tensor=True) + + audio_embedding = audio_embedding.unsqueeze(1).to(device) + + return [self.proj_out(audio_embedding), torch.ones(audio_embedding.shape[0], 1).to(device)] + +class T5Conditioner(Conditioner): + + T5_MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b", + "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large", + "google/flan-t5-xl", "google/flan-t5-xxl"] + + T5_MODEL_DIMS = { + "t5-small": 512, + "t5-base": 768, + "t5-large": 1024, + "t5-3b": 1024, + "t5-11b": 1024, + "t5-xl": 2048, + "t5-xxl": 4096, + "google/flan-t5-small": 512, + "google/flan-t5-base": 768, + "google/flan-t5-large": 1024, + "google/flan-t5-3b": 1024, + "google/flan-t5-11b": 1024, + "google/flan-t5-xl": 2048, + "google/flan-t5-xxl": 4096, + } + + def __init__( + self, + output_dim: int, + t5_model_name: str = "t5-base", + max_length: str = 128, + enable_grad: bool = False, + project_out: bool = False, + ): + assert t5_model_name in self.T5_MODELS, f"Unknown T5 model name: {t5_model_name}" + super().__init__(self.T5_MODEL_DIMS[t5_model_name], output_dim, project_out=project_out) + + from transformers import T5EncoderModel, AutoTokenizer + + self.max_length = max_length + self.enable_grad = enable_grad + + # Suppress logging from transformers + previous_level = logging.root.manager.disable + logging.disable(logging.ERROR) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + # self.tokenizer = T5Tokenizer.from_pretrained(t5_model_name, model_max_length = max_length) + # model = T5EncoderModel.from_pretrained(t5_model_name, max_length=max_length).train(enable_grad).requires_grad_(enable_grad) + self.tokenizer = AutoTokenizer.from_pretrained(t5_model_name) + model = T5EncoderModel.from_pretrained(t5_model_name).train(enable_grad).requires_grad_(enable_grad).to(torch.float16) + finally: + logging.disable(previous_level) + + if self.enable_grad: + self.model = model + else: + self.__dict__["model"] = model + + + def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: + + self.model.to(device) + self.proj_out.to(device) + + encoded = self.tokenizer( + texts, + truncation=True, + max_length=self.max_length, + padding="max_length", + return_tensors="pt", + ) + + input_ids = encoded["input_ids"].to(device) + attention_mask = encoded["attention_mask"].to(device).to(torch.bool) + + self.model.eval() + + with torch.cuda.amp.autocast(dtype=torch.float16) and torch.set_grad_enabled(self.enable_grad): + embeddings = self.model( + input_ids=input_ids, attention_mask=attention_mask + )["last_hidden_state"] + + embeddings = self.proj_out(embeddings.float()) + + embeddings = embeddings * attention_mask.unsqueeze(-1).float() + + return embeddings, attention_mask + +class PhonemeConditioner(Conditioner): + """ + A conditioner that turns text into phonemes and embeds them using a lookup table + Only works for English text + + Args: + output_dim: the dimension of the output embeddings + max_length: the maximum number of phonemes to embed + project_out: whether to add another linear projection to the output embeddings + """ + + def __init__( + self, + output_dim: int, + max_length: int = 1024, + project_out: bool = False, + ): + super().__init__(output_dim, output_dim, project_out=project_out) + + from g2p_en import G2p + + self.max_length = max_length + + self.g2p = G2p() + + # Reserving 0 for padding, 1 for ignored + self.phoneme_embedder = nn.Embedding(len(self.g2p.phonemes) + 2, output_dim) + + def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: + + self.phoneme_embedder.to(device) + self.proj_out.to(device) + + batch_phonemes = [self.g2p(text) for text in texts] # shape [batch_size, length] + + phoneme_ignore = [" ", *string.punctuation] + + # Remove ignored phonemes and cut to max length + batch_phonemes = [[p if p not in phoneme_ignore else "_" for p in phonemes] for phonemes in batch_phonemes] + + # Convert to ids + phoneme_ids = [[self.g2p.p2idx[p] + 2 if p in self.g2p.p2idx else 1 for p in phonemes] for phonemes in batch_phonemes] + + #Pad to match longest and make a mask tensor for the padding + longest = max([len(ids) for ids in phoneme_ids]) + phoneme_ids = [ids + [0] * (longest - len(ids)) for ids in phoneme_ids] + + phoneme_ids = torch.tensor(phoneme_ids).to(device) + + # Convert to embeddings + phoneme_embeds = self.phoneme_embedder(phoneme_ids) + + phoneme_embeds = self.proj_out(phoneme_embeds) + + return phoneme_embeds, torch.ones(phoneme_embeds.shape[0], phoneme_embeds.shape[1]).to(device) + +class TokenizerLUTConditioner(Conditioner): + """ + A conditioner that embeds text using a lookup table on a pretrained tokenizer's vocabulary + + Args: + tokenizer_name: the name of the tokenizer from the Hugging Face transformers library + output_dim: the dimension of the output embeddings + max_length: the maximum length of the text to embed + project_out: whether to add another linear projection to the output embeddings + """ + + def __init__( + self, + tokenizer_name: str, # Name of a tokenizer from the Hugging Face transformers library + output_dim: int, + max_length: int = 1024, + project_out: bool = False, + ): + super().__init__(output_dim, output_dim, project_out=project_out) + + from transformers import AutoTokenizer + + # Suppress logging from transformers + previous_level = logging.root.manager.disable + logging.disable(logging.ERROR) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + finally: + logging.disable(previous_level) + + self.max_length = max_length + + self.token_embedder = nn.Embedding(len(self.tokenizer), output_dim) + + def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: + self.proj_out.to(device) + + encoded = self.tokenizer( + texts, + truncation=True, + max_length=self.max_length, + padding="max_length", + return_tensors="pt", + ) + + input_ids = encoded["input_ids"].to(device) + attention_mask = encoded["attention_mask"].to(device).to(torch.bool) + + embeddings = self.token_embedder(input_ids) + + embeddings = self.proj_out(embeddings) + + embeddings = embeddings * attention_mask.unsqueeze(-1).float() + + return embeddings, attention_mask + +class PretransformConditioner(Conditioner): + """ + A conditioner that uses a pretransform's encoder for conditioning + + Args: + pretransform: an instantiated pretransform to use for conditioning + output_dim: the dimension of the output embeddings + """ + def __init__(self, pretransform: Pretransform, output_dim: int): + super().__init__(pretransform.encoded_channels, output_dim) + + self.pretransform = pretransform + + def forward(self, audio: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: + + self.pretransform.to(device) + self.proj_out.to(device) + + if isinstance(audio, list) or isinstance(audio, tuple): + audio = torch.cat(audio, dim=0) + + # Convert audio to pretransform input channels + audio = set_audio_channels(audio, self.pretransform.io_channels) + + latents = self.pretransform.encode(audio) + + latents = self.proj_out(latents) + + return [latents, torch.ones(latents.shape[0], latents.shape[2]).to(latents.device)] + +class MultiConditioner(nn.Module): + """ + A module that applies multiple conditioners to an input dictionary based on the keys + + Args: + conditioners: a dictionary of conditioners with keys corresponding to the keys of the conditioning input dictionary (e.g. "prompt") + default_keys: a dictionary of default keys to use if the key is not in the input dictionary (e.g. {"prompt_t5": "prompt"}) + """ + def __init__(self, conditioners: tp.Dict[str, Conditioner], default_keys: tp.Dict[str, str] = {}): + super().__init__() + + self.conditioners = nn.ModuleDict(conditioners) + self.default_keys = default_keys + + def forward(self, batch_metadata: tp.List[tp.Dict[str, tp.Any]], device: tp.Union[torch.device, str]) -> tp.Dict[str, tp.Any]: + output = {} + + for key, conditioner in self.conditioners.items(): + condition_key = key + + conditioner_inputs = [] + + for x in batch_metadata: + + if condition_key not in x: + if condition_key in self.default_keys: + condition_key = self.default_keys[condition_key] + else: + raise ValueError(f"Conditioner key {condition_key} not found in batch metadata") + + #Unwrap the condition info if it's a single-element list or tuple, this is to support collation functions that wrap everything in a list + if isinstance(x[condition_key], list) or isinstance(x[condition_key], tuple) and len(x[condition_key]) == 1: + conditioner_inputs.append(x[condition_key][0]) + else: + conditioner_inputs.append(x[condition_key]) + + output[key] = conditioner(conditioner_inputs, device) + + return output + +def create_multi_conditioner_from_conditioning_config(config: tp.Dict[str, tp.Any]) -> MultiConditioner: + """ + Create a MultiConditioner from a conditioning config dictionary + + Args: + config: the conditioning config dictionary + device: the device to put the conditioners on + """ + conditioners = {} + cond_dim = config["cond_dim"] + + default_keys = config.get("default_keys", {}) + + for conditioner_info in config["configs"]: + id = conditioner_info["id"] + + conditioner_type = conditioner_info["type"] + + conditioner_config = {"output_dim": cond_dim} + + conditioner_config.update(conditioner_info["config"]) + + if conditioner_type == "t5": + conditioners[id] = T5Conditioner(**conditioner_config) + elif conditioner_type == "clap_text": + conditioners[id] = CLAPTextConditioner(**conditioner_config) + elif conditioner_type == "clap_audio": + conditioners[id] = CLAPAudioConditioner(**conditioner_config) + elif conditioner_type == "int": + conditioners[id] = IntConditioner(**conditioner_config) + elif conditioner_type == "number": + conditioners[id] = NumberConditioner(**conditioner_config) + elif conditioner_type == "phoneme": + conditioners[id] = PhonemeConditioner(**conditioner_config) + elif conditioner_type == "lut": + conditioners[id] = TokenizerLUTConditioner(**conditioner_config) + elif conditioner_type == "pretransform": + sample_rate = conditioner_config.pop("sample_rate", None) + assert sample_rate is not None, "Sample rate must be specified for pretransform conditioners" + + pretransform = create_pretransform_from_config(conditioner_config.pop("pretransform_config"), sample_rate=sample_rate) + + if conditioner_config.get("pretransform_ckpt_path", None) is not None: + pretransform.load_state_dict(load_ckpt_state_dict(conditioner_config.pop("pretransform_ckpt_path"))) + + conditioners[id] = PretransformConditioner(pretransform, **conditioner_config) + else: + raise ValueError(f"Unknown conditioner type: {conditioner_type}") + + return MultiConditioner(conditioners, default_keys=default_keys) \ No newline at end of file diff --git a/stable_audio_tools/models/diffusion.py b/stable_audio_tools/models/diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..3c70edbf035b5084dc597344da9b7813c4656df2 --- /dev/null +++ b/stable_audio_tools/models/diffusion.py @@ -0,0 +1,678 @@ +import torch +from torch import nn +from torch.nn import functional as F +from functools import partial, reduce +import numpy as np +import typing as tp + +from .blocks import ResConvBlock, FourierFeatures, Upsample1d, Upsample1d_2, Downsample1d, Downsample1d_2, SelfAttention1d, SkipBlock, expand_to_planes +from .conditioners import MultiConditioner, create_multi_conditioner_from_conditioning_config +from .dit import DiffusionTransformer +from .factory import create_pretransform_from_config +from .pretransforms import Pretransform +from ..inference.generation import generate_diffusion_cond + +from .adp import UNetCFG1d, UNet1d + +from time import time + +class Profiler: + + def __init__(self): + self.ticks = [[time(), None]] + + def tick(self, msg): + self.ticks.append([time(), msg]) + + def __repr__(self): + rep = 80 * "=" + "\n" + for i in range(1, len(self.ticks)): + msg = self.ticks[i][1] + ellapsed = self.ticks[i][0] - self.ticks[i - 1][0] + rep += msg + f": {ellapsed*1000:.2f}ms\n" + rep += 80 * "=" + "\n\n\n" + return rep + +class DiffusionModel(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x, t, **kwargs): + raise NotImplementedError() + +class DiffusionModelWrapper(nn.Module): + def __init__( + self, + model: DiffusionModel, + io_channels, + sample_size, + sample_rate, + min_input_length, + pretransform: tp.Optional[Pretransform] = None, + ): + super().__init__() + self.io_channels = io_channels + self.sample_size = sample_size + self.sample_rate = sample_rate + self.min_input_length = min_input_length + + self.model = model + + if pretransform is not None: + self.pretransform = pretransform + else: + self.pretransform = None + + def forward(self, x, t, **kwargs): + return self.model(x, t, **kwargs) + +class ConditionedDiffusionModel(nn.Module): + def __init__(self, + *args, + supports_cross_attention: bool = False, + supports_input_concat: bool = False, + supports_global_cond: bool = False, + supports_prepend_cond: bool = False, + **kwargs): + super().__init__(*args, **kwargs) + self.supports_cross_attention = supports_cross_attention + self.supports_input_concat = supports_input_concat + self.supports_global_cond = supports_global_cond + self.supports_prepend_cond = supports_prepend_cond + + def forward(self, + x: torch.Tensor, + t: torch.Tensor, + cross_attn_cond: torch.Tensor = None, + cross_attn_mask: torch.Tensor = None, + input_concat_cond: torch.Tensor = None, + global_embed: torch.Tensor = None, + prepend_cond: torch.Tensor = None, + prepend_cond_mask: torch.Tensor = None, + cfg_scale: float = 1.0, + cfg_dropout_prob: float = 0.0, + batch_cfg: bool = False, + rescale_cfg: bool = False, + **kwargs): + raise NotImplementedError() + +class ConditionedDiffusionModelWrapper(nn.Module): + """ + A diffusion model that takes in conditioning + """ + def __init__( + self, + model: ConditionedDiffusionModel, + conditioner: MultiConditioner, + io_channels, + sample_rate, + min_input_length: int, + pretransform: tp.Optional[Pretransform] = None, + cross_attn_cond_ids: tp.List[str] = [], + global_cond_ids: tp.List[str] = [], + input_concat_ids: tp.List[str] = [], + prepend_cond_ids: tp.List[str] = [], + ): + super().__init__() + + self.model = model + self.conditioner = conditioner + self.io_channels = io_channels + self.sample_rate = sample_rate + self.pretransform = pretransform + self.cross_attn_cond_ids = cross_attn_cond_ids + self.global_cond_ids = global_cond_ids + self.input_concat_ids = input_concat_ids + self.prepend_cond_ids = prepend_cond_ids + self.min_input_length = min_input_length + + def get_conditioning_inputs(self, cond: tp.Dict[str, tp.Any], negative=False): + cross_attention_input = None + cross_attention_masks = None + global_cond = None + input_concat_cond = None + prepend_cond = None + prepend_cond_mask = None + + if len(self.cross_attn_cond_ids) > 0: + # Concatenate all cross-attention inputs over the sequence dimension + # Assumes that the cross-attention inputs are of shape (batch, seq, channels) + cross_attention_input = [] + cross_attention_masks = [] + + for key in self.cross_attn_cond_ids: + cross_attn_in, cross_attn_mask = cond[key] + + # Add sequence dimension if it's not there + if len(cross_attn_in.shape) == 2: + cross_attn_in = cross_attn_in.unsqueeze(1) + cross_attn_mask = cross_attn_mask.unsqueeze(1) + + cross_attention_input.append(cross_attn_in) + cross_attention_masks.append(cross_attn_mask) + + cross_attention_input = torch.cat(cross_attention_input, dim=1) + cross_attention_masks = torch.cat(cross_attention_masks, dim=1) + + if len(self.global_cond_ids) > 0: + # Concatenate all global conditioning inputs over the channel dimension + # Assumes that the global conditioning inputs are of shape (batch, channels) + global_cond = torch.cat([cond[key][0] for key in self.global_cond_ids], dim=-1) + if len(global_cond.shape) == 3: + global_cond = global_cond.squeeze(1) + + if len(self.input_concat_ids) > 0: + # Concatenate all input concat conditioning inputs over the channel dimension + # Assumes that the input concat conditioning inputs are of shape (batch, channels, seq) + input_concat_cond = torch.cat([cond[key][0] for key in self.input_concat_ids], dim=1) + + if len(self.prepend_cond_ids) > 0: + # Concatenate all prepend conditioning inputs over the sequence dimension + # Assumes that the prepend conditioning inputs are of shape (batch, seq, channels) + prepend_cond = torch.cat([cond[key][0] for key in self.prepend_cond_ids], dim=1) + prepend_cond_mask = torch.cat([cond[key][1] for key in self.prepend_cond_ids], dim=1) + + if negative: + return { + "negative_cross_attn_cond": cross_attention_input, + "negative_cross_attn_mask": cross_attention_masks, + "negative_global_cond": global_cond, + "negative_input_concat_cond": input_concat_cond + } + else: + return { + "cross_attn_cond": cross_attention_input, + "cross_attn_mask": cross_attention_masks, + "global_cond": global_cond, + "input_concat_cond": input_concat_cond, + "prepend_cond": prepend_cond, + "prepend_cond_mask": prepend_cond_mask + } + + def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs): + return self.model(x, t, **self.get_conditioning_inputs(cond), **kwargs) + + def generate(self, *args, **kwargs): + return generate_diffusion_cond(self, *args, **kwargs) + +class UNetCFG1DWrapper(ConditionedDiffusionModel): + def __init__( + self, + *args, + **kwargs + ): + super().__init__(supports_cross_attention=True, supports_global_cond=True, supports_input_concat=True) + + self.model = UNetCFG1d(*args, **kwargs) + + with torch.no_grad(): + for param in self.model.parameters(): + param *= 0.5 + + def forward(self, + x, + t, + cross_attn_cond=None, + cross_attn_mask=None, + input_concat_cond=None, + global_cond=None, + cfg_scale=1.0, + cfg_dropout_prob: float = 0.0, + batch_cfg: bool = False, + rescale_cfg: bool = False, + negative_cross_attn_cond=None, + negative_cross_attn_mask=None, + negative_global_cond=None, + negative_input_concat_cond=None, + prepend_cond=None, + prepend_cond_mask=None, + **kwargs): + p = Profiler() + + p.tick("start") + + channels_list = None + if input_concat_cond is not None: + channels_list = [input_concat_cond] + + outputs = self.model( + x, + t, + embedding=cross_attn_cond, + embedding_mask=cross_attn_mask, + features=global_cond, + channels_list=channels_list, + embedding_scale=cfg_scale, + embedding_mask_proba=cfg_dropout_prob, + batch_cfg=batch_cfg, + rescale_cfg=rescale_cfg, + negative_embedding=negative_cross_attn_cond, + negative_embedding_mask=negative_cross_attn_mask, + **kwargs) + + p.tick("UNetCFG1D forward") + + #print(f"Profiler: {p}") + return outputs + +class UNet1DCondWrapper(ConditionedDiffusionModel): + def __init__( + self, + *args, + **kwargs + ): + super().__init__(supports_cross_attention=False, supports_global_cond=True, supports_input_concat=True) + + self.model = UNet1d(*args, **kwargs) + + with torch.no_grad(): + for param in self.model.parameters(): + param *= 0.5 + + def forward(self, + x, + t, + input_concat_cond=None, + global_cond=None, + cross_attn_cond=None, + cross_attn_mask=None, + prepend_cond=None, + prepend_cond_mask=None, + cfg_scale=1.0, + cfg_dropout_prob: float = 0.0, + batch_cfg: bool = False, + rescale_cfg: bool = False, + negative_cross_attn_cond=None, + negative_cross_attn_mask=None, + negative_global_cond=None, + negative_input_concat_cond=None, + **kwargs): + + channels_list = None + if input_concat_cond is not None: + + # Interpolate input_concat_cond to the same length as x + if input_concat_cond.shape[2] != x.shape[2]: + input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest') + + channels_list = [input_concat_cond] + + outputs = self.model( + x, + t, + features=global_cond, + channels_list=channels_list, + **kwargs) + + return outputs + +class UNet1DUncondWrapper(DiffusionModel): + def __init__( + self, + in_channels, + *args, + **kwargs + ): + super().__init__() + + self.model = UNet1d(in_channels=in_channels, *args, **kwargs) + + self.io_channels = in_channels + + with torch.no_grad(): + for param in self.model.parameters(): + param *= 0.5 + + def forward(self, x, t, **kwargs): + return self.model(x, t, **kwargs) + +class DAU1DCondWrapper(ConditionedDiffusionModel): + def __init__( + self, + *args, + **kwargs + ): + super().__init__(supports_cross_attention=False, supports_global_cond=False, supports_input_concat=True) + + self.model = DiffusionAttnUnet1D(*args, **kwargs) + + with torch.no_grad(): + for param in self.model.parameters(): + param *= 0.5 + + def forward(self, + x, + t, + input_concat_cond=None, + cross_attn_cond=None, + cross_attn_mask=None, + global_cond=None, + cfg_scale=1.0, + cfg_dropout_prob: float = 0.0, + batch_cfg: bool = False, + rescale_cfg: bool = False, + negative_cross_attn_cond=None, + negative_cross_attn_mask=None, + negative_global_cond=None, + negative_input_concat_cond=None, + prepend_cond=None, + **kwargs): + + return self.model(x, t, cond = input_concat_cond) + +class DiffusionAttnUnet1D(nn.Module): + def __init__( + self, + io_channels = 2, + depth=14, + n_attn_layers = 6, + channels = [128, 128, 256, 256] + [512] * 10, + cond_dim = 0, + cond_noise_aug = False, + kernel_size = 5, + learned_resample = False, + strides = [2] * 13, + conv_bias = True, + use_snake = False + ): + super().__init__() + + self.cond_noise_aug = cond_noise_aug + + self.io_channels = io_channels + + if self.cond_noise_aug: + self.rng = torch.quasirandom.SobolEngine(1, scramble=True) + + self.timestep_embed = FourierFeatures(1, 16) + + attn_layer = depth - n_attn_layers + + strides = [1] + strides + + block = nn.Identity() + + conv_block = partial(ResConvBlock, kernel_size=kernel_size, conv_bias = conv_bias, use_snake=use_snake) + + for i in range(depth, 0, -1): + c = channels[i - 1] + stride = strides[i-1] + if stride > 2 and not learned_resample: + raise ValueError("Must have stride 2 without learned resampling") + + if i > 1: + c_prev = channels[i - 2] + add_attn = i >= attn_layer and n_attn_layers > 0 + block = SkipBlock( + Downsample1d_2(c_prev, c_prev, stride) if (learned_resample or stride == 1) else Downsample1d("cubic"), + conv_block(c_prev, c, c), + SelfAttention1d( + c, c // 32) if add_attn else nn.Identity(), + conv_block(c, c, c), + SelfAttention1d( + c, c // 32) if add_attn else nn.Identity(), + conv_block(c, c, c), + SelfAttention1d( + c, c // 32) if add_attn else nn.Identity(), + block, + conv_block(c * 2 if i != depth else c, c, c), + SelfAttention1d( + c, c // 32) if add_attn else nn.Identity(), + conv_block(c, c, c), + SelfAttention1d( + c, c // 32) if add_attn else nn.Identity(), + conv_block(c, c, c_prev), + SelfAttention1d(c_prev, c_prev // + 32) if add_attn else nn.Identity(), + Upsample1d_2(c_prev, c_prev, stride) if learned_resample else Upsample1d(kernel="cubic") + ) + else: + cond_embed_dim = 16 if not self.cond_noise_aug else 32 + block = nn.Sequential( + conv_block((io_channels + cond_dim) + cond_embed_dim, c, c), + conv_block(c, c, c), + conv_block(c, c, c), + block, + conv_block(c * 2, c, c), + conv_block(c, c, c), + conv_block(c, c, io_channels, is_last=True), + ) + self.net = block + + with torch.no_grad(): + for param in self.net.parameters(): + param *= 0.5 + + def forward(self, x, t, cond=None, cond_aug_scale=None): + + timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), x.shape) + + inputs = [x, timestep_embed] + + if cond is not None: + if cond.shape[2] != x.shape[2]: + cond = F.interpolate(cond, (x.shape[2], ), mode='linear', align_corners=False) + + if self.cond_noise_aug: + # Get a random number between 0 and 1, uniformly sampled + if cond_aug_scale is None: + aug_level = self.rng.draw(cond.shape[0])[:, 0].to(cond) + else: + aug_level = torch.tensor([cond_aug_scale]).repeat([cond.shape[0]]).to(cond) + + # Add noise to the conditioning signal + cond = cond + torch.randn_like(cond) * aug_level[:, None, None] + + # Get embedding for noise cond level, reusing timestamp_embed + aug_level_embed = expand_to_planes(self.timestep_embed(aug_level[:, None]), x.shape) + + inputs.append(aug_level_embed) + + inputs.append(cond) + + outputs = self.net(torch.cat(inputs, dim=1)) + + return outputs + +class DiTWrapper(ConditionedDiffusionModel): + def __init__( + self, + *args, + **kwargs + ): + super().__init__(supports_cross_attention=True, supports_global_cond=False, supports_input_concat=False) + + self.model = DiffusionTransformer(*args, **kwargs) + + with torch.no_grad(): + for param in self.model.parameters(): + param *= 0.5 + + def forward(self, + x, + t, + cross_attn_cond=None, + cross_attn_mask=None, + negative_cross_attn_cond=None, + negative_cross_attn_mask=None, + input_concat_cond=None, + negative_input_concat_cond=None, + global_cond=None, + negative_global_cond=None, + prepend_cond=None, + prepend_cond_mask=None, + cfg_scale=1.0, + cfg_dropout_prob: float = 0.0, + batch_cfg: bool = True, + rescale_cfg: bool = False, + scale_phi: float = 0.0, + **kwargs): + + assert batch_cfg, "batch_cfg must be True for DiTWrapper" + assert negative_input_concat_cond is None, "negative_input_concat_cond is not supported for DiTWrapper" + + return self.model( + x, + t, + cross_attn_cond=cross_attn_cond, + cross_attn_cond_mask=cross_attn_mask, + negative_cross_attn_cond=negative_cross_attn_cond, + negative_cross_attn_mask=negative_cross_attn_mask, + input_concat_cond=input_concat_cond, + prepend_cond=prepend_cond, + prepend_cond_mask=prepend_cond_mask, + cfg_scale=cfg_scale, + cfg_dropout_prob=cfg_dropout_prob, + scale_phi=scale_phi, + global_embed=global_cond, + **kwargs) + +class DiTUncondWrapper(DiffusionModel): + def __init__( + self, + in_channels, + *args, + **kwargs + ): + super().__init__() + + self.model = DiffusionTransformer(io_channels=in_channels, *args, **kwargs) + + self.io_channels = in_channels + + with torch.no_grad(): + for param in self.model.parameters(): + param *= 0.5 + + def forward(self, x, t, **kwargs): + return self.model(x, t, **kwargs) + +def create_diffusion_uncond_from_config(config: tp.Dict[str, tp.Any]): + diffusion_uncond_config = config["model"] + + model_type = diffusion_uncond_config.get('type', None) + + diffusion_config = diffusion_uncond_config.get('config', {}) + + assert model_type is not None, "Must specify model type in config" + + pretransform = diffusion_uncond_config.get("pretransform", None) + + sample_size = config.get("sample_size", None) + assert sample_size is not None, "Must specify sample size in config" + + sample_rate = config.get("sample_rate", None) + assert sample_rate is not None, "Must specify sample rate in config" + + if pretransform is not None: + pretransform = create_pretransform_from_config(pretransform, sample_rate) + min_input_length = pretransform.downsampling_ratio + else: + min_input_length = 1 + + if model_type == 'DAU1d': + + model = DiffusionAttnUnet1D( + **diffusion_config + ) + + elif model_type == "adp_uncond_1d": + + model = UNet1DUncondWrapper( + **diffusion_config + ) + + elif model_type == "dit": + model = DiTUncondWrapper( + **diffusion_config + ) + + else: + raise NotImplementedError(f'Unknown model type: {model_type}') + + return DiffusionModelWrapper(model, + io_channels=model.io_channels, + sample_size=sample_size, + sample_rate=sample_rate, + pretransform=pretransform, + min_input_length=min_input_length) + +def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]): + + model_config = config["model"] + + model_type = config["model_type"] + + diffusion_config = model_config.get('diffusion', None) + assert diffusion_config is not None, "Must specify diffusion config" + + diffusion_model_type = diffusion_config.get('type', None) + assert diffusion_model_type is not None, "Must specify diffusion model type" + + diffusion_model_config = diffusion_config.get('config', None) + assert diffusion_model_config is not None, "Must specify diffusion model config" + + if diffusion_model_type == 'adp_cfg_1d': + diffusion_model = UNetCFG1DWrapper(**diffusion_model_config) + elif diffusion_model_type == 'adp_1d': + diffusion_model = UNet1DCondWrapper(**diffusion_model_config) + elif diffusion_model_type == 'dit': + diffusion_model = DiTWrapper(**diffusion_model_config) + + io_channels = model_config.get('io_channels', None) + assert io_channels is not None, "Must specify io_channels in model config" + + sample_rate = config.get('sample_rate', None) + assert sample_rate is not None, "Must specify sample_rate in config" + + conditioning_config = model_config.get('conditioning', None) + + conditioner = None + if conditioning_config is not None: + conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config) + + cross_attention_ids = diffusion_config.get('cross_attention_cond_ids', []) + global_cond_ids = diffusion_config.get('global_cond_ids', []) + input_concat_ids = diffusion_config.get('input_concat_ids', []) + prepend_cond_ids = diffusion_config.get('prepend_cond_ids', []) + + pretransform = model_config.get("pretransform", None) + + if pretransform is not None: + pretransform = create_pretransform_from_config(pretransform, sample_rate) + min_input_length = pretransform.downsampling_ratio + else: + min_input_length = 1 + + if diffusion_model_type == "adp_cfg_1d" or diffusion_model_type == "adp_1d": + min_input_length *= np.prod(diffusion_model_config["factors"]) + elif diffusion_model_type == "dit": + min_input_length *= diffusion_model.model.patch_size + + # Get the proper wrapper class + + if model_type == "diffusion_cond" or model_type == "diffusion_cond_inpaint": + wrapper_fn = ConditionedDiffusionModelWrapper + elif model_type == "diffusion_prior": + prior_type = model_config.get("prior_type", None) + assert prior_type is not None, "Must specify prior_type in diffusion prior model config" + + if prior_type == "mono_stereo": + from .diffusion_prior import MonoToStereoDiffusionPrior + wrapper_fn = MonoToStereoDiffusionPrior + elif prior_type == "source_separation": + from .diffusion_prior import SourceSeparationDiffusionPrior + wrapper_fn = SourceSeparationDiffusionPrior + + return wrapper_fn( + diffusion_model, + conditioner, + min_input_length=min_input_length, + sample_rate=sample_rate, + cross_attn_cond_ids=cross_attention_ids, + global_cond_ids=global_cond_ids, + input_concat_ids=input_concat_ids, + prepend_cond_ids=prepend_cond_ids, + pretransform=pretransform, + io_channels=io_channels + ) \ No newline at end of file diff --git a/stable_audio_tools/models/diffusion_prior.py b/stable_audio_tools/models/diffusion_prior.py new file mode 100644 index 0000000000000000000000000000000000000000..eee63e0ebc778aec883bd3576421758238444463 --- /dev/null +++ b/stable_audio_tools/models/diffusion_prior.py @@ -0,0 +1,151 @@ +from enum import Enum +import typing as tp + +from .diffusion import ConditionedDiffusionModelWrapper +from ..inference.generation import generate_diffusion_cond +from ..inference.utils import prepare_audio + +import torch +from torch.nn import functional as F +from torchaudio import transforms as T + +# Define prior types enum +class PriorType(Enum): + MonoToStereo = 1 + SourceSeparation = 2 + +class DiffusionPrior(ConditionedDiffusionModelWrapper): + def __init__(self, *args, prior_type: PriorType=None, **kwargs): + super().__init__(*args, **kwargs) + self.prior_type = prior_type + +class MonoToStereoDiffusionPrior(DiffusionPrior): + def __init__(self, *args, **kwargs): + super().__init__(*args, prior_type=PriorType.MonoToStereo, **kwargs) + + def stereoize( + self, + audio: torch.Tensor, # (batch, channels, time) + in_sr: int, + steps: int, + sampler_kwargs: dict = {}, + ): + """ + Generate stereo audio from mono audio using a pre-trained diffusion prior + + Args: + audio: The mono audio to convert to stereo + in_sr: The sample rate of the input audio + steps: The number of diffusion steps to run + sampler_kwargs: Keyword arguments to pass to the diffusion sampler + """ + + device = audio.device + + sample_rate = self.sample_rate + + # Resample input audio if necessary + if in_sr != sample_rate: + resample_tf = T.Resample(in_sr, sample_rate).to(audio.device) + audio = resample_tf(audio) + + audio_length = audio.shape[-1] + + # Pad input audio to be compatible with the model + min_length = self.min_input_length + padded_input_length = audio_length + (min_length - (audio_length % min_length)) % min_length + + # Pad input audio to be compatible with the model + if padded_input_length > audio_length: + audio = F.pad(audio, (0, padded_input_length - audio_length)) + + # Make audio mono, duplicate to stereo + dual_mono = audio.mean(1, keepdim=True).repeat(1, 2, 1) + + if self.pretransform is not None: + dual_mono = self.pretransform.encode(dual_mono) + + conditioning = {"source": [dual_mono]} + + stereo_audio = generate_diffusion_cond( + self, + conditioning_tensors=conditioning, + steps=steps, + sample_size=padded_input_length, + sample_rate=sample_rate, + device=device, + **sampler_kwargs, + ) + + return stereo_audio + + +class SourceSeparationDiffusionPrior(DiffusionPrior): + """ + A diffusion prior model made for conditioned source separation + """ + def __init__(self, *args, **kwargs): + super().__init__(*args, prior_type=PriorType.SourceSeparation, **kwargs) + + def separate( + self, + mixed_audio: torch.Tensor, # (batch, channels, time) + in_sr: int, + steps: int, + conditioning: dict = None, + conditioning_tensors: tp.Optional[dict] = None, + sampler_kwargs: dict = {}, + ): + """ + Separate audio sources based on conditioning using a pre-trained diffusion prior + + Args: + mixed_audio: The mixed audio to separate + in_sr: The sample rate of the input audio + steps: The number of diffusion steps to run + conditioning: The conditioning to use for source separation + conditioning_tensors: Pre-computed conditioning tensors to use for source separation. If provided, conditioning is ignored. + sampler_kwargs: Keyword arguments to pass to the diffusion sampler + """ + + device = mixed_audio.device + + sample_rate = self.sample_rate + + # Resample input audio if necessary + if in_sr != sample_rate: + resample_tf = T.Resample(in_sr, sample_rate).to(mixed_audio.device) + mixed_audio = resample_tf(mixed_audio) + + audio_length = mixed_audio.shape[-1] + + # Pad input audio to be compatible with the model + min_length = self.min_input_length + padded_input_length = audio_length + (min_length - (audio_length % min_length)) % min_length + + # Pad input audio to be compatible with the model + if padded_input_length > audio_length: + mixed_audio = F.pad(mixed_audio, (0, padded_input_length - audio_length)) + + if self.pretransform is not None: + mixed_audio = self.pretransform.encode(mixed_audio) + + # Conditioning + assert conditioning is not None or conditioning_tensors is not None, "Must provide either conditioning or conditioning_tensors for conditioned source separation" + if conditioning_tensors is None: + conditioning_tensors = self.conditioner(conditioning, device) + + # Pass in the mixture audio as conditioning + conditioning_tensors["source"] = [mixed_audio] + + stereo_audio = generate_diffusion_cond( + self, + conditioning_tensors=conditioning_tensors, + steps=steps, + sample_size=padded_input_length, + sample_rate=sample_rate, + device=device, + **sampler_kwargs, + ) + + return stereo_audio \ No newline at end of file diff --git a/stable_audio_tools/models/discriminators.py b/stable_audio_tools/models/discriminators.py new file mode 100644 index 0000000000000000000000000000000000000000..b593168df965bb1f57881ea79edbc2f66478c6c2 --- /dev/null +++ b/stable_audio_tools/models/discriminators.py @@ -0,0 +1,546 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from functools import reduce +import typing as tp +from einops import rearrange +from audiotools import AudioSignal, STFTParams +from dac.model.discriminator import WNConv1d, WNConv2d + +def get_hinge_losses(score_real, score_fake): + gen_loss = -score_fake.mean() + dis_loss = torch.relu(1 - score_real).mean() + torch.relu(1 + score_fake).mean() + return dis_loss, gen_loss + +class EncodecDiscriminator(nn.Module): + + def __init__(self, *args, **kwargs): + super().__init__() + + from encodec.msstftd import MultiScaleSTFTDiscriminator + + self.discriminators = MultiScaleSTFTDiscriminator(*args, **kwargs) + + def forward(self, x): + logits, features = self.discriminators(x) + return logits, features + + def loss(self, x, y): + feature_matching_distance = 0. + logits_true, feature_true = self.forward(x) + logits_fake, feature_fake = self.forward(y) + + dis_loss = torch.tensor(0.) + adv_loss = torch.tensor(0.) + + for i, (scale_true, scale_fake) in enumerate(zip(feature_true, feature_fake)): + + feature_matching_distance = feature_matching_distance + sum( + map( + lambda x, y: abs(x - y).mean(), + scale_true, + scale_fake, + )) / len(scale_true) + + _dis, _adv = get_hinge_losses( + logits_true[i], + logits_fake[i], + ) + + dis_loss = dis_loss + _dis + adv_loss = adv_loss + _adv + + return dis_loss, adv_loss, feature_matching_distance + +# Discriminators from oobleck + +IndividualDiscriminatorOut = tp.Tuple[torch.Tensor, tp.Sequence[torch.Tensor]] + +TensorDict = tp.Dict[str, torch.Tensor] + +class SharedDiscriminatorConvNet(nn.Module): + + def __init__( + self, + in_size: int, + convolution: tp.Union[nn.Conv1d, nn.Conv2d], + out_size: int = 1, + capacity: int = 32, + n_layers: int = 4, + kernel_size: int = 15, + stride: int = 4, + activation: tp.Callable[[], nn.Module] = lambda: nn.SiLU(), + normalization: tp.Callable[[nn.Module], nn.Module] = torch.nn.utils.weight_norm, + ) -> None: + super().__init__() + channels = [in_size] + channels += list(capacity * 2**np.arange(n_layers)) + + if isinstance(stride, int): + stride = n_layers * [stride] + + net = [] + for i in range(n_layers): + if isinstance(kernel_size, int): + pad = kernel_size // 2 + s = stride[i] + else: + pad = kernel_size[0] // 2 + s = (stride[i], 1) + + net.append( + normalization( + convolution( + channels[i], + channels[i + 1], + kernel_size, + stride=s, + padding=pad, + ))) + net.append(activation()) + + net.append(convolution(channels[-1], out_size, 1)) + + self.net = nn.ModuleList(net) + + def forward(self, x) -> IndividualDiscriminatorOut: + features = [] + for layer in self.net: + x = layer(x) + if isinstance(layer, nn.modules.conv._ConvNd): + features.append(x) + score = x.reshape(x.shape[0], -1).mean(-1) + return score, features + + +class MultiScaleDiscriminator(nn.Module): + + def __init__(self, + in_channels: int, + n_scales: int, + **conv_kwargs) -> None: + super().__init__() + layers = [] + for _ in range(n_scales): + layers.append(SharedDiscriminatorConvNet(in_channels, nn.Conv1d, **conv_kwargs)) + self.layers = nn.ModuleList(layers) + + def forward(self, x: torch.Tensor) -> IndividualDiscriminatorOut: + score = 0 + features = [] + for layer in self.layers: + s, f = layer(x) + score = score + s + features.extend(f) + x = nn.functional.avg_pool1d(x, 2) + return score, features + +class MultiPeriodDiscriminator(nn.Module): + + def __init__(self, + in_channels: int, + periods: tp.Sequence[int], + **conv_kwargs) -> None: + super().__init__() + layers = [] + self.periods = periods + + for _ in periods: + layers.append(SharedDiscriminatorConvNet(in_channels, nn.Conv2d, **conv_kwargs)) + + self.layers = nn.ModuleList(layers) + + def forward(self, x: torch.Tensor) -> IndividualDiscriminatorOut: + score = 0 + features = [] + for layer, n in zip(self.layers, self.periods): + s, f = layer(self.fold(x, n)) + score = score + s + features.extend(f) + return score, features + + def fold(self, x: torch.Tensor, n: int) -> torch.Tensor: + pad = (n - (x.shape[-1] % n)) % n + x = nn.functional.pad(x, (0, pad)) + return x.reshape(*x.shape[:2], -1, n) + + +class MultiDiscriminator(nn.Module): + """ + Individual discriminators should take a single tensor as input (NxB C T) and + return a tuple composed of a score tensor (NxB) and a Sequence of Features + Sequence[NxB C' T']. + """ + + def __init__(self, discriminator_list: tp.Sequence[nn.Module], + keys: tp.Sequence[str]) -> None: + super().__init__() + self.discriminators = nn.ModuleList(discriminator_list) + self.keys = keys + + def unpack_tensor_to_dict(self, features: torch.Tensor) -> TensorDict: + features = features.chunk(len(self.keys), 0) + return {k: features[i] for i, k in enumerate(self.keys)} + + @staticmethod + def concat_dicts(dict_a, dict_b): + out_dict = {} + keys = set(list(dict_a.keys()) + list(dict_b.keys())) + for k in keys: + out_dict[k] = [] + if k in dict_a: + if isinstance(dict_a[k], list): + out_dict[k].extend(dict_a[k]) + else: + out_dict[k].append(dict_a[k]) + if k in dict_b: + if isinstance(dict_b[k], list): + out_dict[k].extend(dict_b[k]) + else: + out_dict[k].append(dict_b[k]) + return out_dict + + @staticmethod + def sum_dicts(dict_a, dict_b): + out_dict = {} + keys = set(list(dict_a.keys()) + list(dict_b.keys())) + for k in keys: + out_dict[k] = 0. + if k in dict_a: + out_dict[k] = out_dict[k] + dict_a[k] + if k in dict_b: + out_dict[k] = out_dict[k] + dict_b[k] + return out_dict + + def forward(self, inputs: TensorDict) -> TensorDict: + discriminator_input = torch.cat([inputs[k] for k in self.keys], 0) + all_scores = [] + all_features = [] + + for discriminator in self.discriminators: + score, features = discriminator(discriminator_input) + scores = self.unpack_tensor_to_dict(score) + scores = {f"score_{k}": scores[k] for k in scores.keys()} + all_scores.append(scores) + + features = map(self.unpack_tensor_to_dict, features) + features = reduce(self.concat_dicts, features) + features = {f"features_{k}": features[k] for k in features.keys()} + all_features.append(features) + + all_scores = reduce(self.sum_dicts, all_scores) + all_features = reduce(self.concat_dicts, all_features) + + inputs.update(all_scores) + inputs.update(all_features) + + return inputs + +class OobleckDiscriminator(nn.Module): + + def __init__( + self, + in_channels=1, + ): + super().__init__() + + multi_scale_discriminator = MultiScaleDiscriminator( + in_channels=in_channels, + n_scales=3, + ) + + multi_period_discriminator = MultiPeriodDiscriminator( + in_channels=in_channels, + periods=[2, 3, 5, 7, 11] + ) + + # multi_resolution_discriminator = MultiScaleSTFTDiscriminator( + # filters=32, + # in_channels = in_channels, + # out_channels = 1, + # n_ffts = [2048, 1024, 512, 256, 128], + # hop_lengths = [512, 256, 128, 64, 32], + # win_lengths = [2048, 1024, 512, 256, 128] + # ) + + self.multi_discriminator = MultiDiscriminator( + [multi_scale_discriminator, multi_period_discriminator], #, multi_resolution_discriminator], + ["reals", "fakes"] + ) + + def loss(self, reals, fakes): + inputs = { + "reals": reals, + "fakes": fakes, + } + + inputs = self.multi_discriminator(inputs) + + scores_real = inputs["score_reals"] + scores_fake = inputs["score_fakes"] + + features_real = inputs["features_reals"] + features_fake = inputs["features_fakes"] + + dis_loss, gen_loss = get_hinge_losses(scores_real, scores_fake) + + feature_matching_distance = torch.tensor(0.) + + for _, (scale_real, scale_fake) in enumerate(zip(features_real, features_fake)): + + feature_matching_distance = feature_matching_distance + sum( + map( + lambda real, fake: abs(real - fake).mean(), + scale_real, + scale_fake, + )) / len(scale_real) + + return dis_loss, gen_loss, feature_matching_distance + + +## Discriminators from Descript Audio Codec repo +## Copied and modified under MIT license, see LICENSES/LICENSE_DESCRIPT.txt +class MPD(nn.Module): + def __init__(self, period, channels=1): + super().__init__() + + self.period = period + self.convs = nn.ModuleList( + [ + WNConv2d(channels, 32, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)), + ] + ) + self.conv_post = WNConv2d( + 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False + ) + + def pad_to_period(self, x): + t = x.shape[-1] + x = F.pad(x, (0, self.period - t % self.period), mode="reflect") + return x + + def forward(self, x): + fmap = [] + + x = self.pad_to_period(x) + x = rearrange(x, "b c (l p) -> b c l p", p=self.period) + + for layer in self.convs: + x = layer(x) + fmap.append(x) + + x = self.conv_post(x) + fmap.append(x) + + return fmap + + +class MSD(nn.Module): + def __init__(self, rate: int = 1, sample_rate: int = 44100, channels=1): + super().__init__() + + self.convs = nn.ModuleList( + [ + WNConv1d(channels, 16, 15, 1, padding=7), + WNConv1d(16, 64, 41, 4, groups=4, padding=20), + WNConv1d(64, 256, 41, 4, groups=16, padding=20), + WNConv1d(256, 1024, 41, 4, groups=64, padding=20), + WNConv1d(1024, 1024, 41, 4, groups=256, padding=20), + WNConv1d(1024, 1024, 5, 1, padding=2), + ] + ) + self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False) + self.sample_rate = sample_rate + self.rate = rate + + def forward(self, x): + x = AudioSignal(x, self.sample_rate) + x.resample(self.sample_rate // self.rate) + x = x.audio_data + + fmap = [] + + for l in self.convs: + x = l(x) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + + return fmap + + +BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)] + + +class MRD(nn.Module): + def __init__( + self, + window_length: int, + hop_factor: float = 0.25, + sample_rate: int = 44100, + bands: list = BANDS, + channels: int = 1 + ): + """Complex multi-band spectrogram discriminator. + Parameters + ---------- + window_length : int + Window length of STFT. + hop_factor : float, optional + Hop factor of the STFT, defaults to ``0.25 * window_length``. + sample_rate : int, optional + Sampling rate of audio in Hz, by default 44100 + bands : list, optional + Bands to run discriminator over. + """ + super().__init__() + + self.window_length = window_length + self.hop_factor = hop_factor + self.sample_rate = sample_rate + self.stft_params = STFTParams( + window_length=window_length, + hop_length=int(window_length * hop_factor), + match_stride=True, + ) + + self.channels = channels + + n_fft = window_length // 2 + 1 + bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] + self.bands = bands + + ch = 32 + convs = lambda: nn.ModuleList( + [ + WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)), + ] + ) + self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) + self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False) + + def spectrogram(self, x): + x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params) + x = torch.view_as_real(x.stft()) + x = rearrange(x, "b ch f t c -> (b ch) c t f", ch=self.channels) + # Split into bands + x_bands = [x[..., b[0] : b[1]] for b in self.bands] + return x_bands + + def forward(self, x): + x_bands = self.spectrogram(x) + fmap = [] + + x = [] + for band, stack in zip(x_bands, self.band_convs): + for layer in stack: + band = layer(band) + fmap.append(band) + x.append(band) + + x = torch.cat(x, dim=-1) + x = self.conv_post(x) + fmap.append(x) + + return fmap + + +class DACDiscriminator(nn.Module): + def __init__( + self, + channels: int = 1, + rates: list = [], + periods: list = [2, 3, 5, 7, 11], + fft_sizes: list = [2048, 1024, 512], + sample_rate: int = 44100, + bands: list = BANDS, + ): + """Discriminator that combines multiple discriminators. + + Parameters + ---------- + rates : list, optional + sampling rates (in Hz) to run MSD at, by default [] + If empty, MSD is not used. + periods : list, optional + periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11] + fft_sizes : list, optional + Window sizes of the FFT to run MRD at, by default [2048, 1024, 512] + sample_rate : int, optional + Sampling rate of audio in Hz, by default 44100 + bands : list, optional + Bands to run MRD at, by default `BANDS` + """ + super().__init__() + discs = [] + discs += [MPD(p, channels=channels) for p in periods] + discs += [MSD(r, sample_rate=sample_rate, channels=channels) for r in rates] + discs += [MRD(f, sample_rate=sample_rate, bands=bands, channels=channels) for f in fft_sizes] + self.discriminators = nn.ModuleList(discs) + + def preprocess(self, y): + # Remove DC offset + y = y - y.mean(dim=-1, keepdims=True) + # Peak normalize the volume of input audio + y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9) + return y + + def forward(self, x): + x = self.preprocess(x) + fmaps = [d(x) for d in self.discriminators] + return fmaps + +class DACGANLoss(nn.Module): + """ + Computes a discriminator loss, given a discriminator on + generated waveforms/spectrograms compared to ground truth + waveforms/spectrograms. Computes the loss for both the + discriminator and the generator in separate functions. + """ + + def __init__(self, **discriminator_kwargs): + super().__init__() + self.discriminator = DACDiscriminator(**discriminator_kwargs) + + def forward(self, fake, real): + d_fake = self.discriminator(fake) + d_real = self.discriminator(real) + return d_fake, d_real + + def discriminator_loss(self, fake, real): + d_fake, d_real = self.forward(fake.clone().detach(), real) + + loss_d = 0 + for x_fake, x_real in zip(d_fake, d_real): + loss_d += torch.mean(x_fake[-1] ** 2) + loss_d += torch.mean((1 - x_real[-1]) ** 2) + return loss_d + + def generator_loss(self, fake, real): + d_fake, d_real = self.forward(fake, real) + + loss_g = 0 + for x_fake in d_fake: + loss_g += torch.mean((1 - x_fake[-1]) ** 2) + + loss_feature = 0 + + for i in range(len(d_fake)): + for j in range(len(d_fake[i]) - 1): + loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach()) + return loss_g, loss_feature + + def loss(self, fake, real): + gen_loss, feature_distance = self.generator_loss(fake, real) + dis_loss = self.discriminator_loss(fake, real) + + return dis_loss, gen_loss, feature_distance \ No newline at end of file diff --git a/stable_audio_tools/models/dit.py b/stable_audio_tools/models/dit.py new file mode 100644 index 0000000000000000000000000000000000000000..d965fe087ce90613c8763b4ebadbd8cfaa624069 --- /dev/null +++ b/stable_audio_tools/models/dit.py @@ -0,0 +1,358 @@ +import typing as tp + +import torch + +from einops import rearrange +from torch import nn +from torch.nn import functional as F +from x_transformers import ContinuousTransformerWrapper, Encoder + +from .blocks import FourierFeatures +from .transformer import ContinuousTransformer + +class DiffusionTransformer(nn.Module): + def __init__(self, + io_channels=32, + patch_size=1, + embed_dim=768, + cond_token_dim=0, + project_cond_tokens=True, + global_cond_dim=0, + input_concat_dim=0, + prepend_cond_dim=0, + depth=12, + num_heads=8, + transformer_type: tp.Literal["x-transformers", "continuous_transformer"] = "x-transformers", + global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend", + **kwargs): + + super().__init__() + + self.cond_token_dim = cond_token_dim + + # Timestep embeddings + timestep_features_dim = 256 + + self.timestep_features = FourierFeatures(1, timestep_features_dim) + + self.to_timestep_embed = nn.Sequential( + nn.Linear(timestep_features_dim, embed_dim, bias=True), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=True), + ) + + if cond_token_dim > 0: + # Conditioning tokens + + cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim + self.to_cond_embed = nn.Sequential( + nn.Linear(cond_token_dim, cond_embed_dim, bias=False), + nn.SiLU(), + nn.Linear(cond_embed_dim, cond_embed_dim, bias=False) + ) + else: + cond_embed_dim = 0 + + if global_cond_dim > 0: + # Global conditioning + self.to_global_embed = nn.Sequential( + nn.Linear(global_cond_dim, embed_dim, bias=False), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=False) + ) + + if prepend_cond_dim > 0: + # Prepend conditioning + self.to_prepend_embed = nn.Sequential( + nn.Linear(prepend_cond_dim, embed_dim, bias=False), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=False) + ) + + self.input_concat_dim = input_concat_dim + + dim_in = io_channels + self.input_concat_dim + + self.patch_size = patch_size + + # Transformer + + self.transformer_type = transformer_type + + self.global_cond_type = global_cond_type + + if self.transformer_type == "x-transformers": + self.transformer = ContinuousTransformerWrapper( + dim_in=dim_in * patch_size, + dim_out=io_channels * patch_size, + max_seq_len=0, #Not relevant without absolute positional embeds + attn_layers = Encoder( + dim=embed_dim, + depth=depth, + heads=num_heads, + attn_flash = True, + cross_attend = cond_token_dim > 0, + dim_context=None if cond_embed_dim == 0 else cond_embed_dim, + zero_init_branch_output=True, + use_abs_pos_emb = False, + rotary_pos_emb=True, + ff_swish = True, + ff_glu = True, + **kwargs + ) + ) + + elif self.transformer_type == "continuous_transformer": + + global_dim = None + + if self.global_cond_type == "adaLN": + # The global conditioning is projected to the embed_dim already at this point + global_dim = embed_dim + + self.transformer = ContinuousTransformer( + dim=embed_dim, + depth=depth, + dim_heads=embed_dim // num_heads, + dim_in=dim_in * patch_size, + dim_out=io_channels * patch_size, + cross_attend = cond_token_dim > 0, + cond_token_dim = cond_embed_dim, + global_cond_dim=global_dim, + **kwargs + ) + + else: + raise ValueError(f"Unknown transformer type: {self.transformer_type}") + + self.preprocess_conv = nn.Conv1d(dim_in, dim_in, 1, bias=False) + nn.init.zeros_(self.preprocess_conv.weight) + self.postprocess_conv = nn.Conv1d(io_channels, io_channels, 1, bias=False) + nn.init.zeros_(self.postprocess_conv.weight) + + def _forward( + self, + x, + t, + mask=None, + cross_attn_cond=None, + cross_attn_cond_mask=None, + input_concat_cond=None, + global_embed=None, + prepend_cond=None, + prepend_cond_mask=None, + **kwargs): + + if cross_attn_cond is not None: + cross_attn_cond = self.to_cond_embed(cross_attn_cond) + + if global_embed is not None: + # Project the global conditioning to the embedding dimension + global_embed = self.to_global_embed(global_embed) + + prepend_inputs = None + prepend_mask = None + prepend_length = 0 + if prepend_cond is not None: + # Project the prepend conditioning to the embedding dimension + prepend_cond = self.to_prepend_embed(prepend_cond) + + prepend_inputs = prepend_cond + if prepend_cond_mask is not None: + prepend_mask = prepend_cond_mask + + if input_concat_cond is not None: + + # Interpolate input_concat_cond to the same length as x + if input_concat_cond.shape[2] != x.shape[2]: + input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest') + + x = torch.cat([x, input_concat_cond], dim=1) + + # Get the batch of timestep embeddings + timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim) + + # Timestep embedding is considered a global embedding. Add to the global conditioning if it exists + if global_embed is not None: + global_embed = global_embed + timestep_embed + else: + global_embed = timestep_embed + + # Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer + if self.global_cond_type == "prepend": + if prepend_inputs is None: + # Prepend inputs are just the global embed, and the mask is all ones + prepend_inputs = global_embed.unsqueeze(1) + prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool) + else: + # Prepend inputs are the prepend conditioning + the global embed + prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1) + prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], dim=1) + + prepend_length = prepend_inputs.shape[1] + + x = self.preprocess_conv(x) + x + + x = rearrange(x, "b c t -> b t c") + + extra_args = {} + + if self.global_cond_type == "adaLN": + extra_args["global_cond"] = global_embed + + if self.patch_size > 1: + x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size) + + if self.transformer_type == "x-transformers" or self.transformer_type == "continuous_transformer": + output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, **extra_args, **kwargs) + + output = rearrange(output, "b t c -> b c t")[:,:,prepend_length:] + + if self.patch_size > 1: + output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size) + + output = self.postprocess_conv(output) + output + + return output + + def forward( + self, + x, + t, + cross_attn_cond=None, + cross_attn_cond_mask=None, + negative_cross_attn_cond=None, + negative_cross_attn_mask=None, + input_concat_cond=None, + global_embed=None, + negative_global_embed=None, + prepend_cond=None, + prepend_cond_mask=None, + cfg_scale=1.0, + cfg_dropout_prob=0.0, + causal=False, + scale_phi=0.0, + mask=None, + **kwargs): + + assert causal == False, "Causal mode is not supported for DiffusionTransformer" + + if cross_attn_cond_mask is not None: + cross_attn_cond_mask = cross_attn_cond_mask.bool() + + cross_attn_cond_mask = None # Temporarily disabling conditioning masks due to kernel issue for flash attention + + if prepend_cond_mask is not None: + prepend_cond_mask = prepend_cond_mask.bool() + + # CFG dropout + if cfg_dropout_prob > 0.0: + if cross_attn_cond is not None: + null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) + dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool) + cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond) + + if prepend_cond is not None: + null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) + dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool) + prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond) + + + if cfg_scale != 1.0 and (cross_attn_cond is not None or prepend_cond is not None): + # Classifier-free guidance + # Concatenate conditioned and unconditioned inputs on the batch dimension + batch_inputs = torch.cat([x, x], dim=0) + batch_timestep = torch.cat([t, t], dim=0) + + if global_embed is not None: + batch_global_cond = torch.cat([global_embed, global_embed], dim=0) + else: + batch_global_cond = None + + if input_concat_cond is not None: + batch_input_concat_cond = torch.cat([input_concat_cond, input_concat_cond], dim=0) + else: + batch_input_concat_cond = None + + batch_cond = None + batch_cond_masks = None + + # Handle CFG for cross-attention conditioning + if cross_attn_cond is not None: + + null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) + + # For negative cross-attention conditioning, replace the null embed with the negative cross-attention conditioning + if negative_cross_attn_cond is not None: + + # If there's a negative cross-attention mask, set the masked tokens to the null embed + if negative_cross_attn_mask is not None: + negative_cross_attn_mask = negative_cross_attn_mask.to(torch.bool).unsqueeze(2) + + negative_cross_attn_cond = torch.where(negative_cross_attn_mask, negative_cross_attn_cond, null_embed) + + batch_cond = torch.cat([cross_attn_cond, negative_cross_attn_cond], dim=0) + + else: + batch_cond = torch.cat([cross_attn_cond, null_embed], dim=0) + + if cross_attn_cond_mask is not None: + batch_cond_masks = torch.cat([cross_attn_cond_mask, cross_attn_cond_mask], dim=0) + + batch_prepend_cond = None + batch_prepend_cond_mask = None + + if prepend_cond is not None: + + null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) + + batch_prepend_cond = torch.cat([prepend_cond, null_embed], dim=0) + + if prepend_cond_mask is not None: + batch_prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0) + + + if mask is not None: + batch_masks = torch.cat([mask, mask], dim=0) + else: + batch_masks = None + + batch_output = self._forward( + batch_inputs, + batch_timestep, + cross_attn_cond=batch_cond, + cross_attn_cond_mask=batch_cond_masks, + mask = batch_masks, + input_concat_cond=batch_input_concat_cond, + global_embed = batch_global_cond, + prepend_cond = batch_prepend_cond, + prepend_cond_mask = batch_prepend_cond_mask, + **kwargs) + + cond_output, uncond_output = torch.chunk(batch_output, 2, dim=0) + cfg_output = uncond_output + (cond_output - uncond_output) * cfg_scale + + if scale_phi != 0.0: + + cond_out_std = cond_output.std(dim=1, keepdim=True) + out_cfg_std = cfg_output.std(dim=1, keepdim=True) + + return scale_phi * (cfg_output * (cond_out_std/out_cfg_std)) + (1-scale_phi) * cfg_output + + else: + + return cfg_output + + else: + return self._forward( + x, + t, + cross_attn_cond=cross_attn_cond, + cross_attn_cond_mask=cross_attn_cond_mask, + input_concat_cond=input_concat_cond, + global_embed=global_embed, + prepend_cond=prepend_cond, + prepend_cond_mask=prepend_cond_mask, + mask=mask, + **kwargs + ) \ No newline at end of file diff --git a/stable_audio_tools/models/factory.py b/stable_audio_tools/models/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..e6e7474abdb65cd2b83eb2b81cf9536a06cc9794 --- /dev/null +++ b/stable_audio_tools/models/factory.py @@ -0,0 +1,149 @@ +import json + +def create_model_from_config(model_config): + model_type = model_config.get('model_type', None) + + assert model_type is not None, 'model_type must be specified in model config' + + if model_type == 'autoencoder': + from .autoencoders import create_autoencoder_from_config + return create_autoencoder_from_config(model_config) + elif model_type == 'diffusion_uncond': + from .diffusion import create_diffusion_uncond_from_config + return create_diffusion_uncond_from_config(model_config) + elif model_type == 'diffusion_cond' or model_type == 'diffusion_cond_inpaint' or model_type == "diffusion_prior": + from .diffusion import create_diffusion_cond_from_config + return create_diffusion_cond_from_config(model_config) + elif model_type == 'diffusion_autoencoder': + from .autoencoders import create_diffAE_from_config + return create_diffAE_from_config(model_config) + elif model_type == 'musicgen': + from .musicgen import create_musicgen_from_config + return create_musicgen_from_config(model_config) + elif model_type == 'lm': + from .lm import create_audio_lm_from_config + return create_audio_lm_from_config(model_config) + else: + raise NotImplementedError(f'Unknown model type: {model_type}') + +def create_model_from_config_path(model_config_path): + with open(model_config_path) as f: + model_config = json.load(f) + + return create_model_from_config(model_config) + +def create_pretransform_from_config(pretransform_config, sample_rate): + pretransform_type = pretransform_config.get('type', None) + + assert pretransform_type is not None, 'type must be specified in pretransform config' + + if pretransform_type == 'autoencoder': + from .autoencoders import create_autoencoder_from_config + from .pretransforms import AutoencoderPretransform + + # Create fake top-level config to pass sample rate to autoencoder constructor + # This is a bit of a hack but it keeps us from re-defining the sample rate in the config + autoencoder_config = {"sample_rate": sample_rate, "model": pretransform_config["config"]} + autoencoder = create_autoencoder_from_config(autoencoder_config) + + scale = pretransform_config.get("scale", 1.0) + model_half = pretransform_config.get("model_half", False) + iterate_batch = pretransform_config.get("iterate_batch", False) + chunked = pretransform_config.get("chunked", False) + + pretransform = AutoencoderPretransform(autoencoder, scale=scale, model_half=model_half, iterate_batch=iterate_batch, chunked=chunked) + elif pretransform_type == 'wavelet': + from .pretransforms import WaveletPretransform + + wavelet_config = pretransform_config["config"] + channels = wavelet_config["channels"] + levels = wavelet_config["levels"] + wavelet = wavelet_config["wavelet"] + + pretransform = WaveletPretransform(channels, levels, wavelet) + elif pretransform_type == 'pqmf': + from .pretransforms import PQMFPretransform + pqmf_config = pretransform_config["config"] + pretransform = PQMFPretransform(**pqmf_config) + elif pretransform_type == 'dac_pretrained': + from .pretransforms import PretrainedDACPretransform + pretrained_dac_config = pretransform_config["config"] + pretransform = PretrainedDACPretransform(**pretrained_dac_config) + elif pretransform_type == "audiocraft_pretrained": + from .pretransforms import AudiocraftCompressionPretransform + + audiocraft_config = pretransform_config["config"] + pretransform = AudiocraftCompressionPretransform(**audiocraft_config) + else: + raise NotImplementedError(f'Unknown pretransform type: {pretransform_type}') + + enable_grad = pretransform_config.get('enable_grad', False) + pretransform.enable_grad = enable_grad + + pretransform.eval().requires_grad_(pretransform.enable_grad) + + return pretransform + +def create_bottleneck_from_config(bottleneck_config): + bottleneck_type = bottleneck_config.get('type', None) + + assert bottleneck_type is not None, 'type must be specified in bottleneck config' + + if bottleneck_type == 'tanh': + from .bottleneck import TanhBottleneck + return TanhBottleneck() + elif bottleneck_type == 'vae': + from .bottleneck import VAEBottleneck + return VAEBottleneck() + elif bottleneck_type == 'rvq': + from .bottleneck import RVQBottleneck + + quantizer_params = { + "dim": 128, + "codebook_size": 1024, + "num_quantizers": 8, + "decay": 0.99, + "kmeans_init": True, + "kmeans_iters": 50, + "threshold_ema_dead_code": 2, + } + + quantizer_params.update(bottleneck_config["config"]) + + return RVQBottleneck(**quantizer_params) + elif bottleneck_type == "dac_rvq": + from .bottleneck import DACRVQBottleneck + + return DACRVQBottleneck(**bottleneck_config["config"]) + + elif bottleneck_type == 'rvq_vae': + from .bottleneck import RVQVAEBottleneck + + quantizer_params = { + "dim": 128, + "codebook_size": 1024, + "num_quantizers": 8, + "decay": 0.99, + "kmeans_init": True, + "kmeans_iters": 50, + "threshold_ema_dead_code": 2, + } + + quantizer_params.update(bottleneck_config["config"]) + + return RVQVAEBottleneck(**quantizer_params) + + elif bottleneck_type == 'dac_rvq_vae': + from .bottleneck import DACRVQVAEBottleneck + return DACRVQVAEBottleneck(**bottleneck_config["config"]) + elif bottleneck_type == 'l2_norm': + from .bottleneck import L2Bottleneck + return L2Bottleneck() + elif bottleneck_type == "wasserstein": + from .bottleneck import WassersteinBottleneck + return WassersteinBottleneck(**bottleneck_config.get("config", {})) + elif bottleneck_type == "fsq": + from .bottleneck import FSQBottleneck + return FSQBottleneck(**bottleneck_config["config"]) + else: + raise NotImplementedError(f'Unknown bottleneck type: {bottleneck_type}') diff --git a/stable_audio_tools/models/lm.py b/stable_audio_tools/models/lm.py new file mode 100644 index 0000000000000000000000000000000000000000..aeb444c432ba73c0a18551737c1e9dc7d400881e --- /dev/null +++ b/stable_audio_tools/models/lm.py @@ -0,0 +1,531 @@ +from dataclasses import dataclass +import torch +from tqdm.auto import trange +import typing as tp +from einops import rearrange +from torch import nn + +from .autoencoders import AudioAutoencoder +from .conditioners import MultiConditioner, create_multi_conditioner_from_conditioning_config +from .factory import create_pretransform_from_config +from .lm_backbone import AudioLMBackbone, XTransformersAudioLMBackbone, ContinuousTransformerAudioLMBackbone +from .pretransforms import Pretransform, AutoencoderPretransform, PretrainedDACPretransform, AudiocraftCompressionPretransform +from .utils import multinomial, sample_top_k, sample_top_p + +from audiocraft.modules.codebooks_patterns import ( + CodebooksPatternProvider, + DelayedPatternProvider, + MusicLMPattern, + ParallelPatternProvider, + UnrolledPatternProvider, + VALLEPattern, +) + +# Copied and modified from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/models/lm.py under MIT license +# License can be found in LICENSES/LICENSE_META.txt + +@dataclass +class LMOutput: + # The logits are already re-aligned with the input codes + # hence no extra shift is required, e.g. when computing CE + logits: torch.Tensor # [B, K, T, card] + mask: torch.Tensor # [B, K, T] + +# Wrapper for a multi-codebook language model +# Handles patterns and quantizer heads +class AudioLanguageModel(nn.Module): + def __init__( + self, + pattern_provider: CodebooksPatternProvider, + backbone: AudioLMBackbone, + num_quantizers: int, + codebook_size: int + ): + super().__init__() + + self.pattern_provider = pattern_provider + self.backbone = backbone + self.num_quantizers = num_quantizers + self.codebook_size = codebook_size + + self.masked_token_id = codebook_size + + # Per-quantizer embedders + # Add one for the mask embed + self.embeds = nn.ModuleList([nn.Embedding(codebook_size + 1, backbone.embed_dim) for _ in range(num_quantizers)]) + + # Per-quantizer output heads + self.quantizer_heads = nn.ModuleList([ + nn.Linear(backbone.embed_dim, codebook_size) for _ in range(num_quantizers) + ]) + + def forward(self, + sequence: torch.Tensor, #[batch, seq_len, + prepend_cond=None, #[batch, seq, channels] + prepend_cond_mask=None, + cross_attn_cond=None, #[batch, seq, channels], + **kwargs + ): + + batch, num_quantizers, seq_len = sequence.shape + + assert num_quantizers == self.num_quantizers, "Number of quantizers in sequence must match number of quantizers in model" + + backbone_input = sum([self.embeds[i](sequence[:, i]) for i in range(num_quantizers)]) # [batch, seq_len, embed_dim] + + output = self.backbone( + backbone_input, + cross_attn_cond=cross_attn_cond, + prepend_cond=prepend_cond, + prepend_cond_mask=prepend_cond_mask, + **kwargs + ) # [batch, seq_len, embed_dim] + + # Run output through quantizer heads + logits = torch.stack([self.quantizer_heads[i](output) for i in range(num_quantizers)], dim=1) # [batch, num_quantizers, seq_len, codebook_size] + + return logits + + def compute_logits( + self, + codes, #[batch, num_quantizers, seq_len] + **kwargs): + """ + Compute logits for a batch of codes, optionally conditioning on cross-attention and prepend conditioning + Handles translation between input sequence and pattern-shifted sequence + Only used during training + """ + + batch, _, seq_len = codes.shape + + pattern = self.pattern_provider.get_pattern(seq_len) + + # Apply the token pattern to the codes, shifting the codes as needed and masking out invalid steps + shifted_codes, _, _ = pattern.build_pattern_sequence( + codes, + self.masked_token_id, + keep_only_valid_steps=True + ) + + # Run the model to get logits for each quantizer [batch, num_quantizers, seq_len, codebook_size] + logits = self(shifted_codes, **kwargs) + + # Rearrange logits to prepare to revert pattern + logits = rearrange(logits, "b n s c -> b c n s") + + # Revert sequence logits back to original sequence length, removing masked steps + logits, _, logits_mask = pattern.revert_pattern_logits( + logits, float('nan'), keep_only_valid_steps=True + ) + + logits = rearrange(logits, "b c n t -> b n t c") + + logits_mask = logits_mask[None, :, :].expand(batch, -1, -1) # [batch, num_quantizers, seq_len] + + return LMOutput(logits=logits, mask=logits_mask) + +# Conditioning and generation wrapper for a multi-codebook language model +# Handles conditioning, CFG, generation, and encoding/decoding +class AudioLanguageModelWrapper(nn.Module): + def __init__( + self, + pretransform: Pretransform, + lm: AudioLanguageModel, + sample_rate: int, + min_input_length: int, + conditioner: MultiConditioner = None, + cross_attn_cond_ids: tp.List[str] = [], + prepend_cond_ids: tp.List[str] = [], + global_cond_ids: tp.List[str] = [] + ): + super().__init__() + + assert pretransform.is_discrete, "Pretransform must be discrete" + self.pretransform = pretransform + + self.pretransform.requires_grad_(False) + self.pretransform.eval() + + if isinstance(self.pretransform, AutoencoderPretransform): + self.num_quantizers = self.pretransform.model.bottleneck.num_quantizers + self.codebook_size = self.pretransform.model.bottleneck.codebook_size + elif isinstance(self.pretransform, PretrainedDACPretransform): + self.num_quantizers = self.pretransform.model.num_quantizers + self.codebook_size = self.pretransform.model.codebook_size + elif isinstance(self.pretransform, AudiocraftCompressionPretransform): + self.num_quantizers = self.pretransform.num_quantizers + self.codebook_size = self.pretransform.codebook_size + else: + raise NotImplementedError(f"Unrecognized pretransform type {type(self.pretransform)}") + + self.conditioner = conditioner + + self.lm = lm + + self.sample_rate = sample_rate + self.min_input_length = min_input_length + + self.cross_attn_cond_ids = cross_attn_cond_ids + self.prepend_cond_ids = prepend_cond_ids + self.global_cond_ids = global_cond_ids + + def get_conditioning_inputs(self, cond: tp.Dict[str, tp.Any], negative=False): + cross_attention_input = None + prepend_cond = None + prepend_cond_mask = None + global_cond = None + + if len(self.cross_attn_cond_ids) > 0: + # Concatenate all cross-attention inputs over the sequence dimension + # Assumes that the cross-attention inputs are of shape (batch, seq, channels) + cross_attention_input = torch.cat([cond[key][0] for key in self.cross_attn_cond_ids], dim=1) + + if len(self.prepend_cond_ids) > 0: + # Concatenate all prepend conditioning inputs over the sequence dimension + # Assumes that the prepend conditioning inputs are of shape (batch, seq, channels) + prepend_cond = torch.cat([cond[key][0] for key in self.prepend_cond_ids], dim=1) + prepend_cond_mask = torch.cat([cond[key][1] for key in self.prepend_cond_ids], dim=1) + + if len(self.global_cond_ids) > 0: + # Concatenate all global conditioning inputs over the channel dimension + # Assumes that the global conditioning inputs are of shape (batch, channels) + global_cond = torch.cat([cond[key][0] for key in self.global_cond_ids], dim=-1) + if len(global_cond.shape) == 3: + global_cond = global_cond.squeeze(1) + + if negative: + return { + "negative_cross_attn_cond": cross_attention_input, + "negative_prepend_cond": prepend_cond, + "negative_prepend_cond_mask": prepend_cond_mask, + "negative_global_cond": global_cond + } + else: + return { + "cross_attn_cond": cross_attention_input, + "prepend_cond": prepend_cond, + "prepend_cond_mask": prepend_cond_mask, + "global_cond": global_cond + } + + def compute_logits( + self, + codes, + condition_tensors=None, + cfg_dropout_prob=0.0, + **kwargs + ): + """ + Compute logits for a batch of codes, and translates from conditioning inputs to model inputs + Handles CFG dropout + """ + + if condition_tensors is None: + condition_tensors = {} + + conditioning_inputs = self.get_conditioning_inputs(condition_tensors) + + cross_attn_cond = conditioning_inputs["cross_attn_cond"] + prepend_cond = conditioning_inputs["prepend_cond"] + prepend_cond_mask = conditioning_inputs["prepend_cond_mask"] + global_cond = conditioning_inputs["global_cond"] + + if cfg_dropout_prob > 0.0: + if cross_attn_cond is not None: + null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) + dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool) + cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond) + + if prepend_cond is not None: + null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) + dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool) + prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond) + + if global_cond is not None: + null_embed = torch.zeros_like(global_cond, device=global_cond.device) + dropout_mask = torch.bernoulli(torch.full((global_cond.shape[0], 1), cfg_dropout_prob, device=global_cond.device)).to(torch.bool) + global_cond = torch.where(dropout_mask, null_embed, global_cond) + + return self.lm.compute_logits(codes, cross_attn_cond=cross_attn_cond, prepend_cond=prepend_cond, prepend_cond_mask=prepend_cond_mask, global_cond=global_cond, **kwargs) + + def _sample_next_token( + self, + sequence, #[batch, num_quantizers, seq_len] + conditioning_tensors=None, + cross_attn_use_cfg=True, + prepend_use_cfg=True, + global_use_cfg=True, + cfg_scale=1.0, + top_k=250, + top_p=0.0, + temp=1.0, + **kwargs + ): + """ + Sample the next token for a batch of codes, and translates from conditioning inputs to model inputs + Handles CFG inference + """ + + if conditioning_tensors is None: + conditioning_tensors = {} + + conditioning_inputs = self.get_conditioning_inputs(conditioning_tensors) + + cross_attn_cond = conditioning_inputs["cross_attn_cond"] + prepend_cond = conditioning_inputs["prepend_cond"] + prepend_cond_mask = conditioning_inputs["prepend_cond_mask"] + global_cond = conditioning_inputs["global_cond"] + + if cfg_scale != 1.0: + + # Batch size is doubled to account for negative samples + sequence = torch.cat([sequence, sequence], dim=0) + + if cross_attn_cond is not None and cross_attn_use_cfg: + null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) + + cross_attn_cond = torch.cat([cross_attn_cond, null_embed], dim=0) + + if prepend_cond is not None and prepend_use_cfg: + null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) + + prepend_cond = torch.cat([prepend_cond, null_embed], dim=0) + + if prepend_cond_mask is not None: + prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0) + + if global_cond is not None and global_use_cfg: + null_embed = torch.zeros_like(global_cond, device=global_cond.device) + + global_cond = torch.cat([global_cond, null_embed], dim=0) + + logits = self.lm(sequence, cross_attn_cond=cross_attn_cond, prepend_cond=prepend_cond, prepend_cond_mask=prepend_cond_mask, global_cond=global_cond, **kwargs) + + if cfg_scale != 1.0: + cond_logits, uncond_logits = logits.chunk(2, dim=0) + + logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale + + logits = rearrange(logits, "b n s c -> b n c s") # [batch, num_quantizers, codebook_size, seq_len] + + # Grab the logits for the last step + logits = logits[:, :, :, -1] # [batch, num_quantizers, codebook_size] + + # Apply top-k or top-p sampling + + if temp > 0: + probs = torch.softmax(logits / temp, dim=-1) + + if top_p > 0.0: + next_token = sample_top_p(probs, p=top_p) + elif top_k > 0: + next_token = sample_top_k(probs, k=top_k) + else: + next_token = multinomial(probs, num_samples=1) + + else: + next_token = torch.argmax(logits, dim=-1, keepdim=True) # [batch, num_quantizers, 1] + + return next_token + + @torch.no_grad() + def generate( + self, + max_gen_len: int = 256, + batch_size: tp.Optional[int] = None, + init_data: tp.Optional[torch.Tensor] = None, + conditioning: tp.Optional[tp.Dict[str, tp.Any]] = None, + conditioning_tensors: tp.Optional[tp.Dict[str, tp.Any]] = None, + callback: tp.Optional[tp.Callable[[int, int], None]] = None, + use_cache: bool = True, + cfg_scale: float = 1.0, + **kwargs + ): + device = next(self.parameters()).device + + if conditioning_tensors is None and conditioning is not None: + # Convert conditioning inputs to conditioning tensors + conditioning_tensors = self.conditioner(conditioning, device) + + # Check that batch size is consistent across inputs + possible_batch_sizes = [] + + if batch_size is not None: + possible_batch_sizes.append(batch_size) + elif init_data is not None: + possible_batch_sizes.append(init_data.shape[0]) + elif conditioning_tensors is not None: + # Assume that the first conditioning tensor has the batch dimension + possible_batch_sizes.append(conditioning_tensors[list(conditioning_tensors.keys())[0]][0].shape[0]) + else: + possible_batch_sizes.append(1) + + assert [x == possible_batch_sizes[0] for x in possible_batch_sizes], "Batch size must be consistent across inputs" + + batch_size = possible_batch_sizes[0] + + if init_data is None: + # Initialize with zeros + assert batch_size > 0 + init_data = torch.zeros((batch_size, self.num_quantizers, 0), device=device, dtype=torch.long) + + batch_size, num_quantizers, seq_len = init_data.shape + + start_offset = seq_len + assert start_offset < max_gen_len, "init data longer than max gen length" + + pattern = self.lm.pattern_provider.get_pattern(max_gen_len) + + unknown_token = -1 + + # Initialize the generated codes with the init data, padded with unknown tokens + gen_codes = torch.full((batch_size, num_quantizers, max_gen_len), unknown_token, device=device, dtype=torch.long) + gen_codes[:, :, :start_offset] = init_data # [batch, num_quantizers, max_gen_len] + + gen_sequence, _, mask = pattern.build_pattern_sequence(gen_codes, self.lm.masked_token_id) # [batch, num_quantizers, gen_sequence_len] + + start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset) + assert start_offset_sequence is not None + + # Generation + prev_offset = 0 + gen_sequence_len = gen_sequence.shape[-1] + + # Reset generation cache + if use_cache and self.lm.backbone.use_generation_cache: + self.lm.backbone.reset_generation_cache(max_gen_len, batch_size if cfg_scale == 1.0 else batch_size * 2) + + for offset in trange(start_offset_sequence, gen_sequence_len): + + # Get the full sequence up to the current offset + curr_sequence = gen_sequence[..., prev_offset:offset] + + next_token = self._sample_next_token( + curr_sequence, + conditioning_tensors=conditioning_tensors, + use_cache=use_cache, + cfg_scale=cfg_scale, + **kwargs + ) + + valid_mask = mask[..., offset:offset+1].expand(batch_size, -1, -1) + next_token[~valid_mask] = self.lm.masked_token_id + + # Update the generated sequence with the next token + gen_sequence[..., offset:offset+1] = torch.where( + gen_sequence[..., offset:offset+1] == unknown_token, + next_token, + gen_sequence[..., offset:offset+1] + ) + + if use_cache and self.lm.backbone.use_generation_cache: + # Only update the offset if caching is being used + prev_offset = offset + + self.lm.backbone.update_generation_cache(offset) + + if callback is not None: + # Callback to report progress + # Pass in the offset relative to the start of the sequence, and the length of the current sequence + callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence) + + assert not (gen_sequence == unknown_token).any(), "Unknown tokens in generated sequence" + + out_codes, _, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token) + + # sanity checks over the returned codes and corresponding masks + assert (out_codes[..., :max_gen_len] != unknown_token).all() + assert (out_mask[..., :max_gen_len] == 1).all() + + #out_codes = out_codes[..., 0:max_gen_len] + + return out_codes + + + def generate_audio( + self, + **kwargs + ): + """ + Generate audio from a batch of codes + """ + + codes = self.generate(**kwargs) + + audio = self.pretransform.decode_tokens(codes) + + return audio + + +def create_audio_lm_from_config(config): + model_config = config.get('model', None) + assert model_config is not None, 'model config must be specified in config' + + sample_rate = config.get('sample_rate', None) + assert sample_rate is not None, "Must specify sample_rate in config" + + lm_config = model_config.get('lm', None) + assert lm_config is not None, 'lm config must be specified in model config' + + codebook_pattern = lm_config.get("codebook_pattern", "delay") + + pattern_providers = { + 'parallel': ParallelPatternProvider, + 'delay': DelayedPatternProvider, + 'unroll': UnrolledPatternProvider, + 'valle': VALLEPattern, + 'musiclm': MusicLMPattern, + } + + pretransform_config = model_config.get("pretransform", None) + + pretransform = create_pretransform_from_config(pretransform_config, sample_rate) + + assert pretransform.is_discrete, "Pretransform must be discrete" + + min_input_length = pretransform.downsampling_ratio + + pattern_provider = pattern_providers[codebook_pattern](n_q=pretransform.num_quantizers) + + conditioning_config = model_config.get('conditioning', None) + + conditioner = None + if conditioning_config is not None: + conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config) + + cross_attn_cond_ids = lm_config.get('cross_attention_cond_ids', []) + prepend_cond_ids = lm_config.get('prepend_cond_ids', []) + global_cond_ids = lm_config.get('global_cond_ids', []) + + lm_type = lm_config.get("type", None) + lm_model_config = lm_config.get("config", None) + + assert lm_type is not None, "Must specify lm type in lm config" + assert lm_model_config is not None, "Must specify lm model config in lm config" + + if lm_type == "x-transformers": + backbone = XTransformersAudioLMBackbone(**lm_model_config) + elif lm_type == "continuous_transformer": + backbone = ContinuousTransformerAudioLMBackbone(**lm_model_config) + else: + raise NotImplementedError(f"Unrecognized lm type {lm_type}") + + lm = AudioLanguageModel( + pattern_provider=pattern_provider, + backbone=backbone, + num_quantizers=pretransform.num_quantizers, + codebook_size=pretransform.codebook_size + ) + + model = AudioLanguageModelWrapper( + pretransform=pretransform, + lm=lm, + conditioner=conditioner, + sample_rate=sample_rate, + min_input_length=min_input_length, + cross_attn_cond_ids=cross_attn_cond_ids, + prepend_cond_ids=prepend_cond_ids, + global_cond_ids=global_cond_ids + ) + + return model \ No newline at end of file diff --git a/stable_audio_tools/models/lm_backbone.py b/stable_audio_tools/models/lm_backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..6029c539de72cad7c9f7976871c09ff7c1f439e2 --- /dev/null +++ b/stable_audio_tools/models/lm_backbone.py @@ -0,0 +1,157 @@ +import torch +from torch import nn +from x_transformers import ContinuousTransformerWrapper, Decoder + +from .transformer import ContinuousTransformer + +# Interface for backbone of a language model +# Handles conditioning and cross-attention +# Does not have to deal with patterns or quantizer heads +class AudioLMBackbone(nn.Module): + def __init__(self, embed_dim: int, use_generation_cache=False, **kwargs): + super().__init__() + + self.embed_dim = embed_dim + self.use_generation_cache = use_generation_cache + + def forward( + self, + x, + cross_attn_cond=None, + prepend_cond=None, + prepend_cond_mask=None, + global_cond=None, + use_cache=False, + **kwargs + ): + raise NotImplementedError + + def reset_generation_cache( + self, + max_seq_len, + batch_size, + dtype=None + ): + pass + + def update_generation_cache( + self, + seqlen_offset + ): + pass + +class XTransformersAudioLMBackbone(AudioLMBackbone): + def __init__(self, + embed_dim: int, + cross_attn_cond_dim: int = 0, + prepend_cond_dim: int = 0, + **kwargs): + super().__init__(embed_dim=embed_dim) + + # Embeddings are done in the AudioLanguageModel, so we use the continuous-input transformer + self.model = ContinuousTransformerWrapper( + dim_in=embed_dim, + dim_out=embed_dim, + max_seq_len=0, #Not relevant without absolute positional embeds, + attn_layers=Decoder( + dim=embed_dim, + attn_flash = True, + cross_attend = cross_attn_cond_dim > 0, + zero_init_branch_output=True, + use_abs_pos_emb = False, + rotary_pos_emb=True, + ff_swish = True, + ff_glu = True, + **kwargs + ) + ) + + if prepend_cond_dim > 0: + # Prepend conditioning + self.to_prepend_embed = nn.Sequential( + nn.Linear(prepend_cond_dim, embed_dim, bias=False), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=False) + ) + + if cross_attn_cond_dim > 0: + # Cross-attention conditioning + self.to_cross_attn_embed = nn.Sequential( + nn.Linear(cross_attn_cond_dim, embed_dim, bias=False), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=False) + ) + + def forward(self, x, mask=None, prepend_cond=None, prepend_cond_mask=None, cross_attn_cond=None, global_cond=None, use_cache=False): + + prepend_length = 0 + if prepend_cond is not None: + # Project the prepend conditioning to the embedding dimension + prepend_cond = self.to_prepend_embed(prepend_cond) + prepend_length = prepend_cond.shape[1] + + if prepend_cond_mask is not None: + # Cast mask to bool + prepend_cond_mask = prepend_cond_mask.bool() + + if cross_attn_cond is not None: + # Project the cross-attention conditioning to the embedding dimension + cross_attn_cond = self.to_cross_attn_embed(cross_attn_cond) + + return self.model(x, mask=mask, context=cross_attn_cond, prepend_embeds=prepend_cond, prepend_mask=prepend_cond_mask)[:, prepend_length:, :] + +class ContinuousTransformerAudioLMBackbone(AudioLMBackbone): + def __init__(self, + embed_dim: int, + cross_attn_cond_dim: int = 0, + prepend_cond_dim: int = 0, + project_cross_attn_cond: bool = False, + **kwargs): + super().__init__(embed_dim=embed_dim) + + # Embeddings are done in the AudioLanguageModel, so we use the continuous-input transformer + self.model = ContinuousTransformer( + dim=embed_dim, + dim_in=embed_dim, + dim_out=embed_dim, + cross_attend = cross_attn_cond_dim > 0, + cond_token_dim = embed_dim if project_cross_attn_cond else cross_attn_cond_dim, + causal=True, + **kwargs + ) + + if prepend_cond_dim > 0: + # Prepend conditioning + self.to_prepend_embed = nn.Sequential( + nn.Linear(prepend_cond_dim, embed_dim, bias=False), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=False) + ) + + if cross_attn_cond_dim > 0 and project_cross_attn_cond: + # Cross-attention conditioning + self.to_cross_attn_embed = nn.Sequential( + nn.Linear(cross_attn_cond_dim, embed_dim, bias=False), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=False) + ) + else: + self.to_cross_attn_embed = nn.Identity() + + def forward(self, x, mask=None, prepend_cond=None, prepend_cond_mask=None, cross_attn_cond=None, global_cond=None, use_cache=False): + + prepend_length = 0 + if prepend_cond is not None: + # Project the prepend conditioning to the embedding dimension + prepend_cond = self.to_prepend_embed(prepend_cond) + prepend_length = prepend_cond.shape[1] + + if prepend_cond_mask is not None: + # Cast mask to bool + prepend_cond_mask = prepend_cond_mask.bool() + + if cross_attn_cond is not None: + # Project the cross-attention conditioning to the embedding dimension + cross_attn_cond = self.to_cross_attn_embed(cross_attn_cond) + + return self.model(x, mask=mask, context=cross_attn_cond, prepend_embeds=prepend_cond, prepend_mask=prepend_cond_mask)[:, prepend_length:, :] diff --git a/stable_audio_tools/models/local_attention.py b/stable_audio_tools/models/local_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..893ce11fce1f263dd02ff2a2ebe8b5e67426f83f --- /dev/null +++ b/stable_audio_tools/models/local_attention.py @@ -0,0 +1,278 @@ +import torch + +from einops import rearrange +from torch import nn + +from .blocks import AdaRMSNorm +from .transformer import Attention, FeedForward, RotaryEmbedding, LayerNorm + +def checkpoint(function, *args, **kwargs): + kwargs.setdefault("use_reentrant", False) + return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) + +# Adapted from https://github.com/lucidrains/local-attention/blob/master/local_attention/transformer.py +class ContinuousLocalTransformer(nn.Module): + def __init__( + self, + *, + dim, + depth, + dim_in = None, + dim_out = None, + causal = False, + local_attn_window_size = 64, + heads = 8, + ff_mult = 2, + cond_dim = 0, + cross_attn_cond_dim = 0, + **kwargs + ): + super().__init__() + + dim_head = dim//heads + + self.layers = nn.ModuleList([]) + + self.project_in = nn.Linear(dim_in, dim) if dim_in is not None else nn.Identity() + + self.project_out = nn.Linear(dim, dim_out) if dim_out is not None else nn.Identity() + + self.local_attn_window_size = local_attn_window_size + + self.cond_dim = cond_dim + + self.cross_attn_cond_dim = cross_attn_cond_dim + + self.rotary_pos_emb = RotaryEmbedding(max(dim_head // 2, 32)) + + for _ in range(depth): + + self.layers.append(nn.ModuleList([ + AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim), + Attention( + dim=dim, + dim_heads=dim_head, + causal=causal, + zero_init_output=True, + natten_kernel_size=local_attn_window_size, + ), + Attention( + dim=dim, + dim_heads=dim_head, + dim_context = cross_attn_cond_dim, + zero_init_output=True + ) if self.cross_attn_cond_dim > 0 else nn.Identity(), + AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim), + FeedForward(dim = dim, mult = ff_mult, no_bias=True) + ])) + + def forward(self, x, mask = None, cond = None, cross_attn_cond = None, cross_attn_cond_mask = None, prepend_cond = None): + + x = checkpoint(self.project_in, x) + + if prepend_cond is not None: + x = torch.cat([prepend_cond, x], dim=1) + + pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1]) + + for attn_norm, attn, xattn, ff_norm, ff in self.layers: + + residual = x + if cond is not None: + x = checkpoint(attn_norm, x, cond) + else: + x = checkpoint(attn_norm, x) + + x = checkpoint(attn, x, mask = mask, rotary_pos_emb=pos_emb) + residual + + if cross_attn_cond is not None: + x = checkpoint(xattn, x, context=cross_attn_cond, context_mask=cross_attn_cond_mask) + x + + residual = x + + if cond is not None: + x = checkpoint(ff_norm, x, cond) + else: + x = checkpoint(ff_norm, x) + + x = checkpoint(ff, x) + residual + + return checkpoint(self.project_out, x) + +class TransformerDownsampleBlock1D(nn.Module): + def __init__( + self, + in_channels, + embed_dim = 768, + depth = 3, + heads = 12, + downsample_ratio = 2, + local_attn_window_size = 64, + **kwargs + ): + super().__init__() + + self.downsample_ratio = downsample_ratio + + self.transformer = ContinuousLocalTransformer( + dim=embed_dim, + depth=depth, + heads=heads, + local_attn_window_size=local_attn_window_size, + **kwargs + ) + + self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity() + + self.project_down = nn.Linear(embed_dim * self.downsample_ratio, embed_dim, bias=False) + + + def forward(self, x): + + x = checkpoint(self.project_in, x) + + # Compute + x = self.transformer(x) + + # Trade sequence length for channels + x = rearrange(x, "b (n r) c -> b n (c r)", r=self.downsample_ratio) + + # Project back to embed dim + x = checkpoint(self.project_down, x) + + return x + +class TransformerUpsampleBlock1D(nn.Module): + def __init__( + self, + in_channels, + embed_dim, + depth = 3, + heads = 12, + upsample_ratio = 2, + local_attn_window_size = 64, + **kwargs + ): + super().__init__() + + self.upsample_ratio = upsample_ratio + + self.transformer = ContinuousLocalTransformer( + dim=embed_dim, + depth=depth, + heads=heads, + local_attn_window_size = local_attn_window_size, + **kwargs + ) + + self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity() + + self.project_up = nn.Linear(embed_dim, embed_dim * self.upsample_ratio, bias=False) + + def forward(self, x): + + # Project to embed dim + x = checkpoint(self.project_in, x) + + # Project to increase channel dim + x = checkpoint(self.project_up, x) + + # Trade channels for sequence length + x = rearrange(x, "b n (c r) -> b (n r) c", r=self.upsample_ratio) + + # Compute + x = self.transformer(x) + + return x + + +class TransformerEncoder1D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + embed_dims = [96, 192, 384, 768], + heads = [12, 12, 12, 12], + depths = [3, 3, 3, 3], + ratios = [2, 2, 2, 2], + local_attn_window_size = 64, + **kwargs + ): + super().__init__() + + layers = [] + + for layer in range(len(depths)): + prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0] + + layers.append( + TransformerDownsampleBlock1D( + in_channels = prev_dim, + embed_dim = embed_dims[layer], + heads = heads[layer], + depth = depths[layer], + downsample_ratio = ratios[layer], + local_attn_window_size = local_attn_window_size, + **kwargs + ) + ) + + self.layers = nn.Sequential(*layers) + + self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False) + self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False) + + def forward(self, x): + x = rearrange(x, "b c n -> b n c") + x = checkpoint(self.project_in, x) + x = self.layers(x) + x = checkpoint(self.project_out, x) + x = rearrange(x, "b n c -> b c n") + + return x + + +class TransformerDecoder1D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + embed_dims = [768, 384, 192, 96], + heads = [12, 12, 12, 12], + depths = [3, 3, 3, 3], + ratios = [2, 2, 2, 2], + local_attn_window_size = 64, + **kwargs + ): + + super().__init__() + + layers = [] + + for layer in range(len(depths)): + prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0] + + layers.append( + TransformerUpsampleBlock1D( + in_channels = prev_dim, + embed_dim = embed_dims[layer], + heads = heads[layer], + depth = depths[layer], + upsample_ratio = ratios[layer], + local_attn_window_size = local_attn_window_size, + **kwargs + ) + ) + + self.layers = nn.Sequential(*layers) + + self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False) + self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False) + + def forward(self, x): + x = rearrange(x, "b c n -> b n c") + x = checkpoint(self.project_in, x) + x = self.layers(x) + x = checkpoint(self.project_out, x) + x = rearrange(x, "b n c -> b c n") + return x \ No newline at end of file diff --git a/stable_audio_tools/models/musicgen.py b/stable_audio_tools/models/musicgen.py new file mode 100644 index 0000000000000000000000000000000000000000..0454fe2de4e09e670b636294cb3502ec6400a678 --- /dev/null +++ b/stable_audio_tools/models/musicgen.py @@ -0,0 +1,161 @@ +import torch +import typing as tp +from audiocraft.models import MusicGen, CompressionModel, LMModel +import audiocraft.quantization as qt +from .autoencoders import AudioAutoencoder +from .bottleneck import DACRVQBottleneck, DACRVQVAEBottleneck + +from audiocraft.modules.codebooks_patterns import ( + DelayedPatternProvider, + MusicLMPattern, + ParallelPatternProvider, + UnrolledPatternProvider, + VALLEPattern, +) + +from audiocraft.modules.conditioners import ( + ConditionFuser, + ConditioningProvider, + T5Conditioner, +) + +def create_musicgen_from_config(config): + model_config = config.get('model', None) + assert model_config is not None, 'model config must be specified in config' + + if model_config.get("pretrained", False): + model = MusicGen.get_pretrained(model_config["pretrained"], device="cpu") + + if model_config.get("reinit_lm", False): + model.lm._init_weights("gaussian", "current", True) + + return model + + # Create MusicGen model from scratch + compression_config = model_config.get('compression', None) + assert compression_config is not None, 'compression config must be specified in model config' + + compression_type = compression_config.get('type', None) + assert compression_type is not None, 'type must be specified in compression config' + + if compression_type == 'pretrained': + compression_model = CompressionModel.get_pretrained(compression_config["config"]["name"]) + elif compression_type == "dac_rvq_ae": + from .autoencoders import create_autoencoder_from_config + autoencoder = create_autoencoder_from_config({"model": compression_config["config"], "sample_rate": config["sample_rate"]}) + autoencoder.load_state_dict(torch.load(compression_config["ckpt_path"], map_location="cpu")["state_dict"]) + compression_model = DACRVQCompressionModel(autoencoder) + + lm_config = model_config.get('lm', None) + assert lm_config is not None, 'lm config must be specified in model config' + + codebook_pattern = lm_config.pop("codebook_pattern", "delay") + + pattern_providers = { + 'parallel': ParallelPatternProvider, + 'delay': DelayedPatternProvider, + 'unroll': UnrolledPatternProvider, + 'valle': VALLEPattern, + 'musiclm': MusicLMPattern, + } + + pattern_provider = pattern_providers[codebook_pattern](n_q=compression_model.num_codebooks) + + conditioning_config = model_config.get("conditioning", {}) + + condition_output_dim = conditioning_config.get("output_dim", 768) + + condition_provider = ConditioningProvider( + conditioners = { + "description": T5Conditioner( + name="t5-base", + output_dim=condition_output_dim, + word_dropout=0.3, + normalize_text=False, + finetune=False, + device="cpu" + ) + } + ) + + condition_fuser = ConditionFuser(fuse2cond={ + "cross": ["description"], + "prepend": [], + "sum": [] + }) + + lm = LMModel( + pattern_provider = pattern_provider, + condition_provider = condition_provider, + fuser = condition_fuser, + n_q = compression_model.num_codebooks, + card = compression_model.cardinality, + **lm_config + ) + + + model = MusicGen( + name = model_config.get("name", "musicgen-scratch"), + compression_model = compression_model, + lm = lm, + max_duration=30 + ) + + return model + +class DACRVQCompressionModel(CompressionModel): + def __init__(self, autoencoder: AudioAutoencoder): + super().__init__() + self.model = autoencoder.eval() + + assert isinstance(self.model.bottleneck, DACRVQBottleneck) or isinstance(self.model.bottleneck, DACRVQVAEBottleneck), "Autoencoder must have a DACRVQBottleneck or DACRVQVAEBottleneck" + + self.n_quantizers = self.model.bottleneck.num_quantizers + + def forward(self, x: torch.Tensor) -> qt.QuantizedResult: + raise NotImplementedError("Forward and training with DAC RVQ not supported") + + def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + _, info = self.model.encode(x, return_info=True, n_quantizers=self.n_quantizers) + codes = info["codes"] + return codes, None + + def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): + assert scale is None + z_q = self.decode_latent(codes) + return self.model.decode(z_q) + + def decode_latent(self, codes: torch.Tensor): + """Decode from the discrete codes to continuous latent space.""" + return self.model.bottleneck.quantizer.from_codes(codes)[0] + + @property + def channels(self) -> int: + return self.model.io_channels + + @property + def frame_rate(self) -> float: + return self.model.sample_rate / self.model.downsampling_ratio + + @property + def sample_rate(self) -> int: + return self.model.sample_rate + + @property + def cardinality(self) -> int: + return self.model.bottleneck.quantizer.codebook_size + + @property + def num_codebooks(self) -> int: + return self.n_quantizers + + @property + def total_codebooks(self) -> int: + self.model.bottleneck.num_quantizers + + def set_num_codebooks(self, n: int): + """Set the active number of codebooks used by the quantizer. + """ + assert n >= 1 + assert n <= self.total_codebooks + self.n_quantizers = n \ No newline at end of file diff --git a/stable_audio_tools/models/pqmf.py b/stable_audio_tools/models/pqmf.py new file mode 100644 index 0000000000000000000000000000000000000000..007fdb51ec797554c1cdd4d9363894d743d970bf --- /dev/null +++ b/stable_audio_tools/models/pqmf.py @@ -0,0 +1,393 @@ +import math +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange +from scipy.optimize import fmin +from scipy.signal import firwin, kaiser, kaiser_beta, kaiserord + +class PQMF(nn.Module): + """ + Pseudo Quadrature Mirror Filter (PQMF) for multiband signal decomposition and reconstruction. + Uses polyphase representation which is computationally more efficient for real-time. + + Parameters: + - attenuation (int): Desired attenuation of the rejected frequency bands, usually between 80 and 120 dB. + - num_bands (int): Number of desired frequency bands. It must be a power of 2. + """ + + def __init__(self, attenuation, num_bands): + super(PQMF, self).__init__() + + # Ensure num_bands is a power of 2 + is_power_of_2 = (math.log2(num_bands) == int(math.log2(num_bands))) + assert is_power_of_2, "'num_bands' must be a power of 2." + + # Create the prototype filter + prototype_filter = design_prototype_filter(attenuation, num_bands) + filter_bank = generate_modulated_filter_bank(prototype_filter, num_bands) + padded_filter_bank = pad_to_nearest_power_of_two(filter_bank) + + # Register filters and settings + self.register_buffer("filter_bank", padded_filter_bank) + self.register_buffer("prototype", prototype_filter) + self.num_bands = num_bands + + def forward(self, signal): + """Decompose the signal into multiple frequency bands.""" + # If signal is not a pytorch tensor of Batch x Channels x Length, convert it + signal = prepare_signal_dimensions(signal) + # The signal length must be a multiple of num_bands. Pad it with zeros. + signal = pad_signal(signal, self.num_bands) + # run it + signal = polyphase_analysis(signal, self.filter_bank) + return apply_alias_cancellation(signal) + + def inverse(self, bands): + """Reconstruct the original signal from the frequency bands.""" + bands = apply_alias_cancellation(bands) + return polyphase_synthesis(bands, self.filter_bank) + + +def prepare_signal_dimensions(signal): + """ + Rearrange signal into Batch x Channels x Length. + + Parameters + ---------- + signal : torch.Tensor or numpy.ndarray + The input signal. + + Returns + ------- + torch.Tensor + Preprocessed signal tensor. + """ + # Convert numpy to torch tensor + if isinstance(signal, np.ndarray): + signal = torch.from_numpy(signal) + + # Ensure tensor + if not isinstance(signal, torch.Tensor): + raise ValueError("Input should be either a numpy array or a PyTorch tensor.") + + # Modify dimension of signal to Batch x Channels x Length + if signal.dim() == 1: + # This is just a mono signal. Unsqueeze to 1 x 1 x Length + signal = signal.unsqueeze(0).unsqueeze(0) + elif signal.dim() == 2: + # This is a multi-channel signal (e.g. stereo) + # Rearrange so that larger dimension (Length) is last + if signal.shape[0] > signal.shape[1]: + signal = signal.T + # Unsqueeze to 1 x Channels x Length + signal = signal.unsqueeze(0) + return signal + +def pad_signal(signal, num_bands): + """ + Pads the signal to make its length divisible by the given number of bands. + + Parameters + ---------- + signal : torch.Tensor + The input signal tensor, where the last dimension represents the signal length. + + num_bands : int + The number of bands by which the signal length should be divisible. + + Returns + ------- + torch.Tensor + The padded signal tensor. If the original signal length was already divisible + by num_bands, returns the original signal unchanged. + """ + remainder = signal.shape[-1] % num_bands + if remainder > 0: + padding_size = num_bands - remainder + signal = nn.functional.pad(signal, (0, padding_size)) + return signal + +def generate_modulated_filter_bank(prototype_filter, num_bands): + """ + Generate a QMF bank of cosine modulated filters based on a given prototype filter. + + Parameters + ---------- + prototype_filter : torch.Tensor + The prototype filter used as the basis for modulation. + num_bands : int + The number of desired subbands or filters. + + Returns + ------- + torch.Tensor + A bank of cosine modulated filters. + """ + + # Initialize indices for modulation. + subband_indices = torch.arange(num_bands).reshape(-1, 1) + + # Calculate the length of the prototype filter. + filter_length = prototype_filter.shape[-1] + + # Generate symmetric time indices centered around zero. + time_indices = torch.arange(-(filter_length // 2), (filter_length // 2) + 1) + + # Calculate phase offsets to ensure orthogonality between subbands. + phase_offsets = (-1)**subband_indices * np.pi / 4 + + # Compute the cosine modulation function. + modulation = torch.cos( + (2 * subband_indices + 1) * np.pi / (2 * num_bands) * time_indices + phase_offsets + ) + + # Apply modulation to the prototype filter. + modulated_filters = 2 * prototype_filter * modulation + + return modulated_filters + + +def design_kaiser_lowpass(angular_cutoff, attenuation, filter_length=None): + """ + Design a lowpass filter using the Kaiser window. + + Parameters + ---------- + angular_cutoff : float + The angular frequency cutoff of the filter. + attenuation : float + The desired stopband attenuation in decibels (dB). + filter_length : int, optional + Desired length of the filter. If not provided, it's computed based on the given specs. + + Returns + ------- + ndarray + The designed lowpass filter coefficients. + """ + + estimated_length, beta = kaiserord(attenuation, angular_cutoff / np.pi) + + # Ensure the estimated length is odd. + estimated_length = 2 * (estimated_length // 2) + 1 + + if filter_length is None: + filter_length = estimated_length + + return firwin(filter_length, angular_cutoff, window=('kaiser', beta), scale=False, nyq=np.pi) + + +def evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length): + """ + Evaluate the filter's objective value based on the criteria from https://ieeexplore.ieee.org/document/681427 + + Parameters + ---------- + angular_cutoff : float + Angular frequency cutoff of the filter. + attenuation : float + Desired stopband attenuation in dB. + num_bands : int + Number of bands for the multiband filter system. + filter_length : int, optional + Desired length of the filter. + + Returns + ------- + float + The computed objective (loss) value for the given filter specs. + """ + + filter_coeffs = design_kaiser_lowpass(angular_cutoff, attenuation, filter_length) + convolved_filter = np.convolve(filter_coeffs, filter_coeffs[::-1], "full") + + return np.max(np.abs(convolved_filter[convolved_filter.shape[-1] // 2::2 * num_bands][1:])) + + +def design_prototype_filter(attenuation, num_bands, filter_length=None): + """ + Design the optimal prototype filter for a multiband system given the desired specs. + + Parameters + ---------- + attenuation : float + The desired stopband attenuation in dB. + num_bands : int + Number of bands for the multiband filter system. + filter_length : int, optional + Desired length of the filter. If not provided, it's computed based on the given specs. + + Returns + ------- + ndarray + The optimal prototype filter coefficients. + """ + + optimal_angular_cutoff = fmin(lambda angular_cutoff: evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length), + 1 / num_bands, disp=0)[0] + + prototype_filter = design_kaiser_lowpass(optimal_angular_cutoff, attenuation, filter_length) + return torch.tensor(prototype_filter, dtype=torch.float32) + +def pad_to_nearest_power_of_two(x): + """ + Pads the input tensor 'x' on both sides such that its last dimension + becomes the nearest larger power of two. + + Parameters: + ----------- + x : torch.Tensor + The input tensor to be padded. + + Returns: + -------- + torch.Tensor + The padded tensor. + """ + current_length = x.shape[-1] + target_length = 2**math.ceil(math.log2(current_length)) + + total_padding = target_length - current_length + left_padding = total_padding // 2 + right_padding = total_padding - left_padding + + return nn.functional.pad(x, (left_padding, right_padding)) + +def apply_alias_cancellation(x): + """ + Applies alias cancellation by inverting the sign of every + second element of every second row, starting from the second + row's first element in a tensor. + + This operation helps ensure that the aliasing introduced in + each band during the decomposition will be counteracted during + the reconstruction. + + Parameters: + ----------- + x : torch.Tensor + The input tensor. + + Returns: + -------- + torch.Tensor + Tensor with specific elements' sign inverted for alias cancellation. + """ + + # Create a mask of the same shape as 'x', initialized with all ones + mask = torch.ones_like(x) + + # Update specific elements in the mask to -1 to perform inversion + mask[..., 1::2, ::2] = -1 + + # Apply the mask to the input tensor 'x' + return x * mask + +def ensure_odd_length(tensor): + """ + Pads the last dimension of a tensor to ensure its size is odd. + + Parameters: + ----------- + tensor : torch.Tensor + Input tensor whose last dimension might need padding. + + Returns: + -------- + torch.Tensor + The original tensor if its last dimension was already odd, + or the padded tensor with an odd-sized last dimension. + """ + + last_dim_size = tensor.shape[-1] + + if last_dim_size % 2 == 0: + tensor = nn.functional.pad(tensor, (0, 1)) + + return tensor + +def polyphase_analysis(signal, filter_bank): + """ + Applies the polyphase method to efficiently analyze the signal using a filter bank. + + Parameters: + ----------- + signal : torch.Tensor + Input signal tensor with shape (Batch x Channels x Length). + + filter_bank : torch.Tensor + Filter bank tensor with shape (Bands x Length). + + Returns: + -------- + torch.Tensor + Signal split into sub-bands. (Batch x Channels x Bands x Length) + """ + + num_bands = filter_bank.shape[0] + num_channels = signal.shape[1] + + # Rearrange signal for polyphase processing. + # Also combine Batch x Channel into one dimension for now. + #signal = rearrange(signal, "b c (t n) -> b (c n) t", n=num_bands) + signal = rearrange(signal, "b c (t n) -> (b c) n t", n=num_bands) + + # Rearrange the filter bank for matching signal shape + filter_bank = rearrange(filter_bank, "c (t n) -> c n t", n=num_bands) + + # Apply convolution with appropriate padding to maintain spatial dimensions + padding = filter_bank.shape[-1] // 2 + filtered_signal = nn.functional.conv1d(signal, filter_bank, padding=padding) + + # Truncate the last dimension post-convolution to adjust the output shape + filtered_signal = filtered_signal[..., :-1] + # Rearrange the first dimension back into Batch x Channels + filtered_signal = rearrange(filtered_signal, "(b c) n t -> b c n t", c=num_channels) + + return filtered_signal + +def polyphase_synthesis(signal, filter_bank): + """ + Polyphase Inverse: Apply polyphase filter bank synthesis to reconstruct a signal. + + Parameters + ---------- + signal : torch.Tensor + Decomposed signal to be reconstructed (shape: Batch x Channels x Bands x Length). + + filter_bank : torch.Tensor + Analysis filter bank (shape: Bands x Length). + + should_rearrange : bool, optional + Flag to determine if the filters should be rearranged for polyphase synthesis. Default is True. + + Returns + ------- + torch.Tensor + Reconstructed signal (shape: Batch x Channels X Length) + """ + + num_bands = filter_bank.shape[0] + num_channels = signal.shape[1] + + # Rearrange the filter bank + filter_bank = filter_bank.flip(-1) + filter_bank = rearrange(filter_bank, "c (t n) -> n c t", n=num_bands) + + # Combine Batch x Channels into one dimension for now. + signal = rearrange(signal, "b c n t -> (b c) n t") + + # Apply convolution with appropriate padding + padding_amount = filter_bank.shape[-1] // 2 + 1 + reconstructed_signal = nn.functional.conv1d(signal, filter_bank, padding=int(padding_amount)) + + # Scale the result + reconstructed_signal = reconstructed_signal[..., :-1] * num_bands + + # Reorganize the output and truncate + reconstructed_signal = reconstructed_signal.flip(1) + reconstructed_signal = rearrange(reconstructed_signal, "(b c) n t -> b c (t n)", c=num_channels, n=num_bands) + reconstructed_signal = reconstructed_signal[..., 2 * filter_bank.shape[1]:] + + return reconstructed_signal \ No newline at end of file diff --git a/stable_audio_tools/models/pretrained.py b/stable_audio_tools/models/pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..e83af343587da91af92218f309c969c5a975b5ed --- /dev/null +++ b/stable_audio_tools/models/pretrained.py @@ -0,0 +1,25 @@ +import json + +from .factory import create_model_from_config +from .utils import load_ckpt_state_dict + +from huggingface_hub import hf_hub_download + +def get_pretrained_model(name: str): + + model_config_path = hf_hub_download(name, filename="model_config.json", repo_type='model') + + with open(model_config_path) as f: + model_config = json.load(f) + + model = create_model_from_config(model_config) + + # Try to download the model.safetensors file first, if it doesn't exist, download the model.ckpt file + try: + model_ckpt_path = hf_hub_download(name, filename="model.safetensors", repo_type='model') + except Exception as e: + model_ckpt_path = hf_hub_download(name, filename="model.ckpt", repo_type='model') + + model.load_state_dict(load_ckpt_state_dict(model_ckpt_path)) + + return model, model_config \ No newline at end of file diff --git a/stable_audio_tools/models/pretransforms.py b/stable_audio_tools/models/pretransforms.py new file mode 100644 index 0000000000000000000000000000000000000000..cc471b104c36184324518798d1c5f04f3a6e3811 --- /dev/null +++ b/stable_audio_tools/models/pretransforms.py @@ -0,0 +1,256 @@ +import torch +from einops import rearrange +from torch import nn + +class Pretransform(nn.Module): + def __init__(self, enable_grad, io_channels, is_discrete): + super().__init__() + + self.is_discrete = is_discrete + self.io_channels = io_channels + self.encoded_channels = None + self.downsampling_ratio = None + + self.enable_grad = enable_grad + + def encode(self, x): + raise NotImplementedError + + def decode(self, z): + raise NotImplementedError + + def tokenize(self, x): + raise NotImplementedError + + def decode_tokens(self, tokens): + raise NotImplementedError + +class AutoencoderPretransform(Pretransform): + def __init__(self, model, scale=1.0, model_half=False, iterate_batch=False, chunked=False): + super().__init__(enable_grad=False, io_channels=model.io_channels, is_discrete=model.bottleneck is not None and model.bottleneck.is_discrete) + self.model = model + self.model.requires_grad_(False).eval() + self.scale=scale + self.downsampling_ratio = model.downsampling_ratio + self.io_channels = model.io_channels + self.sample_rate = model.sample_rate + + self.model_half = model_half + self.iterate_batch = iterate_batch + + self.encoded_channels = model.latent_dim + + self.chunked = chunked + self.num_quantizers = model.bottleneck.num_quantizers if model.bottleneck is not None and model.bottleneck.is_discrete else None + self.codebook_size = model.bottleneck.codebook_size if model.bottleneck is not None and model.bottleneck.is_discrete else None + + if self.model_half: + self.model.half() + + def encode(self, x, **kwargs): + + if self.model_half: + x = x.half() + self.model.to(torch.float16) + + encoded = self.model.encode_audio(x, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs) + + if self.model_half: + encoded = encoded.float() + + return encoded / self.scale + + def decode(self, z, **kwargs): + z = z * self.scale + + if self.model_half: + z = z.half() + self.model.to(torch.float16) + + decoded = self.model.decode_audio(z, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs) + + if self.model_half: + decoded = decoded.float() + + return decoded + + def tokenize(self, x, **kwargs): + assert self.model.is_discrete, "Cannot tokenize with a continuous model" + + _, info = self.model.encode(x, return_info = True, **kwargs) + + return info[self.model.bottleneck.tokens_id] + + def decode_tokens(self, tokens, **kwargs): + assert self.model.is_discrete, "Cannot decode tokens with a continuous model" + + return self.model.decode_tokens(tokens, **kwargs) + + def load_state_dict(self, state_dict, strict=True): + self.model.load_state_dict(state_dict, strict=strict) + +class WaveletPretransform(Pretransform): + def __init__(self, channels, levels, wavelet): + super().__init__(enable_grad=False, io_channels=channels, is_discrete=False) + + from .wavelets import WaveletEncode1d, WaveletDecode1d + + self.encoder = WaveletEncode1d(channels, levels, wavelet) + self.decoder = WaveletDecode1d(channels, levels, wavelet) + + self.downsampling_ratio = 2 ** levels + self.io_channels = channels + self.encoded_channels = channels * self.downsampling_ratio + + def encode(self, x): + return self.encoder(x) + + def decode(self, z): + return self.decoder(z) + +class PQMFPretransform(Pretransform): + def __init__(self, attenuation=100, num_bands=16): + # TODO: Fix PQMF to take in in-channels + super().__init__(enable_grad=False, io_channels=1, is_discrete=False) + from .pqmf import PQMF + self.pqmf = PQMF(attenuation, num_bands) + + + def encode(self, x): + # x is (Batch x Channels x Time) + x = self.pqmf.forward(x) + # pqmf.forward returns (Batch x Channels x Bands x Time) + # but Pretransform needs Batch x Channels x Time + # so concatenate channels and bands into one axis + return rearrange(x, "b c n t -> b (c n) t") + + def decode(self, x): + # x is (Batch x (Channels Bands) x Time), convert back to (Batch x Channels x Bands x Time) + x = rearrange(x, "b (c n) t -> b c n t", n=self.pqmf.num_bands) + # returns (Batch x Channels x Time) + return self.pqmf.inverse(x) + +class PretrainedDACPretransform(Pretransform): + def __init__(self, model_type="44khz", model_bitrate="8kbps", scale=1.0, quantize_on_decode: bool = True, chunked=True): + super().__init__(enable_grad=False, io_channels=1, is_discrete=True) + + import dac + + #model_path = dac.utils.download(model_type=model_type, model_bitrate=model_bitrate) + + model_path = "/g/data/qq08/synth_orn/dac_pretrained/weights_44khz_16kbps.pth" + + + self.model = dac.DAC.load(model_path) + + self.quantize_on_decode = quantize_on_decode + + if model_type == "44khz": + self.downsampling_ratio = 512 + else: + self.downsampling_ratio = 320 + + self.io_channels = 1 + + self.scale = scale + + self.chunked = chunked + + self.encoded_channels = self.model.latent_dim + + self.num_quantizers = self.model.n_codebooks + + self.codebook_size = self.model.codebook_size + + def encode(self, x): + + latents = self.model.encoder(x) + + if self.quantize_on_decode: + output = latents + else: + z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks) + output = z + + if self.scale != 1.0: + output = output / self.scale + + return output + + def decode(self, z): + + if self.scale != 1.0: + z = z * self.scale + + if self.quantize_on_decode: + z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks) + + return self.model.decode(z) + + def tokenize(self, x): + return self.model.encode(x)[1] + + def decode_tokens(self, tokens): + latents = self.model.quantizer.from_codes(tokens) + return self.model.decode(latents) + +class AudiocraftCompressionPretransform(Pretransform): + def __init__(self, model_type="facebook/encodec_32khz", scale=1.0, quantize_on_decode: bool = True): + super().__init__(enable_grad=False, io_channels=1, is_discrete=True) + + from audiocraft.models import CompressionModel + + self.model = CompressionModel.get_pretrained(model_type) + + self.quantize_on_decode = quantize_on_decode + + self.downsampling_ratio = round(self.model.sample_rate / self.model.frame_rate) + + self.io_channels = self.model.channels + + self.scale = scale + + #self.encoded_channels = self.model.latent_dim + + self.num_quantizers = self.model.num_codebooks + + self.codebook_size = self.model.cardinality + + self.model.to(torch.float16).eval().requires_grad_(False) + + def encode(self, x): + + assert False, "Audiocraft compression models do not support continuous encoding" + + # latents = self.model.encoder(x) + + # if self.quantize_on_decode: + # output = latents + # else: + # z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks) + # output = z + + # if self.scale != 1.0: + # output = output / self.scale + + # return output + + def decode(self, z): + + assert False, "Audiocraft compression models do not support continuous decoding" + + # if self.scale != 1.0: + # z = z * self.scale + + # if self.quantize_on_decode: + # z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks) + + # return self.model.decode(z) + + def tokenize(self, x): + with torch.cuda.amp.autocast(enabled=False): + return self.model.encode(x.to(torch.float16))[0] + + def decode_tokens(self, tokens): + with torch.cuda.amp.autocast(enabled=False): + return self.model.decode(tokens) diff --git a/stable_audio_tools/models/transformer.py b/stable_audio_tools/models/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..8dad7f664a826d6608e3bb50e3f07566c3d305b1 --- /dev/null +++ b/stable_audio_tools/models/transformer.py @@ -0,0 +1,775 @@ +from functools import reduce, partial +from packaging import version + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange +import torch +import torch.nn.functional as F +from torch import nn, einsum +from torch.cuda.amp import autocast +from typing import Callable, Literal + +try: + from flash_attn import flash_attn_func, flash_attn_kvpacked_func +except ImportError: + flash_attn_kvpacked_func = None + flash_attn_func = None + +try: + import natten +except ImportError: + natten = None + +def checkpoint(function, *args, **kwargs): + kwargs.setdefault("use_reentrant", False) + return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) + + +# Copied and modified from https://github.com/lucidrains/x-transformers/blob/main/x_transformers/attend.py under MIT License +# License can be found in LICENSES/LICENSE_XTRANSFORMERS.txt + +def create_causal_mask(i, j, device): + return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1) + +def or_reduce(masks): + head, *body = masks + for rest in body: + head = head | rest + return head + +# positional embeddings + +class AbsolutePositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len): + super().__init__() + self.scale = dim ** -0.5 + self.max_seq_len = max_seq_len + self.emb = nn.Embedding(max_seq_len, dim) + + def forward(self, x, pos = None, seq_start_pos = None): + seq_len, device = x.shape[1], x.device + assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}' + + if pos is None: + pos = torch.arange(seq_len, device = device) + + if seq_start_pos is not None: + pos = (pos - seq_start_pos[..., None]).clamp(min = 0) + + pos_emb = self.emb(pos) + pos_emb = pos_emb * self.scale + return pos_emb + +class ScaledSinusoidalEmbedding(nn.Module): + def __init__(self, dim, theta = 10000): + super().__init__() + assert (dim % 2) == 0, 'dimension must be divisible by 2' + self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5) + + half_dim = dim // 2 + freq_seq = torch.arange(half_dim).float() / half_dim + inv_freq = theta ** -freq_seq + self.register_buffer('inv_freq', inv_freq, persistent = False) + + def forward(self, x, pos = None, seq_start_pos = None): + seq_len, device = x.shape[1], x.device + + if pos is None: + pos = torch.arange(seq_len, device = device) + + if seq_start_pos is not None: + pos = pos - seq_start_pos[..., None] + + emb = einsum('i, j -> i j', pos, self.inv_freq) + emb = torch.cat((emb.sin(), emb.cos()), dim = -1) + return emb * self.scale + +class RotaryEmbedding(nn.Module): + def __init__( + self, + dim, + use_xpos = False, + scale_base = 512, + interpolation_factor = 1., + base = 10000, + base_rescale_factor = 1. + ): + super().__init__() + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + base *= base_rescale_factor ** (dim / (dim - 2)) + + inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + + assert interpolation_factor >= 1. + self.interpolation_factor = interpolation_factor + + if not use_xpos: + self.register_buffer('scale', None) + return + + scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) + + self.scale_base = scale_base + self.register_buffer('scale', scale) + + def forward_from_seq_len(self, seq_len): + device = self.inv_freq.device + + t = torch.arange(seq_len, device = device) + return self.forward(t) + + @autocast(enabled = False) + def forward(self, t): + device = self.inv_freq.device + + t = t.to(torch.float32) + + t = t / self.interpolation_factor + + freqs = torch.einsum('i , j -> i j', t, self.inv_freq) + freqs = torch.cat((freqs, freqs), dim = -1) + + if self.scale is None: + return freqs, 1. + + power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base + scale = self.scale ** rearrange(power, 'n -> n 1') + scale = torch.cat((scale, scale), dim = -1) + + return freqs, scale + +def rotate_half(x): + x = rearrange(x, '... (j d) -> ... j d', j = 2) + x1, x2 = x.unbind(dim = -2) + return torch.cat((-x2, x1), dim = -1) + +@autocast(enabled = False) +def apply_rotary_pos_emb(t, freqs, scale = 1): + out_dtype = t.dtype + + # cast to float32 if necessary for numerical stability + dtype = reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32)) + rot_dim, seq_len = freqs.shape[-1], t.shape[-2] + freqs, t = freqs.to(dtype), t.to(dtype) + freqs = freqs[-seq_len:, :] + + if t.ndim == 4 and freqs.ndim == 3: + freqs = rearrange(freqs, 'b n d -> b 1 n d') + + # partial rotary embeddings, Wang et al. GPT-J + t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:] + t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) + + t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype) + + return torch.cat((t, t_unrotated), dim = -1) + +# norms +class LayerNorm(nn.Module): + def __init__(self, dim, bias=False, fix_scale=False): + """ + bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less + """ + super().__init__() + + if fix_scale: + self.register_buffer("gamma", torch.ones(dim)) + else: + self.gamma = nn.Parameter(torch.ones(dim)) + + if bias: + self.beta = nn.Parameter(torch.zeros(dim)) + else: + self.register_buffer("beta", torch.zeros(dim)) + + + def forward(self, x): + return F.layer_norm(x, x.shape[-1:], weight=self.gamma, bias=self.beta) + +# feedforward + +class GLU(nn.Module): + def __init__( + self, + dim_in, + dim_out, + activation: Callable, + use_conv = False, + conv_kernel_size = 3, + ): + super().__init__() + self.act = activation + self.proj = nn.Linear(dim_in, dim_out * 2) if not use_conv else nn.Conv1d(dim_in, dim_out * 2, conv_kernel_size, padding = (conv_kernel_size // 2)) + self.use_conv = use_conv + + def forward(self, x): + if self.use_conv: + x = rearrange(x, 'b n d -> b d n') + x = self.proj(x) + x = rearrange(x, 'b d n -> b n d') + else: + x = self.proj(x) + + x, gate = x.chunk(2, dim = -1) + return x * self.act(gate) + +class FeedForward(nn.Module): + def __init__( + self, + dim, + dim_out = None, + mult = 4, + no_bias = False, + glu = True, + use_conv = False, + conv_kernel_size = 3, + zero_init_output = True, + ): + super().__init__() + inner_dim = int(dim * mult) + + # Default to SwiGLU + + activation = nn.SiLU() + + dim_out = dim if dim_out is None else dim_out + + if glu: + linear_in = GLU(dim, inner_dim, activation) + else: + linear_in = nn.Sequential( + Rearrange('b n d -> b d n') if use_conv else nn.Identity(), + nn.Linear(dim, inner_dim, bias = not no_bias) if not use_conv else nn.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias), + Rearrange('b n d -> b d n') if use_conv else nn.Identity(), + activation + ) + + linear_out = nn.Linear(inner_dim, dim_out, bias = not no_bias) if not use_conv else nn.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias) + + # init last linear layer to 0 + if zero_init_output: + nn.init.zeros_(linear_out.weight) + if not no_bias: + nn.init.zeros_(linear_out.bias) + + + self.ff = nn.Sequential( + linear_in, + Rearrange('b d n -> b n d') if use_conv else nn.Identity(), + linear_out, + Rearrange('b n d -> b d n') if use_conv else nn.Identity(), + ) + + def forward(self, x): + return self.ff(x) + +class Attention(nn.Module): + def __init__( + self, + dim, + dim_heads = 64, + dim_context = None, + causal = False, + zero_init_output=True, + qk_norm = False, + natten_kernel_size = None + ): + super().__init__() + self.dim = dim + self.dim_heads = dim_heads + self.causal = causal + + dim_kv = dim_context if dim_context is not None else dim + + self.num_heads = dim // dim_heads + self.kv_heads = dim_kv // dim_heads + + if dim_context is not None: + self.to_q = nn.Linear(dim, dim, bias=False) + self.to_kv = nn.Linear(dim_kv, dim_kv * 2, bias=False) + else: + self.to_qkv = nn.Linear(dim, dim * 3, bias=False) + + self.to_out = nn.Linear(dim, dim, bias=False) + + if zero_init_output: + nn.init.zeros_(self.to_out.weight) + + self.qk_norm = qk_norm + + # Using 1d neighborhood attention + self.natten_kernel_size = natten_kernel_size + if natten_kernel_size is not None: + return + + self.use_pt_flash = False + + self.use_fa_flash = False + + self.sdp_kwargs = dict( + enable_flash = True, + enable_math = True, + enable_mem_efficient = True + ) + + def flash_attn( + self, + q, + k, + v, + mask = None, + causal = None + ): + batch, heads, q_len, _, k_len, device = *q.shape, k.shape[-2], q.device + + # Recommended for multi-query single-key-value attention by Tri Dao + # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64]) + + if k.ndim == 3: + k = rearrange(k, 'b ... -> b 1 ...').expand_as(q) + + if v.ndim == 3: + v = rearrange(v, 'b ... -> b 1 ...').expand_as(q) + + causal = self.causal if causal is None else causal + + if q_len == 1 and causal: + causal = False + + if mask is not None: + assert mask.ndim == 4 + mask = mask.expand(batch, heads, q_len, k_len) + + # handle kv cache - this should be bypassable in updated flash attention 2 + + if k_len > q_len and causal: + causal_mask = self.create_causal_mask(q_len, k_len, device = device) + if mask is None: + mask = ~causal_mask + else: + mask = mask & ~causal_mask + causal = False + + # manually handle causal mask, if another mask was given + + row_is_entirely_masked = None + + if mask is not None and causal: + causal_mask = self.create_causal_mask(q_len, k_len, device = device) + mask = mask & ~causal_mask + + # protect against an entire row being masked out + + row_is_entirely_masked = ~mask.any(dim = -1) + mask[..., 0] = mask[..., 0] | row_is_entirely_masked + + causal = False + + with torch.backends.cuda.sdp_kernel(**self.sdp_kwargs): + out = F.scaled_dot_product_attention( + q, k, v, + attn_mask = mask, + is_causal = causal + ) + + # for a row that is entirely masked out, should zero out the output of that row token + + if row_is_entirely_masked is not None: + out = out.masked_fill(row_is_entirely_masked[..., None], 0.) + + return out + + def forward( + self, + x, + context = None, + mask = None, + context_mask = None, + rotary_pos_emb = None, + causal = None + ): + h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None + + kv_input = context if has_context else x + + if hasattr(self, 'to_q'): + # Use separate linear projections for q and k/v + q = self.to_q(x) + q = rearrange(q, 'b n (h d) -> b h n d', h = h) + + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v)) + else: + # Use fused linear projection + q, k, v = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) + + # Normalize q and k for cosine sim attention + if self.qk_norm: + q = F.normalize(q, dim=-1) + k = F.normalize(k, dim=-1) + + if rotary_pos_emb is not None and not has_context: + freqs, _ = rotary_pos_emb + + q_dtype = q.dtype + k_dtype = k.dtype + + q = q.to(torch.float32) + k = k.to(torch.float32) + freqs = freqs.to(torch.float32) + + q = apply_rotary_pos_emb(q, freqs) + k = apply_rotary_pos_emb(k, freqs) + + q = q.to(q_dtype) + k = k.to(k_dtype) + + input_mask = context_mask + + if input_mask is None and not has_context: + input_mask = mask + + # determine masking + masks = [] + final_attn_mask = None # The mask that will be applied to the attention matrix, taking all masks into account + + if input_mask is not None: + input_mask = rearrange(input_mask, 'b j -> b 1 1 j') + masks.append(~input_mask) + + # Other masks will be added here later + + if len(masks) > 0: + final_attn_mask = ~or_reduce(masks) + + n, device = q.shape[-2], q.device + + causal = self.causal if causal is None else causal + + if n == 1 and causal: + causal = False + + if self.natten_kernel_size is not None: + if natten is None: + raise ImportError('natten not installed, please install natten to use neighborhood attention') + + dtype_in = q.dtype + q, k, v = map(lambda t: t.to(torch.float32), (q, k, v)) + + attn = natten.functional.natten1dqk(q, k, kernel_size = self.natten_kernel_size, dilation=1) + + if final_attn_mask is not None: + attn = attn.masked_fill(final_attn_mask, -torch.finfo(attn.dtype).max) + + attn = F.softmax(attn, dim=-1, dtype=torch.float32) + + out = natten.functional.natten1dav(attn, v, kernel_size = self.natten_kernel_size, dilation=1).to(dtype_in) + + # Prioritize Flash Attention 2 + elif self.use_fa_flash: + assert final_attn_mask is None, 'masking not yet supported for Flash Attention 2' + # Flash Attention 2 requires FP16 inputs + dtype_in = q.dtype + q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d').half(), (q, k, v)) + out = flash_attn_func(q, k, v, causal = causal).to(dtype_in) + + out = rearrange(out, 'b n h d -> b h n d') + + # Fall back to PyTorch implementation + elif self.use_pt_flash: + out = self.flash_attn(q, k, v, causal = causal, mask = final_attn_mask) + + else: + # Fall back to custom implementation + + scale = 1. / (q.shape[-1] ** 0.5) + + kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d' + + dots = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale + + i, j, dtype = *dots.shape[-2:], dots.dtype + + mask_value = -torch.finfo(dots.dtype).max + + if final_attn_mask is not None: + dots = dots.masked_fill(~final_attn_mask, mask_value) + + if causal: + causal_mask = self.create_causal_mask(i, j, device = device) + dots = dots.masked_fill(causal_mask, mask_value) + + attn = F.softmax(dots, dim=-1, dtype=torch.float32) + attn = attn.type(dtype) + + out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v) + + # merge heads + out = rearrange(out, ' b h n d -> b n (h d)') + + # Communicate between heads + out = self.to_out(out) + + if mask is not None: + mask = rearrange(mask, 'b n -> b n 1') + out = out.masked_fill(~mask, 0.) + + return out + +class ConformerModule(nn.Module): + def __init__( + self, + dim, + norm_kwargs = {}, + ): + + super().__init__() + + self.dim = dim + + self.in_norm = LayerNorm(dim, **norm_kwargs) + self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False) + self.glu = GLU(dim, dim, nn.SiLU()) + self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False) + self.mid_norm = LayerNorm(dim, **norm_kwargs) # This is a batch norm in the original but I don't like batch norm + self.swish = nn.SiLU() + self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False) + + def forward(self, x): + x = self.in_norm(x) + x = rearrange(x, 'b n d -> b d n') + x = self.pointwise_conv(x) + x = rearrange(x, 'b d n -> b n d') + x = self.glu(x) + x = rearrange(x, 'b n d -> b d n') + x = self.depthwise_conv(x) + x = rearrange(x, 'b d n -> b n d') + x = self.mid_norm(x) + x = self.swish(x) + x = rearrange(x, 'b n d -> b d n') + x = self.pointwise_conv_2(x) + x = rearrange(x, 'b d n -> b n d') + + return x + +class TransformerBlock(nn.Module): + def __init__( + self, + dim, + dim_heads = 64, + cross_attend = False, + dim_context = None, + global_cond_dim = None, + causal = False, + zero_init_branch_outputs = True, + conformer = False, + layer_ix = -1, + remove_norms = False, + attn_kwargs = {}, + ff_kwargs = {}, + norm_kwargs = {} + ): + + super().__init__() + self.dim = dim + self.dim_heads = dim_heads + self.cross_attend = cross_attend + self.dim_context = dim_context + self.causal = causal + + self.pre_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity() + + self.self_attn = Attention( + dim, + dim_heads = dim_heads, + causal = causal, + zero_init_output=zero_init_branch_outputs, + **attn_kwargs + ) + + if cross_attend: + self.cross_attend_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity() + self.cross_attn = Attention( + dim, + dim_heads = dim_heads, + dim_context=dim_context, + causal = causal, + zero_init_output=zero_init_branch_outputs, + **attn_kwargs + ) + + self.ff_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity() + self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, **ff_kwargs) + + self.layer_ix = layer_ix + + self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs) if conformer else None + + self.global_cond_dim = global_cond_dim + + if global_cond_dim is not None: + self.to_scale_shift_gate = nn.Sequential( + nn.SiLU(), + nn.Linear(global_cond_dim, dim * 6, bias=False) + ) + + nn.init.zeros_(self.to_scale_shift_gate[1].weight) + #nn.init.zeros_(self.to_scale_shift_gate_self[1].bias) + + def forward( + self, + x, + context = None, + global_cond=None, + mask = None, + context_mask = None, + rotary_pos_emb = None + ): + if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None: + + scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = self.to_scale_shift_gate(global_cond).unsqueeze(1).chunk(6, dim = -1) + + # self-attention with adaLN + residual = x + x = self.pre_norm(x) + x = x * (1 + scale_self) + shift_self + x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb) + x = x * torch.sigmoid(1 - gate_self) + x = x + residual + + if context is not None: + x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask) + + if self.conformer is not None: + x = x + self.conformer(x) + + # feedforward with adaLN + residual = x + x = self.ff_norm(x) + x = x * (1 + scale_ff) + shift_ff + x = self.ff(x) + x = x * torch.sigmoid(1 - gate_ff) + x = x + residual + + else: + x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb) + + if context is not None: + x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask) + + if self.conformer is not None: + x = x + self.conformer(x) + + x = x + self.ff(self.ff_norm(x)) + + return x + +class ContinuousTransformer(nn.Module): + def __init__( + self, + dim, + depth, + *, + dim_in = None, + dim_out = None, + dim_heads = 64, + cross_attend=False, + cond_token_dim=None, + global_cond_dim=None, + causal=False, + rotary_pos_emb=True, + zero_init_branch_outputs=True, + conformer=False, + use_sinusoidal_emb=False, + use_abs_pos_emb=False, + abs_pos_emb_max_length=10000, + **kwargs + ): + + super().__init__() + + self.dim = dim + self.depth = depth + self.causal = causal + self.layers = nn.ModuleList([]) + + self.project_in = nn.Linear(dim_in, dim, bias=False) if dim_in is not None else nn.Identity() + self.project_out = nn.Linear(dim, dim_out, bias=False) if dim_out is not None else nn.Identity() + + if rotary_pos_emb: + self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32)) + else: + self.rotary_pos_emb = None + + self.use_sinusoidal_emb = use_sinusoidal_emb + if use_sinusoidal_emb: + self.pos_emb = ScaledSinusoidalEmbedding(dim) + + self.use_abs_pos_emb = use_abs_pos_emb + if use_abs_pos_emb: + self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length) + + for i in range(depth): + self.layers.append( + TransformerBlock( + dim, + dim_heads = dim_heads, + cross_attend = cross_attend, + dim_context = cond_token_dim, + global_cond_dim = global_cond_dim, + causal = causal, + zero_init_branch_outputs = zero_init_branch_outputs, + conformer=conformer, + layer_ix=i, + **kwargs + ) + ) + + def forward( + self, + x, + mask = None, + prepend_embeds = None, + prepend_mask = None, + global_cond = None, + **kwargs + ): + batch, seq, device = *x.shape[:2], x.device + + x = self.project_in(x) + + if prepend_embeds is not None: + prepend_length, prepend_dim = prepend_embeds.shape[1:] + + assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension' + + x = torch.cat((prepend_embeds, x), dim = -2) + + if prepend_mask is not None or mask is not None: + mask = mask if mask is not None else torch.ones((batch, seq), device = device, dtype = torch.bool) + prepend_mask = prepend_mask if prepend_mask is not None else torch.ones((batch, prepend_length), device = device, dtype = torch.bool) + + mask = torch.cat((prepend_mask, mask), dim = -1) + + # Attention layers + + if self.rotary_pos_emb is not None: + rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1]) + else: + rotary_pos_emb = None + + if self.use_sinusoidal_emb: + x = x + self.pos_emb(x) + + # Iterate over the transformer layers + for layer in self.layers: + x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs) + + x = self.project_out(x) + + return x \ No newline at end of file diff --git a/stable_audio_tools/models/utils.py b/stable_audio_tools/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..caacb8cdbcaf04b274481d11225d99887a231074 --- /dev/null +++ b/stable_audio_tools/models/utils.py @@ -0,0 +1,83 @@ +import torch +from safetensors.torch import load_file + +from torch.nn.utils import remove_weight_norm + +def load_ckpt_state_dict(ckpt_path): + if ckpt_path.endswith(".safetensors"): + state_dict = load_file(ckpt_path) + else: + state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"] + + return state_dict + +def remove_weight_norm_from_model(model): + for module in model.modules(): + if hasattr(module, "weight"): + print(f"Removing weight norm from {module}") + remove_weight_norm(module) + + return model + +# Sampling functions copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/utils/utils.py under MIT license +# License can be found in LICENSES/LICENSE_META.txt + +def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None): + """torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension. + + Args: + input (torch.Tensor): The input tensor containing probabilities. + num_samples (int): Number of samples to draw. + replacement (bool): Whether to draw with replacement or not. + Keywords args: + generator (torch.Generator): A pseudorandom number generator for sampling. + Returns: + torch.Tensor: Last dimension contains num_samples indices + sampled from the multinomial probability distribution + located in the last dimension of tensor input. + """ + + if num_samples == 1: + q = torch.empty_like(input).exponential_(1, generator=generator) + return torch.argmax(input / q, dim=-1, keepdim=True).to(torch.int64) + + input_ = input.reshape(-1, input.shape[-1]) + output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator) + output = output_.reshape(*list(input.shape[:-1]), -1) + return output + + +def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor: + """Sample next token from top K values along the last dimension of the input probs tensor. + + Args: + probs (torch.Tensor): Input probabilities with token candidates on the last dimension. + k (int): The k in “top-k”. + Returns: + torch.Tensor: Sampled tokens. + """ + top_k_value, _ = torch.topk(probs, k, dim=-1) + min_value_top_k = top_k_value[..., [-1]] + probs *= (probs >= min_value_top_k).float() + probs.div_(probs.sum(dim=-1, keepdim=True)) + next_token = multinomial(probs, num_samples=1) + return next_token + + +def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: + """Sample next token from top P probabilities along the last dimension of the input probs tensor. + + Args: + probs (torch.Tensor): Input probabilities with token candidates on the last dimension. + p (int): The p in “top-p”. + Returns: + torch.Tensor: Sampled tokens. + """ + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > p + probs_sort *= (~mask).float() + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + next_token = multinomial(probs_sort, num_samples=1) + next_token = torch.gather(probs_idx, -1, next_token) + return next_token diff --git a/stable_audio_tools/models/wavelets.py b/stable_audio_tools/models/wavelets.py new file mode 100644 index 0000000000000000000000000000000000000000..a359e39110c168aab960d3f79262b464a660e55e --- /dev/null +++ b/stable_audio_tools/models/wavelets.py @@ -0,0 +1,82 @@ +"""The 1D discrete wavelet transform for PyTorch.""" + +from einops import rearrange +import pywt +import torch +from torch import nn +from torch.nn import functional as F +from typing import Literal + + +def get_filter_bank(wavelet): + filt = torch.tensor(pywt.Wavelet(wavelet).filter_bank) + if wavelet.startswith("bior") and torch.all(filt[:, 0] == 0): + filt = filt[:, 1:] + return filt + +class WaveletEncode1d(nn.Module): + def __init__(self, + channels, + levels, + wavelet: Literal["bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"] = "bior4.4"): + super().__init__() + self.wavelet = wavelet + self.channels = channels + self.levels = levels + filt = get_filter_bank(wavelet) + assert filt.shape[-1] % 2 == 1 + kernel = filt[:2, None] + kernel = torch.flip(kernel, dims=(-1,)) + index_i = torch.repeat_interleave(torch.arange(2), channels) + index_j = torch.tile(torch.arange(channels), (2,)) + kernel_final = torch.zeros(channels * 2, channels, filt.shape[-1]) + kernel_final[index_i * channels + index_j, index_j] = kernel[index_i, 0] + self.register_buffer("kernel", kernel_final) + + def forward(self, x): + for i in range(self.levels): + low, rest = x[:, : self.channels], x[:, self.channels :] + pad = self.kernel.shape[-1] // 2 + low = F.pad(low, (pad, pad), "reflect") + low = F.conv1d(low, self.kernel, stride=2) + rest = rearrange( + rest, "n (c c2) (l l2) -> n (c l2 c2) l", l2=2, c2=self.channels + ) + x = torch.cat([low, rest], dim=1) + return x + + +class WaveletDecode1d(nn.Module): + def __init__(self, + channels, + levels, + wavelet: Literal["bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"] = "bior4.4"): + super().__init__() + self.wavelet = wavelet + self.channels = channels + self.levels = levels + filt = get_filter_bank(wavelet) + assert filt.shape[-1] % 2 == 1 + kernel = filt[2:, None] + index_i = torch.repeat_interleave(torch.arange(2), channels) + index_j = torch.tile(torch.arange(channels), (2,)) + kernel_final = torch.zeros(channels * 2, channels, filt.shape[-1]) + kernel_final[index_i * channels + index_j, index_j] = kernel[index_i, 0] + self.register_buffer("kernel", kernel_final) + + def forward(self, x): + for i in range(self.levels): + low, rest = x[:, : self.channels * 2], x[:, self.channels * 2 :] + pad = self.kernel.shape[-1] // 2 + 2 + low = rearrange(low, "n (l2 c) l -> n c (l l2)", l2=2) + low = F.pad(low, (pad, pad), "reflect") + low = rearrange(low, "n c (l l2) -> n (l2 c) l", l2=2) + low = F.conv_transpose1d( + low, self.kernel, stride=2, padding=self.kernel.shape[-1] // 2 + ) + low = low[..., pad - 1 : -pad] + rest = rearrange( + rest, "n (c l2 c2) l -> n (c c2) (l l2)", l2=2, c2=self.channels + ) + x = torch.cat([low, rest], dim=1) + return x \ No newline at end of file diff --git a/stable_audio_tools/training/__init__.py b/stable_audio_tools/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f77486b07a478bc88359bf2ece8b9c860df1b054 --- /dev/null +++ b/stable_audio_tools/training/__init__.py @@ -0,0 +1 @@ +from .factory import create_training_wrapper_from_config, create_demo_callback_from_config diff --git a/stable_audio_tools/training/autoencoders.py b/stable_audio_tools/training/autoencoders.py new file mode 100644 index 0000000000000000000000000000000000000000..ab0a2029b5bad9c2a3b5f5a2f1fb3c5cd36bdd4b --- /dev/null +++ b/stable_audio_tools/training/autoencoders.py @@ -0,0 +1,480 @@ +import torch +import torchaudio +import wandb +from einops import rearrange +from safetensors.torch import save_file, save_model +from torch import nn, optim +from torch.nn import functional as F +from torch.nn.parameter import Parameter +from ema_pytorch import EMA +import auraloss +import pytorch_lightning as pl +from ..models.autoencoders import AudioAutoencoder +from ..models.discriminators import EncodecDiscriminator, OobleckDiscriminator, DACGANLoss +from ..models.bottleneck import VAEBottleneck, RVQBottleneck, DACRVQBottleneck, DACRVQVAEBottleneck, RVQVAEBottleneck, WassersteinBottleneck +from .losses import MultiLoss, AuralossLoss, ValueLoss, L1Loss +from .utils import create_optimizer_from_config, create_scheduler_from_config + + +from pytorch_lightning.utilities.rank_zero import rank_zero_only +from aeiou.viz import pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image + +class AutoencoderTrainingWrapper(pl.LightningModule): + def __init__( + self, + autoencoder: AudioAutoencoder, + lr: float = 1e-4, + warmup_steps: int = 0, + encoder_freeze_on_warmup: bool = False, + sample_rate=48000, + loss_config: dict = None, + optimizer_configs: dict = None, + use_ema: bool = True, + ema_copy = None, + force_input_mono = False, + latent_mask_ratio = 0.0, + teacher_model: AudioAutoencoder = None + ): + super().__init__() + + self.automatic_optimization = False + + self.autoencoder = autoencoder + + self.warmed_up = False + self.warmup_steps = warmup_steps + self.encoder_freeze_on_warmup = encoder_freeze_on_warmup + self.lr = lr + + self.force_input_mono = force_input_mono + + self.teacher_model = teacher_model + + if optimizer_configs is None: + optimizer_configs ={ + "autoencoder": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": lr, + "betas": (.8, .99) + } + } + }, + "discriminator": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": lr, + "betas": (.8, .99) + } + } + } + + } + + self.optimizer_configs = optimizer_configs + + if loss_config is None: + scales = [2048, 1024, 512, 256, 128, 64, 32] + hop_sizes = [] + win_lengths = [] + overlap = 0.75 + for s in scales: + hop_sizes.append(int(s * (1 - overlap))) + win_lengths.append(s) + + loss_config = { + "discriminator": { + "type": "encodec", + "config": { + "n_ffts": scales, + "hop_lengths": hop_sizes, + "win_lengths": win_lengths, + "filters": 32 + }, + "weights": { + "adversarial": 0.1, + "feature_matching": 5.0, + } + }, + "spectral": { + "type": "mrstft", + "config": { + "fft_sizes": scales, + "hop_sizes": hop_sizes, + "win_lengths": win_lengths, + "perceptual_weighting": True + }, + "weights": { + "mrstft": 1.0, + } + }, + "time": { + "type": "l1", + "config": {}, + "weights": { + "l1": 0.0, + } + } + } + + self.loss_config = loss_config + + # Spectral reconstruction loss + + stft_loss_args = loss_config['spectral']['config'] + + if self.autoencoder.out_channels == 2: + self.sdstft = auraloss.freq.SumAndDifferenceSTFTLoss(sample_rate=sample_rate, **stft_loss_args) + self.lrstft = auraloss.freq.MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args) + else: + self.sdstft = auraloss.freq.MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args) + + # Discriminator + + if loss_config['discriminator']['type'] == 'oobleck': + self.discriminator = OobleckDiscriminator(**loss_config['discriminator']['config']) + elif loss_config['discriminator']['type'] == 'encodec': + self.discriminator = EncodecDiscriminator(in_channels=self.autoencoder.out_channels, **loss_config['discriminator']['config']) + elif loss_config['discriminator']['type'] == 'dac': + self.discriminator = DACGANLoss(channels=self.autoencoder.out_channels, sample_rate=sample_rate, **loss_config['discriminator']['config']) + + self.gen_loss_modules = [] + + # Adversarial and feature matching losses + self.gen_loss_modules += [ + ValueLoss(key='loss_adv', weight=self.loss_config['discriminator']['weights']['adversarial'], name='loss_adv'), + ValueLoss(key='feature_matching_distance', weight=self.loss_config['discriminator']['weights']['feature_matching'], name='feature_matching'), + ] + + if self.teacher_model is not None: + # Distillation losses + + stft_loss_weight = self.loss_config['spectral']['weights']['mrstft'] * 0.25 + self.gen_loss_modules += [ + AuralossLoss(self.sdstft, 'reals', 'decoded', name='mrstft_loss', weight=stft_loss_weight), # Reconstruction loss + AuralossLoss(self.sdstft, 'decoded', 'teacher_decoded', name='mrstft_loss_distill', weight=stft_loss_weight), # Distilled model's decoder is compatible with teacher's decoder + AuralossLoss(self.sdstft, 'reals', 'own_latents_teacher_decoded', name='mrstft_loss_own_latents_teacher', weight=stft_loss_weight), # Distilled model's encoder is compatible with teacher's decoder + AuralossLoss(self.sdstft, 'reals', 'teacher_latents_own_decoded', name='mrstft_loss_teacher_latents_own', weight=stft_loss_weight) # Teacher's encoder is compatible with distilled model's decoder + ] + + else: + + # Reconstruction loss + self.gen_loss_modules += [ + AuralossLoss(self.sdstft, 'reals', 'decoded', name='mrstft_loss', weight=self.loss_config['spectral']['weights']['mrstft']), + ] + + if self.autoencoder.out_channels == 2: + + # Add left and right channel reconstruction losses in addition to the sum and difference + self.gen_loss_modules += [ + AuralossLoss(self.lrstft, 'reals_left', 'decoded_left', name='stft_loss_left', weight=self.loss_config['spectral']['weights']['mrstft']/2), + AuralossLoss(self.lrstft, 'reals_right', 'decoded_right', name='stft_loss_right', weight=self.loss_config['spectral']['weights']['mrstft']/2), + ] + + self.gen_loss_modules += [ + AuralossLoss(self.sdstft, 'reals', 'decoded', name='mrstft_loss', weight=self.loss_config['spectral']['weights']['mrstft']), + ] + + if self.loss_config['time']['weights']['l1'] > 0.0: + self.gen_loss_modules.append(L1Loss(key_a='reals', key_b='decoded', weight=self.loss_config['time']['weights']['l1'], name='l1_time_loss')) + + if self.autoencoder.bottleneck is not None: + self.gen_loss_modules += create_loss_modules_from_bottleneck(self.autoencoder.bottleneck, self.loss_config) + + self.losses_gen = MultiLoss(self.gen_loss_modules) + + self.disc_loss_modules = [ + ValueLoss(key='loss_dis', weight=1.0, name='discriminator_loss'), + ] + + self.losses_disc = MultiLoss(self.disc_loss_modules) + + # Set up EMA for model weights + self.autoencoder_ema = None + + self.use_ema = use_ema + + if self.use_ema: + self.autoencoder_ema = EMA( + self.autoencoder, + ema_model=ema_copy, + beta=0.9999, + power=3/4, + update_every=1, + update_after_step=1 + ) + + self.latent_mask_ratio = latent_mask_ratio + + def configure_optimizers(self): + + opt_gen = create_optimizer_from_config(self.optimizer_configs['autoencoder']['optimizer'], self.autoencoder.parameters()) + opt_disc = create_optimizer_from_config(self.optimizer_configs['discriminator']['optimizer'], self.discriminator.parameters()) + + if "scheduler" in self.optimizer_configs['autoencoder'] and "scheduler" in self.optimizer_configs['discriminator']: + sched_gen = create_scheduler_from_config(self.optimizer_configs['autoencoder']['scheduler'], opt_gen) + sched_disc = create_scheduler_from_config(self.optimizer_configs['discriminator']['scheduler'], opt_disc) + return [opt_gen, opt_disc], [sched_gen, sched_disc] + + return [opt_gen, opt_disc] + + def training_step(self, batch, batch_idx): + reals, _ = batch + + # Remove extra dimension added by WebDataset + if reals.ndim == 4 and reals.shape[0] == 1: + reals = reals[0] + + if self.global_step >= self.warmup_steps: + self.warmed_up = True + + loss_info = {} + + loss_info["reals"] = reals + + encoder_input = reals + + if self.force_input_mono and encoder_input.shape[1] > 1: + encoder_input = encoder_input.mean(dim=1, keepdim=True) + + loss_info["encoder_input"] = encoder_input + + data_std = encoder_input.std() + + if self.warmed_up and self.encoder_freeze_on_warmup: + with torch.no_grad(): + latents, encoder_info = self.autoencoder.encode(encoder_input, return_info=True) + else: + latents, encoder_info = self.autoencoder.encode(encoder_input, return_info=True) + + loss_info["latents"] = latents + + loss_info.update(encoder_info) + + # Encode with teacher model for distillation + if self.teacher_model is not None: + with torch.no_grad(): + teacher_latents = self.teacher_model.encode(encoder_input, return_info=False) + loss_info['teacher_latents'] = teacher_latents + + # Optionally mask out some latents for noise resistance + if self.latent_mask_ratio > 0.0: + mask = torch.rand_like(latents) < self.latent_mask_ratio + latents = torch.where(mask, torch.zeros_like(latents), latents) + + decoded = self.autoencoder.decode(latents) + + loss_info["decoded"] = decoded + + if self.autoencoder.out_channels == 2: + loss_info["decoded_left"] = decoded[:, 0:1, :] + loss_info["decoded_right"] = decoded[:, 1:2, :] + loss_info["reals_left"] = reals[:, 0:1, :] + loss_info["reals_right"] = reals[:, 1:2, :] + + # Distillation + if self.teacher_model is not None: + with torch.no_grad(): + teacher_decoded = self.teacher_model.decode(teacher_latents) + own_latents_teacher_decoded = self.teacher_model.decode(latents) #Distilled model's latents decoded by teacher + teacher_latents_own_decoded = self.autoencoder.decode(teacher_latents) #Teacher's latents decoded by distilled model + + loss_info['teacher_decoded'] = teacher_decoded + loss_info['own_latents_teacher_decoded'] = own_latents_teacher_decoded + loss_info['teacher_latents_own_decoded'] = teacher_latents_own_decoded + + + if self.warmed_up: + loss_dis, loss_adv, feature_matching_distance = self.discriminator.loss(reals, decoded) + else: + loss_dis = torch.tensor(0.).to(reals) + loss_adv = torch.tensor(0.).to(reals) + feature_matching_distance = torch.tensor(0.).to(reals) + + loss_info["loss_dis"] = loss_dis + loss_info["loss_adv"] = loss_adv + loss_info["feature_matching_distance"] = feature_matching_distance + + opt_gen, opt_disc = self.optimizers() + + lr_schedulers = self.lr_schedulers() + + sched_gen = None + sched_disc = None + + if lr_schedulers is not None: + sched_gen, sched_disc = lr_schedulers + + # Train the discriminator + if self.global_step % 2 and self.warmed_up: + loss, losses = self.losses_disc(loss_info) + + log_dict = { + 'train/disc_lr': opt_disc.param_groups[0]['lr'] + } + + opt_disc.zero_grad() + self.manual_backward(loss) + opt_disc.step() + + if sched_disc is not None: + # sched step every step + sched_disc.step() + + # Train the generator + else: + + loss, losses = self.losses_gen(loss_info) + + if self.use_ema: + self.autoencoder_ema.update() + + opt_gen.zero_grad() + self.manual_backward(loss) + opt_gen.step() + + if sched_gen is not None: + # scheduler step every step + sched_gen.step() + + log_dict = { + 'train/loss': loss.detach(), + 'train/latent_std': latents.std().detach(), + 'train/data_std': data_std.detach(), + 'train/gen_lr': opt_gen.param_groups[0]['lr'] + } + + for loss_name, loss_value in losses.items(): + log_dict[f'train/{loss_name}'] = loss_value.detach() + + self.log_dict(log_dict, prog_bar=True, on_step=True) + + return loss + + def export_model(self, path, use_safetensors=False): + if self.autoencoder_ema is not None: + model = self.autoencoder_ema.ema_model + else: + model = self.autoencoder + + if use_safetensors: + save_model(model, path) + else: + torch.save({"state_dict": model.state_dict()}, path) + + +class AutoencoderDemoCallback(pl.Callback): + def __init__( + self, + demo_dl, + demo_every=2000, + sample_size=65536, + sample_rate=48000 + ): + super().__init__() + self.demo_every = demo_every + self.demo_samples = sample_size + self.demo_dl = iter(demo_dl) + self.sample_rate = sample_rate + self.last_demo_step = -1 + + @rank_zero_only + @torch.no_grad() + def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx): + if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: + return + + self.last_demo_step = trainer.global_step + + module.eval() + + try: + demo_reals, _ = next(self.demo_dl) + + # Remove extra dimension added by WebDataset + if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: + demo_reals = demo_reals[0] + + encoder_input = demo_reals + + encoder_input = encoder_input.to(module.device) + + if module.force_input_mono: + encoder_input = encoder_input.mean(dim=1, keepdim=True) + + demo_reals = demo_reals.to(module.device) + + with torch.no_grad(): + if module.use_ema: + + latents = module.autoencoder_ema.ema_model.encode(encoder_input) + + fakes = module.autoencoder_ema.ema_model.decode(latents) + else: + latents = module.autoencoder.encode(encoder_input) + + fakes = module.autoencoder.decode(latents) + + #Interleave reals and fakes + reals_fakes = rearrange([demo_reals, fakes], 'i b d n -> (b i) d n') + + # Put the demos together + reals_fakes = rearrange(reals_fakes, 'b d n -> d (b n)') + + log_dict = {} + + filename = f'recon_{trainer.global_step:08}.wav' + reals_fakes = reals_fakes.to(torch.float32).clamp(-1, 1).mul(32767).to(torch.int16).cpu() + torchaudio.save(filename, reals_fakes, self.sample_rate) + + log_dict[f'recon'] = wandb.Audio(filename, + sample_rate=self.sample_rate, + caption=f'Reconstructed') + + log_dict[f'embeddings_3dpca'] = pca_point_cloud(latents) + log_dict[f'embeddings_spec'] = wandb.Image(tokens_spectrogram_image(latents)) + + log_dict[f'recon_melspec_left'] = wandb.Image(audio_spectrogram_image(reals_fakes)) + + trainer.logger.experiment.log(log_dict) + except Exception as e: + print(f'{type(e).__name__}: {e}') + raise e + finally: + module.train() + +def create_loss_modules_from_bottleneck(bottleneck, loss_config): + losses = [] + + if isinstance(bottleneck, VAEBottleneck) or isinstance(bottleneck, DACRVQVAEBottleneck) or isinstance(bottleneck, RVQVAEBottleneck): + try: + kl_weight = loss_config['bottleneck']['weights']['kl'] + except: + kl_weight = 1e-6 + + kl_loss = ValueLoss(key='kl', weight=kl_weight, name='kl_loss') + losses.append(kl_loss) + + if isinstance(bottleneck, RVQBottleneck) or isinstance(bottleneck, RVQVAEBottleneck): + quantizer_loss = ValueLoss(key='quantizer_loss', weight=1.0, name='quantizer_loss') + losses.append(quantizer_loss) + + if isinstance(bottleneck, DACRVQBottleneck) or isinstance(bottleneck, DACRVQVAEBottleneck): + codebook_loss = ValueLoss(key='vq/codebook_loss', weight=1.0, name='codebook_loss') + commitment_loss = ValueLoss(key='vq/commitment_loss', weight=0.25, name='commitment_loss') + losses.append(codebook_loss) + losses.append(commitment_loss) + + if isinstance(bottleneck, WassersteinBottleneck): + try: + mmd_weight = loss_config['bottleneck']['weights']['mmd'] + except: + mmd_weight = 100 + + mmd_loss = ValueLoss(key='mmd', weight=mmd_weight, name='mmd_loss') + losses.append(mmd_loss) + + return losses \ No newline at end of file diff --git a/stable_audio_tools/training/diffusion.py b/stable_audio_tools/training/diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..b2934eef1555a854459d89888a1ddeea084cc2df --- /dev/null +++ b/stable_audio_tools/training/diffusion.py @@ -0,0 +1,1431 @@ +import pytorch_lightning as pl +import sys, gc +import random +import torch +import torchaudio +import typing as tp +import wandb + +from aeiou.viz import pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image +import auraloss +from ema_pytorch import EMA +from einops import rearrange +from safetensors.torch import save_file +from torch import optim +from torch.nn import functional as F +from pytorch_lightning.utilities.rank_zero import rank_zero_only + +from ..inference.sampling import get_alphas_sigmas, sample +from ..models.diffusion import DiffusionModelWrapper, ConditionedDiffusionModelWrapper +from ..models.autoencoders import DiffusionAutoencoder +from ..models.diffusion_prior import PriorType +from .autoencoders import create_loss_modules_from_bottleneck +from .losses import AuralossLoss, MSELoss, MultiLoss +from .utils import create_optimizer_from_config, create_scheduler_from_config + +from time import time + +class Profiler: + + def __init__(self): + self.ticks = [[time(), None]] + + def tick(self, msg): + self.ticks.append([time(), msg]) + + def __repr__(self): + rep = 80 * "=" + "\n" + for i in range(1, len(self.ticks)): + msg = self.ticks[i][1] + ellapsed = self.ticks[i][0] - self.ticks[i - 1][0] + rep += msg + f": {ellapsed*1000:.2f}ms\n" + rep += 80 * "=" + "\n\n\n" + return rep + +class DiffusionUncondTrainingWrapper(pl.LightningModule): + ''' + Wrapper for training an unconditional audio diffusion model (like Dance Diffusion). + ''' + def __init__( + self, + model: DiffusionModelWrapper, + lr: float = 1e-4 + ): + super().__init__() + + self.diffusion = model + + self.diffusion_ema = EMA( + self.diffusion.model, + beta=0.9999, + power=3/4, + update_every=1, + update_after_step=1 + ) + + self.lr = lr + + self.rng = torch.quasirandom.SobolEngine(1, scramble=True) + + loss_modules = [ + MSELoss("v", + "targets", + weight=1.0, + name="mse_loss" + ) + ] + + self.losses = MultiLoss(loss_modules) + + def configure_optimizers(self): + return optim.Adam([*self.diffusion.parameters()], lr=self.lr) + + def training_step(self, batch, batch_idx): + reals = batch[0] + + if reals.ndim == 4 and reals.shape[0] == 1: + reals = reals[0] + + # Draw uniformly distributed continuous timesteps + t = self.rng.draw(reals.shape[0])[:, 0].to(self.device) + + # Calculate the noise schedule parameters for those timesteps + alphas, sigmas = get_alphas_sigmas(t) + + diffusion_input = reals + + loss_info = {} + + loss_info["audio_reals"] = diffusion_input + + if self.diffusion.pretransform is not None: + with torch.set_grad_enabled(self.diffusion.pretransform.enable_grad): + diffusion_input = self.diffusion.pretransform.encode(diffusion_input) + loss_info["reals"] = diffusion_input + + # Combine the ground truth data and the noise + alphas = alphas[:, None, None] + sigmas = sigmas[:, None, None] + noise = torch.randn_like(diffusion_input) + noised_inputs = diffusion_input * alphas + noise * sigmas + targets = noise * alphas - diffusion_input * sigmas + + with torch.cuda.amp.autocast(): + v = self.diffusion(noised_inputs, t) + + loss_info.update({ + "v": v, + "targets": targets + }) + + loss, losses = self.losses(loss_info) + + log_dict = { + 'train/loss': loss.detach(), + 'train/std_data': diffusion_input.std(), + } + + for loss_name, loss_value in losses.items(): + log_dict[f"train/{loss_name}"] = loss_value.detach() + + self.log_dict(log_dict, prog_bar=True, on_step=True) + return loss + + def on_before_zero_grad(self, *args, **kwargs): + self.diffusion_ema.update() + + def export_model(self, path, use_safetensors=False): + + self.diffusion.model = self.diffusion_ema.ema_model + + if use_safetensors: + save_file(self.diffusion.state_dict(), path) + else: + torch.save({"state_dict": self.diffusion.state_dict()}, path) + +class DiffusionUncondDemoCallback(pl.Callback): + def __init__(self, + demo_every=2000, + num_demos=8, + demo_steps=250, + sample_rate=48000 + ): + super().__init__() + + self.demo_every = demo_every + self.num_demos = num_demos + self.demo_steps = demo_steps + self.sample_rate = sample_rate + self.last_demo_step = -1 + + @rank_zero_only + @torch.no_grad() + def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx): + + if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: + return + + self.last_demo_step = trainer.global_step + + demo_samples = module.diffusion.sample_size + + if module.diffusion.pretransform is not None: + demo_samples = demo_samples // module.diffusion.pretransform.downsampling_ratio + + noise = torch.randn([self.num_demos, module.diffusion.io_channels, demo_samples]).to(module.device) + + try: + with torch.cuda.amp.autocast(): + fakes = sample(module.diffusion_ema, noise, self.demo_steps, 0) + + if module.diffusion.pretransform is not None: + fakes = module.diffusion.pretransform.decode(fakes) + + # Put the demos together + fakes = rearrange(fakes, 'b d n -> d (b n)') + + log_dict = {} + + filename = f'demo_{trainer.global_step:08}.wav' + fakes = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).mul(32767).to(torch.int16).cpu() + torchaudio.save(filename, fakes, self.sample_rate) + + log_dict[f'demo'] = wandb.Audio(filename, + sample_rate=self.sample_rate, + caption=f'Reconstructed') + + log_dict[f'demo_melspec_left'] = wandb.Image(audio_spectrogram_image(fakes)) + + trainer.logger.experiment.log(log_dict) + + del fakes + + except Exception as e: + print(f'{type(e).__name__}: {e}') + finally: + gc.collect() + torch.cuda.empty_cache() + +class DiffusionCondTrainingWrapper(pl.LightningModule): + ''' + Wrapper for training a conditional audio diffusion model. + ''' + def __init__( + self, + model: ConditionedDiffusionModelWrapper, + lr: float = None, + causal_dropout: float = 0.0, + mask_padding: bool = False, + mask_padding_dropout: float = 0.0, + use_ema: bool = True, + log_loss_info: bool = False, + optimizer_configs: dict = None, + use_reconstruction_loss: bool = False + ): + super().__init__() + + self.diffusion = model + + if use_ema: + self.diffusion_ema = EMA( + self.diffusion.model, + beta=0.9999, + power=3/4, + update_every=1, + update_after_step=1, + include_online_model=False + ) + else: + self.diffusion_ema = None + + self.mask_padding = mask_padding + self.mask_padding_dropout = mask_padding_dropout + + + self.rng = torch.quasirandom.SobolEngine(1, scramble=True) + + self.causal_dropout = causal_dropout + + self.loss_modules = [ + MSELoss("v", + "targets", + weight=1.0, + mask_key="padding_mask" if self.mask_padding else None, + name="mse_loss" + ) + ] + + self.use_reconstruction_loss = use_reconstruction_loss + + if use_reconstruction_loss: + scales = [2048, 1024, 512, 256, 128, 64, 32] + hop_sizes = [] + win_lengths = [] + overlap = 0.75 + for s in scales: + hop_sizes.append(int(s * (1 - overlap))) + win_lengths.append(s) + + sample_rate = model.sample_rate + + stft_loss_args = { + "fft_sizes": scales, + "hop_sizes": hop_sizes, + "win_lengths": win_lengths, + "perceptual_weighting": True + } + + out_channels = model.io_channels + + if model.pretransform is not None: + out_channels = model.pretransform.io_channels + + self.audio_out_channels = out_channels + + if self.audio_out_channels == 2: + self.sdstft = auraloss.freq.SumAndDifferenceSTFTLoss(sample_rate=sample_rate, **stft_loss_args) + self.lrstft = auraloss.freq.MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args) + + # Add left and right channel reconstruction losses in addition to the sum and difference + self.loss_modules += [ + AuralossLoss(self.lrstft, 'audio_reals_left', 'pred_left', name='stft_loss_left', weight=0.05), + AuralossLoss(self.lrstft, 'audio_reals_right', 'pred_right', name='stft_loss_right', weight=0.05), + ] + + else: + self.sdstft = auraloss.freq.MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args) + + self.loss_modules.append( + AuralossLoss(self.sdstft, 'audio_reals', 'audio_pred', name='mrstft_loss', weight=0.1), # Reconstruction loss + ) + + self.losses = MultiLoss(self.loss_modules) + + self.log_loss_info = log_loss_info + + assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config" + + if optimizer_configs is None: + optimizer_configs = { + "diffusion": { + "optimizer": { + "type": "Adam", + "config": { + "lr": lr + } + } + } + } + else: + if lr is not None: + print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.") + + self.optimizer_configs = optimizer_configs + + def configure_optimizers(self): + diffusion_opt_config = self.optimizer_configs['diffusion'] + opt_diff = create_optimizer_from_config(diffusion_opt_config['optimizer'], self.diffusion.parameters()) + + if "scheduler" in diffusion_opt_config: + sched_diff = create_scheduler_from_config(diffusion_opt_config['scheduler'], opt_diff) + sched_diff_config = { + "scheduler": sched_diff, + "interval": "step" + } + return [opt_diff], [sched_diff_config] + + return [opt_diff] + + def training_step(self, batch, batch_idx): + reals, metadata = batch + + p = Profiler() + + if reals.ndim == 4 and reals.shape[0] == 1: + reals = reals[0] + + loss_info = {} + + # Draw uniformly distributed continuous timesteps + t = self.rng.draw(reals.shape[0])[:, 0].to(self.device) + + # Replace 1% of t with ones to ensure training on terminal SNR + t = torch.where(torch.rand_like(t) < 0.01, torch.ones_like(t), t) + + # Calculate the noise schedule parameters for those timesteps + alphas, sigmas = get_alphas_sigmas(t) + + diffusion_input = reals + + loss_info["audio_reals"] = diffusion_input + + p.tick("setup") + + with torch.cuda.amp.autocast(): + conditioning = self.diffusion.conditioner(metadata, self.device) + + # If mask_padding is on, randomly drop the padding masks to allow for learning silence padding + use_padding_mask = self.mask_padding and random.random() > self.mask_padding_dropout + + # Create batch tensor of attention masks from the "mask" field of the metadata array + if use_padding_mask: + padding_masks = torch.stack([md["padding_mask"][0] for md in metadata], dim=0).to(self.device) # Shape (batch_size, sequence_length) + + p.tick("conditioning") + + if self.diffusion.pretransform is not None: + self.diffusion.pretransform.to(self.device) + + with torch.cuda.amp.autocast() and torch.set_grad_enabled(self.diffusion.pretransform.enable_grad): + diffusion_input = self.diffusion.pretransform.encode(diffusion_input) + p.tick("pretransform") + + # If mask_padding is on, interpolate the padding masks to the size of the pretransformed input + if use_padding_mask: + padding_masks = F.interpolate(padding_masks.unsqueeze(1).float(), size=diffusion_input.shape[2], mode="nearest").squeeze(1).bool() + + + # Combine the ground truth data and the noise + alphas = alphas[:, None, None] + sigmas = sigmas[:, None, None] + noise = torch.randn_like(diffusion_input) + noised_inputs = diffusion_input * alphas + noise * sigmas + targets = noise * alphas - diffusion_input * sigmas + + p.tick("noise") + + extra_args = {} + + if self.causal_dropout > 0.0: + extra_args["causal"] = random.random() < self.causal_dropout + + if use_padding_mask: + extra_args["mask"] = padding_masks + + with torch.cuda.amp.autocast(): + p.tick("amp") + v = self.diffusion(noised_inputs, t, cond=conditioning, cfg_dropout_prob = 0.1, **extra_args) + p.tick("diffusion") + + loss_info.update({ + "v": v, + "targets": targets, + "padding_mask": padding_masks if use_padding_mask else None, + }) + + if self.use_reconstruction_loss: + pred = noised_inputs * alphas - v * sigmas + + loss_info["pred"] = pred + + if self.diffusion.pretransform is not None: + pred = self.diffusion.pretransform.decode(pred) + loss_info["audio_pred"] = pred + + if self.audio_out_channels == 2: + loss_info["pred_left"] = pred[:, 0:1, :] + loss_info["pred_right"] = pred[:, 1:2, :] + loss_info["audio_reals_left"] = loss_info["audio_reals"][:, 0:1, :] + loss_info["audio_reals_right"] = loss_info["audio_reals"][:, 1:2, :] + + loss, losses = self.losses(loss_info) + + p.tick("loss") + + if self.log_loss_info: + # Loss debugging logs + num_loss_buckets = 10 + bucket_size = 1 / num_loss_buckets + loss_all = F.mse_loss(v, targets, reduction="none") + + sigmas = rearrange(self.all_gather(sigmas), "w b c n -> (w b) c n").squeeze() + + # gather loss_all across all GPUs + loss_all = rearrange(self.all_gather(loss_all), "w b c n -> (w b) c n") + + # Bucket loss values based on corresponding sigma values, bucketing sigma values by bucket_size + loss_all = torch.stack([loss_all[(sigmas >= i) & (sigmas < i + bucket_size)].mean() for i in torch.arange(0, 1, bucket_size).to(self.device)]) + + # Log bucketed losses with corresponding sigma bucket values, if it's not NaN + debug_log_dict = { + f"model/loss_all_{i/num_loss_buckets:.1f}": loss_all[i].detach() for i in range(num_loss_buckets) if not torch.isnan(loss_all[i]) + } + + self.log_dict(debug_log_dict) + + + log_dict = { + 'train/loss': loss.detach(), + 'train/std_data': diffusion_input.std(), + 'train/lr': self.trainer.optimizers[0].param_groups[0]['lr'] + } + + for loss_name, loss_value in losses.items(): + log_dict[f"train/{loss_name}"] = loss_value.detach() + + self.log_dict(log_dict, prog_bar=True, on_step=True) + p.tick("log") + #print(f"Profiler: {p}") + return loss + + def on_before_zero_grad(self, *args, **kwargs): + if self.diffusion_ema is not None: + self.diffusion_ema.update() + + def export_model(self, path, use_safetensors=False): + if self.diffusion_ema is not None: + self.diffusion.model = self.diffusion_ema.ema_model + + if use_safetensors: + save_file(self.diffusion.state_dict(), path) + else: + torch.save({"state_dict": self.diffusion.state_dict()}, path) + +class DiffusionCondDemoCallback(pl.Callback): + def __init__(self, + demo_every=2000, + num_demos=8, + sample_size=65536, + demo_steps=250, + sample_rate=48000, + demo_conditioning: tp.Optional[tp.Dict[str, tp.Any]] = {}, + demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7], + demo_cond_from_batch: bool = False, + display_audio_cond: bool = False + ): + super().__init__() + + self.demo_every = demo_every + self.num_demos = num_demos + self.demo_samples = sample_size + self.demo_steps = demo_steps + self.sample_rate = sample_rate + self.last_demo_step = -1 + self.demo_conditioning = demo_conditioning + self.demo_cfg_scales = demo_cfg_scales + + # If true, the callback will use the metadata from the batch to generate the demo conditioning + self.demo_cond_from_batch = demo_cond_from_batch + + # If true, the callback will display the audio conditioning + self.display_audio_cond = display_audio_cond + + @rank_zero_only + @torch.no_grad() + def on_train_batch_end(self, trainer, module: DiffusionCondTrainingWrapper, outputs, batch, batch_idx): + + if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: + return + + module.eval() + + print(f"Generating demo") + self.last_demo_step = trainer.global_step + + demo_samples = self.demo_samples + + demo_cond = self.demo_conditioning + + if self.demo_cond_from_batch: + # Get metadata from the batch + demo_cond = batch[1][:self.num_demos] + + if module.diffusion.pretransform is not None: + demo_samples = demo_samples // module.diffusion.pretransform.downsampling_ratio + + noise = torch.randn([self.num_demos, module.diffusion.io_channels, demo_samples]).to(module.device) + + try: + print("Getting conditioning") + with torch.cuda.amp.autocast(): + conditioning = module.diffusion.conditioner(demo_cond, module.device) + + cond_inputs = module.diffusion.get_conditioning_inputs(conditioning) + + log_dict = {} + + if self.display_audio_cond: + audio_inputs = torch.cat([cond["audio"] for cond in demo_cond], dim=0) + audio_inputs = rearrange(audio_inputs, 'b d n -> d (b n)') + + filename = f'demo_audio_cond_{trainer.global_step:08}.wav' + audio_inputs = audio_inputs.to(torch.float32).mul(32767).to(torch.int16).cpu() + torchaudio.save(filename, audio_inputs, self.sample_rate) + log_dict[f'demo_audio_cond'] = wandb.Audio(filename, sample_rate=self.sample_rate, caption="Audio conditioning") + log_dict[f"demo_audio_cond_melspec_left"] = wandb.Image(audio_spectrogram_image(audio_inputs)) + trainer.logger.experiment.log(log_dict) + + for cfg_scale in self.demo_cfg_scales: + + print(f"Generating demo for cfg scale {cfg_scale}") + + with torch.cuda.amp.autocast(): + model = module.diffusion_ema.model if module.diffusion_ema is not None else module.diffusion.model + + fakes = sample(model, noise, self.demo_steps, 0, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True) + if module.diffusion.pretransform is not None: + fakes = module.diffusion.pretransform.decode(fakes) + + # Put the demos together + fakes = rearrange(fakes, 'b d n -> d (b n)') + + log_dict = {} + + filename = f'demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav' + fakes = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).mul(32767).to(torch.int16).cpu() + torchaudio.save(filename, fakes, self.sample_rate) + + log_dict[f'demo_cfg_{cfg_scale}'] = wandb.Audio(filename, + sample_rate=self.sample_rate, + caption=f'Reconstructed') + + log_dict[f'demo_melspec_left_cfg_{cfg_scale}'] = wandb.Image(audio_spectrogram_image(fakes)) + + trainer.logger.experiment.log(log_dict) + + del fakes + + except Exception as e: + raise e + finally: + gc.collect() + torch.cuda.empty_cache() + module.train() + +class DiffusionCondInpaintTrainingWrapper(pl.LightningModule): + ''' + Wrapper for training a conditional audio diffusion model. + ''' + def __init__( + self, + model: ConditionedDiffusionModelWrapper, + lr: float = 1e-4, + max_mask_segments = 10 + ): + super().__init__() + + self.diffusion = model + + self.diffusion_ema = EMA( + self.diffusion.model, + beta=0.9999, + power=3/4, + update_every=1, + update_after_step=1, + include_online_model=False + ) + + self.lr = lr + self.max_mask_segments = max_mask_segments + + self.rng = torch.quasirandom.SobolEngine(1, scramble=True) + + self.loss_modules = [ + MSELoss("v", + "targets", + weight=1.0, + name="mse_loss" + ) + ] + + self.losses = MultiLoss(self.loss_modules) + + def configure_optimizers(self): + return optim.Adam([*self.diffusion.parameters()], lr=self.lr) + + def random_mask(self, sequence, max_mask_length): + b, _, sequence_length = sequence.size() + + # Create a mask tensor for each batch element + masks = [] + + for i in range(b): + mask_type = random.randint(0, 2) + + if mask_type == 0: # Random mask with multiple segments + num_segments = random.randint(1, self.max_mask_segments) + max_segment_length = max_mask_length // num_segments + + segment_lengths = random.sample(range(1, max_segment_length + 1), num_segments) + + mask = torch.ones((1, 1, sequence_length)) + for length in segment_lengths: + mask_start = random.randint(0, sequence_length - length) + mask[:, :, mask_start:mask_start + length] = 0 + + elif mask_type == 1: # Full mask + mask = torch.zeros((1, 1, sequence_length)) + + elif mask_type == 2: # Causal mask + mask = torch.ones((1, 1, sequence_length)) + mask_length = random.randint(1, max_mask_length) + mask[:, :, -mask_length:] = 0 + + mask = mask.to(sequence.device) + masks.append(mask) + + # Concatenate the mask tensors into a single tensor + mask = torch.cat(masks, dim=0).to(sequence.device) + + # Apply the mask to the sequence tensor for each batch element + masked_sequence = sequence * mask + + return masked_sequence, mask + + def training_step(self, batch, batch_idx): + reals, metadata = batch + + p = Profiler() + + if reals.ndim == 4 and reals.shape[0] == 1: + reals = reals[0] + + # Draw uniformly distributed continuous timesteps + t = self.rng.draw(reals.shape[0])[:, 0].to(self.device) + + # Calculate the noise schedule parameters for those timesteps + alphas, sigmas = get_alphas_sigmas(t) + + diffusion_input = reals + + p.tick("setup") + + with torch.cuda.amp.autocast(): + conditioning = self.diffusion.conditioner(metadata, self.device) + + p.tick("conditioning") + + if self.diffusion.pretransform is not None: + self.diffusion.pretransform.to(self.device) + with torch.cuda.amp.autocast() and torch.set_grad_enabled(self.diffusion.pretransform.enable_grad): + diffusion_input = self.diffusion.pretransform.encode(diffusion_input) + p.tick("pretransform") + + # Max mask size is the full sequence length + max_mask_length = diffusion_input.shape[2] + + # Create a mask of random length for a random slice of the input + masked_input, mask = self.random_mask(diffusion_input, max_mask_length) + + conditioning['inpaint_mask'] = [mask] + conditioning['inpaint_masked_input'] = [masked_input] + + # Combine the ground truth data and the noise + alphas = alphas[:, None, None] + sigmas = sigmas[:, None, None] + noise = torch.randn_like(diffusion_input) + noised_inputs = diffusion_input * alphas + noise * sigmas + targets = noise * alphas - diffusion_input * sigmas + + p.tick("noise") + + with torch.cuda.amp.autocast(): + p.tick("amp") + v = self.diffusion(noised_inputs, t, cond=conditioning, cfg_dropout_prob = 0.1) + p.tick("diffusion") + + loss_info = { + "v": v, + "targets": targets + } + + loss, losses = self.losses(loss_info) + + log_dict = { + 'train/loss': loss.detach(), + 'train/std_data': diffusion_input.std(), + } + + for loss_name, loss_value in losses.items(): + log_dict[f"train/{loss_name}"] = loss_value.detach() + + self.log_dict(log_dict, prog_bar=True, on_step=True) + p.tick("log") + #print(f"Profiler: {p}") + return loss + + def on_before_zero_grad(self, *args, **kwargs): + self.diffusion_ema.update() + + def export_model(self, path): + self.diffusion.model = self.diffusion_ema.ema_model + + save_file(self.diffusion.state_dict(), path) + +class DiffusionCondInpaintDemoCallback(pl.Callback): + def __init__( + self, + demo_dl, + demo_every=2000, + demo_steps=250, + sample_size=65536, + sample_rate=48000, + demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7] + ): + super().__init__() + self.demo_every = demo_every + self.demo_steps = demo_steps + self.demo_samples = sample_size + self.demo_dl = iter(demo_dl) + self.sample_rate = sample_rate + self.demo_cfg_scales = demo_cfg_scales + self.last_demo_step = -1 + + @rank_zero_only + @torch.no_grad() + def on_train_batch_end(self, trainer, module: DiffusionCondTrainingWrapper, outputs, batch, batch_idx): + if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: + return + + self.last_demo_step = trainer.global_step + + try: + log_dict = {} + + demo_reals, metadata = next(self.demo_dl) + + # Remove extra dimension added by WebDataset + if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: + demo_reals = demo_reals[0] + + demo_reals = demo_reals.to(module.device) + + + # Log the real audio + log_dict[f'demo_reals_melspec_left'] = wandb.Image(audio_spectrogram_image(rearrange(demo_reals, "b d n -> d (b n)").mul(32767).to(torch.int16).cpu())) + # log_dict[f'demo_reals'] = wandb.Audio(rearrange(demo_reals, "b d n -> d (b n)").mul(32767).to(torch.int16).cpu(), sample_rate=self.sample_rate, caption="demo reals") + + if module.diffusion.pretransform is not None: + module.diffusion.pretransform.to(module.device) + with torch.cuda.amp.autocast(): + demo_reals = module.diffusion.pretransform.encode(demo_reals) + + demo_samples = demo_reals.shape[2] + + # Get conditioning + conditioning = module.diffusion.conditioner(metadata, module.device) + + masked_input, mask = module.random_mask(demo_reals, demo_reals.shape[2]) + + conditioning['inpaint_mask'] = [mask] + conditioning['inpaint_masked_input'] = [masked_input] + + if module.diffusion.pretransform is not None: + log_dict[f'demo_masked_input'] = wandb.Image(tokens_spectrogram_image(masked_input.cpu())) + else: + log_dict[f'demo_masked_input'] = wandb.Image(audio_spectrogram_image(rearrange(masked_input, "b c t -> c (b t)").mul(32767).to(torch.int16).cpu())) + + cond_inputs = module.diffusion.get_conditioning_inputs(conditioning) + + noise = torch.randn([demo_reals.shape[0], module.diffusion.io_channels, demo_samples]).to(module.device) + + trainer.logger.experiment.log(log_dict) + + for cfg_scale in self.demo_cfg_scales: + + print(f"Generating demo for cfg scale {cfg_scale}") + fakes = sample(module.diffusion_ema.model, noise, self.demo_steps, 0, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True) + + if module.diffusion.pretransform is not None: + with torch.cuda.amp.autocast(): + fakes = module.diffusion.pretransform.decode(fakes) + + # Put the demos together + fakes = rearrange(fakes, 'b d n -> d (b n)') + + log_dict = {} + + filename = f'demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav' + fakes = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).mul(32767).to(torch.int16).cpu() + torchaudio.save(filename, fakes, self.sample_rate) + + log_dict[f'demo_cfg_{cfg_scale}'] = wandb.Audio(filename, + sample_rate=self.sample_rate, + caption=f'Reconstructed') + + log_dict[f'demo_melspec_left_cfg_{cfg_scale}'] = wandb.Image(audio_spectrogram_image(fakes)) + + trainer.logger.experiment.log(log_dict) + except Exception as e: + print(f'{type(e).__name__}: {e}') + raise e + +class DiffusionAutoencoderTrainingWrapper(pl.LightningModule): + ''' + Wrapper for training a diffusion autoencoder + ''' + def __init__( + self, + model: DiffusionAutoencoder, + lr: float = 1e-4, + ema_copy = None, + use_reconstruction_loss: bool = False + ): + super().__init__() + + self.diffae = model + + self.diffae_ema = EMA( + self.diffae, + ema_model=ema_copy, + beta=0.9999, + power=3/4, + update_every=1, + update_after_step=1, + include_online_model=False + ) + + self.lr = lr + + self.rng = torch.quasirandom.SobolEngine(1, scramble=True) + + loss_modules = [ + MSELoss("v", + "targets", + weight=1.0, + name="mse_loss" + ) + ] + + if model.bottleneck is not None: + # TODO: Use loss config for configurable bottleneck weights and reconstruction losses + loss_modules += create_loss_modules_from_bottleneck(model.bottleneck, {}) + + self.use_reconstruction_loss = use_reconstruction_loss + + if use_reconstruction_loss: + scales = [2048, 1024, 512, 256, 128, 64, 32] + hop_sizes = [] + win_lengths = [] + overlap = 0.75 + for s in scales: + hop_sizes.append(int(s * (1 - overlap))) + win_lengths.append(s) + + sample_rate = model.sample_rate + + stft_loss_args = { + "fft_sizes": scales, + "hop_sizes": hop_sizes, + "win_lengths": win_lengths, + "perceptual_weighting": True + } + + out_channels = model.out_channels + + if model.pretransform is not None: + out_channels = model.pretransform.io_channels + + if out_channels == 2: + self.sdstft = auraloss.freq.SumAndDifferenceSTFTLoss(sample_rate=sample_rate, **stft_loss_args) + else: + self.sdstft = auraloss.freq.MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args) + + loss_modules.append( + AuralossLoss(self.sdstft, 'audio_reals', 'audio_pred', name='mrstft_loss', weight=0.1), # Reconstruction loss + ) + + self.losses = MultiLoss(loss_modules) + + def configure_optimizers(self): + return optim.Adam([*self.diffae.parameters()], lr=self.lr) + + def training_step(self, batch, batch_idx): + reals = batch[0] + + if reals.ndim == 4 and reals.shape[0] == 1: + reals = reals[0] + + loss_info = {} + + loss_info["audio_reals"] = reals + + if self.diffae.pretransform is not None: + with torch.no_grad(): + reals = self.diffae.pretransform.encode(reals) + + loss_info["reals"] = reals + + #Encode reals, skipping the pretransform since it was already applied + latents, encoder_info = self.diffae.encode(reals, return_info=True, skip_pretransform=True) + + loss_info["latents"] = latents + loss_info.update(encoder_info) + + if self.diffae.decoder is not None: + latents = self.diffae.decoder(latents) + + # Upsample latents to match diffusion length + if latents.shape[2] != reals.shape[2]: + latents = F.interpolate(latents, size=reals.shape[2], mode='nearest') + + loss_info["latents_upsampled"] = latents + + # Draw uniformly distributed continuous timesteps + t = self.rng.draw(reals.shape[0])[:, 0].to(self.device) + + # Calculate the noise schedule parameters for those timesteps + alphas, sigmas = get_alphas_sigmas(t) + + # Combine the ground truth data and the noise + alphas = alphas[:, None, None] + sigmas = sigmas[:, None, None] + noise = torch.randn_like(reals) + noised_reals = reals * alphas + noise * sigmas + targets = noise * alphas - reals * sigmas + + with torch.cuda.amp.autocast(): + v = self.diffae.diffusion(noised_reals, t, input_concat_cond=latents) + + loss_info.update({ + "v": v, + "targets": targets + }) + + if self.use_reconstruction_loss: + pred = noised_reals * alphas - v * sigmas + + loss_info["pred"] = pred + + if self.diffae.pretransform is not None: + pred = self.diffae.pretransform.decode(pred) + loss_info["audio_pred"] = pred + + loss, losses = self.losses(loss_info) + + log_dict = { + 'train/loss': loss.detach(), + 'train/std_data': reals.std(), + 'train/latent_std': latents.std(), + } + + for loss_name, loss_value in losses.items(): + log_dict[f"train/{loss_name}"] = loss_value.detach() + + self.log_dict(log_dict, prog_bar=True, on_step=True) + return loss + + def on_before_zero_grad(self, *args, **kwargs): + self.diffae_ema.update() + + def export_model(self, path, use_safetensors=False): + + model = self.diffae_ema.ema_model + + if use_safetensors: + save_file(model.state_dict(), path) + else: + torch.save({"state_dict": model.state_dict()}, path) + +class DiffusionAutoencoderDemoCallback(pl.Callback): + def __init__( + self, + demo_dl, + demo_every=2000, + demo_steps=250, + sample_size=65536, + sample_rate=48000 + ): + super().__init__() + self.demo_every = demo_every + self.demo_steps = demo_steps + self.demo_samples = sample_size + self.demo_dl = iter(demo_dl) + self.sample_rate = sample_rate + self.last_demo_step = -1 + + @rank_zero_only + @torch.no_grad() + def on_train_batch_end(self, trainer, module: DiffusionAutoencoderTrainingWrapper, outputs, batch, batch_idx): + if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: + return + + self.last_demo_step = trainer.global_step + + demo_reals, _ = next(self.demo_dl) + + # Remove extra dimension added by WebDataset + if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: + demo_reals = demo_reals[0] + + encoder_input = demo_reals + + encoder_input = encoder_input.to(module.device) + + demo_reals = demo_reals.to(module.device) + + with torch.no_grad() and torch.cuda.amp.autocast(): + latents = module.diffae_ema.ema_model.encode(encoder_input).float() + fakes = module.diffae_ema.ema_model.decode(latents, steps=self.demo_steps) + + #Interleave reals and fakes + reals_fakes = rearrange([demo_reals, fakes], 'i b d n -> (b i) d n') + + # Put the demos together + reals_fakes = rearrange(reals_fakes, 'b d n -> d (b n)') + + log_dict = {} + + filename = f'recon_{trainer.global_step:08}.wav' + reals_fakes = reals_fakes.to(torch.float32).div(torch.max(torch.abs(reals_fakes))).mul(32767).to(torch.int16).cpu() + torchaudio.save(filename, reals_fakes, self.sample_rate) + + log_dict[f'recon'] = wandb.Audio(filename, + sample_rate=self.sample_rate, + caption=f'Reconstructed') + + log_dict[f'embeddings_3dpca'] = pca_point_cloud(latents) + log_dict[f'embeddings_spec'] = wandb.Image(tokens_spectrogram_image(latents)) + + log_dict[f'recon_melspec_left'] = wandb.Image(audio_spectrogram_image(reals_fakes)) + + if module.diffae_ema.ema_model.pretransform is not None: + with torch.no_grad() and torch.cuda.amp.autocast(): + initial_latents = module.diffae_ema.ema_model.pretransform.encode(encoder_input) + first_stage_fakes = module.diffae_ema.ema_model.pretransform.decode(initial_latents) + first_stage_fakes = rearrange(first_stage_fakes, 'b d n -> d (b n)') + first_stage_fakes = first_stage_fakes.to(torch.float32).mul(32767).to(torch.int16).cpu() + first_stage_filename = f'first_stage_{trainer.global_step:08}.wav' + torchaudio.save(first_stage_filename, first_stage_fakes, self.sample_rate) + + log_dict[f'first_stage_latents'] = wandb.Image(tokens_spectrogram_image(initial_latents)) + + log_dict[f'first_stage'] = wandb.Audio(first_stage_filename, + sample_rate=self.sample_rate, + caption=f'First Stage Reconstructed') + + log_dict[f'first_stage_melspec_left'] = wandb.Image(audio_spectrogram_image(first_stage_fakes)) + + + trainer.logger.experiment.log(log_dict) + +def create_source_mixture(reals, num_sources=2): + # Create a fake mixture source by mixing elements from the training batch together with random offsets + source = torch.zeros_like(reals) + for i in range(reals.shape[0]): + sources_added = 0 + + js = list(range(reals.shape[0])) + random.shuffle(js) + for j in js: + if i == j or (i != j and sources_added < num_sources): + # Randomly offset the mixed element between 0 and the length of the source + seq_len = reals.shape[2] + offset = random.randint(0, seq_len-1) + source[i, :, offset:] += reals[j, :, :-offset] + if i == j: + # If this is the real one, shift the reals as well to ensure alignment + new_reals = torch.zeros_like(reals[i]) + new_reals[:, offset:] = reals[i, :, :-offset] + reals[i] = new_reals + sources_added += 1 + + return source + +class DiffusionPriorTrainingWrapper(pl.LightningModule): + ''' + Wrapper for training a diffusion prior for inverse problems + Prior types: + mono_stereo: The prior is conditioned on a mono version of the audio to generate a stereo version + ''' + def __init__( + self, + model: ConditionedDiffusionModelWrapper, + lr: float = 1e-4, + ema_copy = None, + prior_type: PriorType = PriorType.MonoToStereo, + use_reconstruction_loss: bool = False, + log_loss_info: bool = False, + ): + super().__init__() + + self.diffusion = model + + self.diffusion_ema = EMA( + self.diffusion, + ema_model=ema_copy, + beta=0.9999, + power=3/4, + update_every=1, + update_after_step=1, + include_online_model=False + ) + + self.lr = lr + + self.rng = torch.quasirandom.SobolEngine(1, scramble=True) + + self.log_loss_info = log_loss_info + + loss_modules = [ + MSELoss("v", + "targets", + weight=1.0, + name="mse_loss" + ) + ] + + self.use_reconstruction_loss = use_reconstruction_loss + + if use_reconstruction_loss: + scales = [2048, 1024, 512, 256, 128, 64, 32] + hop_sizes = [] + win_lengths = [] + overlap = 0.75 + for s in scales: + hop_sizes.append(int(s * (1 - overlap))) + win_lengths.append(s) + + sample_rate = model.sample_rate + + stft_loss_args = { + "fft_sizes": scales, + "hop_sizes": hop_sizes, + "win_lengths": win_lengths, + "perceptual_weighting": True + } + + out_channels = model.io_channels + + self.audio_out_channels = out_channels + + if model.pretransform is not None: + out_channels = model.pretransform.io_channels + + if self.audio_out_channels == 2: + self.sdstft = auraloss.freq.SumAndDifferenceSTFTLoss(sample_rate=sample_rate, **stft_loss_args) + self.lrstft = auraloss.freq.MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args) + + # Add left and right channel reconstruction losses in addition to the sum and difference + self.loss_modules += [ + AuralossLoss(self.lrstft, 'audio_reals_left', 'pred_left', name='stft_loss_left', weight=0.05), + AuralossLoss(self.lrstft, 'audio_reals_right', 'pred_right', name='stft_loss_right', weight=0.05), + ] + + else: + self.sdstft = auraloss.freq.MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args) + + self.loss_modules.append( + AuralossLoss(self.sdstft, 'audio_reals', 'audio_pred', name='mrstft_loss', weight=0.1), # Reconstruction loss + ) + + self.losses = MultiLoss(loss_modules) + + self.prior_type = prior_type + + def configure_optimizers(self): + return optim.Adam([*self.diffusion.parameters()], lr=self.lr) + + def training_step(self, batch, batch_idx): + reals, metadata = batch + + if reals.ndim == 4 and reals.shape[0] == 1: + reals = reals[0] + + loss_info = {} + + loss_info["audio_reals"] = reals + + if self.prior_type == PriorType.MonoToStereo: + source = reals.mean(dim=1, keepdim=True).repeat(1, reals.shape[1], 1).to(self.device) + loss_info["audio_reals_mono"] = source + elif self.prior_type == PriorType.SourceSeparation: + source = create_source_mixture(reals) + loss_info["audio_mixture"] = source + else: + raise ValueError(f"Unknown prior type {self.prior_type}") + + if self.diffusion.pretransform is not None: + with torch.no_grad(): + reals = self.diffusion.pretransform.encode(reals) + + if self.prior_type in [PriorType.MonoToStereo, PriorType.SourceSeparation]: + source = self.diffusion.pretransform.encode(source) + + if self.diffusion.conditioner is not None: + with torch.cuda.amp.autocast(): + conditioning = self.diffusion.conditioner(metadata, self.device) + else: + conditioning = {} + + loss_info["reals"] = reals + + # Draw uniformly distributed continuous timesteps + t = self.rng.draw(reals.shape[0])[:, 0].to(self.device) + + # Calculate the noise schedule parameters for those timesteps + alphas, sigmas = get_alphas_sigmas(t) + + # Combine the ground truth data and the noise + alphas = alphas[:, None, None] + sigmas = sigmas[:, None, None] + noise = torch.randn_like(reals) + noised_reals = reals * alphas + noise * sigmas + targets = noise * alphas - reals * sigmas + + with torch.cuda.amp.autocast(): + + conditioning['source'] = [source] + + v = self.diffusion(noised_reals, t, cond=conditioning, cfg_dropout_prob = 0.1) + + loss_info.update({ + "v": v, + "targets": targets + }) + + if self.use_reconstruction_loss: + pred = noised_reals * alphas - v * sigmas + + loss_info["pred"] = pred + + if self.diffusion.pretransform is not None: + pred = self.diffusion.pretransform.decode(pred) + loss_info["audio_pred"] = pred + + if self.audio_out_channels == 2: + loss_info["pred_left"] = pred[:, 0:1, :] + loss_info["pred_right"] = pred[:, 1:2, :] + loss_info["audio_reals_left"] = loss_info["audio_reals"][:, 0:1, :] + loss_info["audio_reals_right"] = loss_info["audio_reals"][:, 1:2, :] + + loss, losses = self.losses(loss_info) + + if self.log_loss_info: + # Loss debugging logs + num_loss_buckets = 10 + bucket_size = 1 / num_loss_buckets + loss_all = F.mse_loss(v, targets, reduction="none") + + sigmas = rearrange(self.all_gather(sigmas), "w b c n -> (w b) c n").squeeze() + + # gather loss_all across all GPUs + loss_all = rearrange(self.all_gather(loss_all), "w b c n -> (w b) c n") + + # Bucket loss values based on corresponding sigma values, bucketing sigma values by bucket_size + loss_all = torch.stack([loss_all[(sigmas >= i) & (sigmas < i + bucket_size)].mean() for i in torch.arange(0, 1, bucket_size).to(self.device)]) + + # Log bucketed losses with corresponding sigma bucket values, if it's not NaN + debug_log_dict = { + f"model/loss_all_{i/num_loss_buckets:.1f}": loss_all[i].detach() for i in range(num_loss_buckets) if not torch.isnan(loss_all[i]) + } + + self.log_dict(debug_log_dict) + + log_dict = { + 'train/loss': loss.detach(), + 'train/std_data': reals.std() + } + + for loss_name, loss_value in losses.items(): + log_dict[f"train/{loss_name}"] = loss_value.detach() + + self.log_dict(log_dict, prog_bar=True, on_step=True) + return loss + + def on_before_zero_grad(self, *args, **kwargs): + self.diffusion_ema.update() + + def export_model(self, path, use_safetensors=False): + + #model = self.diffusion_ema.ema_model + model = self.diffusion + + if use_safetensors: + save_file(model.state_dict(), path) + else: + torch.save({"state_dict": model.state_dict()}, path) + +class DiffusionPriorDemoCallback(pl.Callback): + def __init__( + self, + demo_dl, + demo_every=2000, + demo_steps=250, + sample_size=65536, + sample_rate=48000 + ): + super().__init__() + self.demo_every = demo_every + self.demo_steps = demo_steps + self.demo_samples = sample_size + self.demo_dl = iter(demo_dl) + self.sample_rate = sample_rate + self.last_demo_step = -1 + + @rank_zero_only + @torch.no_grad() + def on_train_batch_end(self, trainer, module: DiffusionAutoencoderTrainingWrapper, outputs, batch, batch_idx): + if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: + return + + self.last_demo_step = trainer.global_step + + demo_reals, metadata = next(self.demo_dl) + + # Remove extra dimension added by WebDataset + if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: + demo_reals = demo_reals[0] + + demo_reals = demo_reals.to(module.device) + + encoder_input = demo_reals + + if module.diffusion.conditioner is not None: + with torch.cuda.amp.autocast(): + conditioning_tensors = module.diffusion.conditioner(metadata, module.device) + + else: + conditioning_tensors = {} + + + with torch.no_grad() and torch.cuda.amp.autocast(): + if module.prior_type == PriorType.MonoToStereo and encoder_input.shape[1] > 1: + source = encoder_input.mean(dim=1, keepdim=True).repeat(1, encoder_input.shape[1], 1).to(module.device) + elif module.prior_type == PriorType.SourceSeparation: + source = create_source_mixture(encoder_input) + + if module.diffusion.pretransform is not None: + encoder_input = module.diffusion.pretransform.encode(encoder_input) + source_input = module.diffusion.pretransform.encode(source) + else: + source_input = source + + conditioning_tensors['source'] = [source_input] + + fakes = sample(module.diffusion_ema.model, torch.randn_like(encoder_input), self.demo_steps, 0, cond=conditioning_tensors) + + if module.diffusion.pretransform is not None: + fakes = module.diffusion.pretransform.decode(fakes) + + #Interleave reals and fakes + reals_fakes = rearrange([demo_reals, fakes], 'i b d n -> (b i) d n') + + # Put the demos together + reals_fakes = rearrange(reals_fakes, 'b d n -> d (b n)') + + log_dict = {} + + filename = f'recon_{trainer.global_step:08}.wav' + reals_fakes = reals_fakes.to(torch.float32).div(torch.max(torch.abs(reals_fakes))).mul(32767).to(torch.int16).cpu() + torchaudio.save(filename, reals_fakes, self.sample_rate) + + log_dict[f'recon'] = wandb.Audio(filename, + sample_rate=self.sample_rate, + caption=f'Reconstructed') + + log_dict[f'recon_melspec_left'] = wandb.Image(audio_spectrogram_image(reals_fakes)) + + #Log the source + filename = f'source_{trainer.global_step:08}.wav' + source = rearrange(source, 'b d n -> d (b n)') + source = source.to(torch.float32).mul(32767).to(torch.int16).cpu() + torchaudio.save(filename, source, self.sample_rate) + + log_dict[f'source'] = wandb.Audio(filename, + sample_rate=self.sample_rate, + caption=f'Source') + + log_dict[f'source_melspec_left'] = wandb.Image(audio_spectrogram_image(source)) + + trainer.logger.experiment.log(log_dict) \ No newline at end of file diff --git a/stable_audio_tools/training/factory.py b/stable_audio_tools/training/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..a11badb10ad0057a2bedcbb5e584bdd46e99eb9e --- /dev/null +++ b/stable_audio_tools/training/factory.py @@ -0,0 +1,259 @@ +import torch +from torch.nn import Parameter +from ..models.factory import create_model_from_config + +def create_training_wrapper_from_config(model_config, model): + model_type = model_config.get('model_type', None) + assert model_type is not None, 'model_type must be specified in model config' + + training_config = model_config.get('training', None) + assert training_config is not None, 'training config must be specified in model config' + + if model_type == 'autoencoder': + from .autoencoders import AutoencoderTrainingWrapper + + ema_copy = None + + if training_config.get("use_ema", False): + ema_copy = create_model_from_config(model_config) + ema_copy = create_model_from_config(model_config) # I don't know why this needs to be called twice but it broke when I called it once + # Copy each weight to the ema copy + for name, param in model.state_dict().items(): + if isinstance(param, Parameter): + # backwards compatibility for serialized parameters + param = param.data + ema_copy.state_dict()[name].copy_(param) + + use_ema = training_config.get("use_ema", False) + + latent_mask_ratio = training_config.get("latent_mask_ratio", 0.0) + + teacher_model = training_config.get("teacher_model", None) + if teacher_model is not None: + teacher_model = create_model_from_config(teacher_model) + teacher_model = teacher_model.eval().requires_grad_(False) + + teacher_model_ckpt = training_config.get("teacher_model_ckpt", None) + if teacher_model_ckpt is not None: + teacher_model.load_state_dict(torch.load(teacher_model_ckpt)["state_dict"]) + else: + raise ValueError("teacher_model_ckpt must be specified if teacher_model is specified") + + return AutoencoderTrainingWrapper( + model, + lr=training_config["learning_rate"], + warmup_steps=training_config.get("warmup_steps", 0), + encoder_freeze_on_warmup=training_config.get("encoder_freeze_on_warmup", False), + sample_rate=model_config["sample_rate"], + loss_config=training_config.get("loss_configs", None), + optimizer_configs=training_config.get("optimizer_configs", None), + use_ema=use_ema, + ema_copy=ema_copy if use_ema else None, + force_input_mono=training_config.get("force_input_mono", False), + latent_mask_ratio=latent_mask_ratio, + teacher_model=teacher_model + ) + elif model_type == 'diffusion_uncond': + from .diffusion import DiffusionUncondTrainingWrapper + return DiffusionUncondTrainingWrapper( + model, + lr=training_config["learning_rate"], + ) + elif model_type == 'diffusion_cond': + from .diffusion import DiffusionCondTrainingWrapper + return DiffusionCondTrainingWrapper( + model, + lr=training_config.get("learning_rate", None), + causal_dropout=training_config.get("causal_dropout", 0.0), + mask_padding=training_config.get("mask_padding", False), + mask_padding_dropout=training_config.get("mask_padding_dropout", 0.0), + use_ema = training_config.get("use_ema", True), + log_loss_info=training_config.get("log_loss_info", False), + optimizer_configs=training_config.get("optimizer_configs", None), + use_reconstruction_loss=training_config.get("use_reconstruction_loss", False), + ) + elif model_type == 'diffusion_prior': + from .diffusion import DiffusionPriorTrainingWrapper + from ..models.diffusion_prior import PriorType + + ema_copy = create_model_from_config(model_config) + + # Copy each weight to the ema copy + for name, param in model.state_dict().items(): + if isinstance(param, Parameter): + # backwards compatibility for serialized parameters + param = param.data + ema_copy.state_dict()[name].copy_(param) + + prior_type = training_config.get("prior_type", "mono_stereo") + + if prior_type == "mono_stereo": + prior_type_enum = PriorType.MonoToStereo + elif prior_type == "source_separation": + prior_type_enum = PriorType.SourceSeparation + else: + raise ValueError(f"Unknown prior type: {prior_type}") + + return DiffusionPriorTrainingWrapper( + model, + lr=training_config["learning_rate"], + ema_copy=ema_copy, + prior_type=prior_type_enum, + log_loss_info=training_config.get("log_loss_info", False), + use_reconstruction_loss=training_config.get("use_reconstruction_loss", False), + ) + elif model_type == 'diffusion_cond_inpaint': + from .diffusion import DiffusionCondInpaintTrainingWrapper + return DiffusionCondInpaintTrainingWrapper( + model, + lr=training_config["learning_rate"] + ) + elif model_type == 'diffusion_autoencoder': + from .diffusion import DiffusionAutoencoderTrainingWrapper + + ema_copy = create_model_from_config(model_config) + + # Copy each weight to the ema copy + for name, param in model.state_dict().items(): + if isinstance(param, Parameter): + # backwards compatibility for serialized parameters + param = param.data + ema_copy.state_dict()[name].copy_(param) + + return DiffusionAutoencoderTrainingWrapper( + model, + ema_copy=ema_copy, + lr=training_config["learning_rate"], + use_reconstruction_loss=training_config.get("use_reconstruction_loss", False) + ) + elif model_type == 'musicgen': + from .musicgen import MusicGenTrainingWrapper + + ema_copy = create_model_from_config(model_config).lm + + for name, param in model.lm.state_dict().items(): + if isinstance(param, Parameter): + # backwards compatibility for serialized parameters + param = param.data + ema_copy.state_dict()[name].copy_(param) + + return MusicGenTrainingWrapper( + model, + ema_copy=ema_copy, + lr=training_config["learning_rate"] + ) + elif model_type == 'lm': + from .lm import AudioLanguageModelTrainingWrapper + + ema_copy = create_model_from_config(model_config) + + for name, param in model.state_dict().items(): + if isinstance(param, Parameter): + # backwards compatibility for serialized parameters + param = param.data + ema_copy.state_dict()[name].copy_(param) + + return AudioLanguageModelTrainingWrapper( + model, + ema_copy=ema_copy, + lr=training_config.get("learning_rate", None), + use_ema=training_config.get("use_ema", False), + optimizer_configs=training_config.get("optimizer_configs", None), + ) + + else: + raise NotImplementedError(f'Unknown model type: {model_type}') + +def create_demo_callback_from_config(model_config, **kwargs): + model_type = model_config.get('model_type', None) + assert model_type is not None, 'model_type must be specified in model config' + + training_config = model_config.get('training', None) + assert training_config is not None, 'training config must be specified in model config' + + demo_config = training_config.get("demo", {}) + + if model_type == 'autoencoder': + from .autoencoders import AutoencoderDemoCallback + return AutoencoderDemoCallback( + demo_every=demo_config.get("demo_every", 2000), + sample_size=model_config["sample_size"], + sample_rate=model_config["sample_rate"], + **kwargs + ) + elif model_type == 'diffusion_uncond': + from .diffusion import DiffusionUncondDemoCallback + return DiffusionUncondDemoCallback( + demo_every=demo_config.get("demo_every", 2000), + demo_steps=demo_config.get("demo_steps", 250), + sample_rate=model_config["sample_rate"] + ) + elif model_type == "diffusion_autoencoder": + from .diffusion import DiffusionAutoencoderDemoCallback + return DiffusionAutoencoderDemoCallback( + demo_every=demo_config.get("demo_every", 2000), + demo_steps=demo_config.get("demo_steps", 250), + sample_size=model_config["sample_size"], + sample_rate=model_config["sample_rate"], + **kwargs + ) + elif model_type == "diffusion_prior": + from .diffusion import DiffusionPriorDemoCallback + return DiffusionPriorDemoCallback( + demo_every=demo_config.get("demo_every", 2000), + demo_steps=demo_config.get("demo_steps", 250), + sample_size=model_config["sample_size"], + sample_rate=model_config["sample_rate"], + **kwargs + ) + elif model_type == "diffusion_cond": + from .diffusion import DiffusionCondDemoCallback + + return DiffusionCondDemoCallback( + demo_every=demo_config.get("demo_every", 2000), + sample_size=model_config["sample_size"], + sample_rate=model_config["sample_rate"], + demo_steps=demo_config.get("demo_steps", 250), + num_demos=demo_config["num_demos"], + demo_cfg_scales=demo_config["demo_cfg_scales"], + demo_conditioning=demo_config.get("demo_cond", {}), + demo_cond_from_batch=demo_config.get("demo_cond_from_batch", False), + display_audio_cond=demo_config.get("display_audio_cond", False), + ) + elif model_type == "diffusion_cond_inpaint": + from .diffusion import DiffusionCondInpaintDemoCallback + + return DiffusionCondInpaintDemoCallback( + demo_every=demo_config.get("demo_every", 2000), + sample_size=model_config["sample_size"], + sample_rate=model_config["sample_rate"], + demo_steps=demo_config.get("demo_steps", 250), + demo_cfg_scales=demo_config["demo_cfg_scales"], + **kwargs + ) + elif model_type == "musicgen": + from .musicgen import MusicGenDemoCallback + + return MusicGenDemoCallback( + demo_every=demo_config.get("demo_every", 2000), + sample_size=model_config["sample_size"], + sample_rate=model_config["sample_rate"], + demo_cfg_scales=demo_config["demo_cfg_scales"], + demo_conditioning=demo_config["demo_cond"], + **kwargs + ) + + elif model_type == "lm": + from .lm import AudioLanguageModelDemoCallback + + return AudioLanguageModelDemoCallback( + demo_every=demo_config.get("demo_every", 2000), + sample_size=model_config["sample_size"], + sample_rate=model_config["sample_rate"], + demo_cfg_scales=demo_config.get("demo_cfg_scales", [1]), + demo_conditioning=demo_config.get("demo_cond", None), + num_demos=demo_config.get("num_demos", 8), + **kwargs + ) + else: + raise NotImplementedError(f'Unknown model type: {model_type}') \ No newline at end of file diff --git a/stable_audio_tools/training/lm.py b/stable_audio_tools/training/lm.py new file mode 100644 index 0000000000000000000000000000000000000000..06b9298ed218cc63633a5c161d449dff2a31de47 --- /dev/null +++ b/stable_audio_tools/training/lm.py @@ -0,0 +1,254 @@ +import pytorch_lightning as pl +import sys, gc +import random +import torch +import torchaudio +import typing as tp +import wandb + +from aeiou.viz import pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image +from ema_pytorch import EMA +from einops import rearrange +from safetensors.torch import save_file +from torch import optim +from torch.nn import functional as F +from pytorch_lightning.utilities.rank_zero import rank_zero_only + +from ..models.lm import AudioLanguageModelWrapper +from .utils import create_optimizer_from_config, create_scheduler_from_config + +class AudioLanguageModelTrainingWrapper(pl.LightningModule): + def __init__( + self, + model: AudioLanguageModelWrapper, + lr = 1e-4, + use_ema=False, + ema_copy=None, + optimizer_configs: dict = None, + ): + super().__init__() + + self.model = model + + self.model.pretransform.requires_grad_(False) + + self.model_ema = None + if use_ema: + self.model_ema = EMA(self.model, ema_model=ema_copy, beta=0.99, update_every=10) + + assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config" + + if optimizer_configs is None: + optimizer_configs = { + "lm": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": lr, + "betas": (0.9, 0.95), + "weight_decay": 0.1 + } + } + } + } + else: + if lr is not None: + print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.") + + self.optimizer_configs = optimizer_configs + + def configure_optimizers(self): + lm_opt_config = self.optimizer_configs['lm'] + opt_lm = create_optimizer_from_config(lm_opt_config['optimizer'], self.model.parameters()) + + if "scheduler" in lm_opt_config: + sched_lm = create_scheduler_from_config(lm_opt_config['scheduler'], opt_lm) + sched_lm_config = { + "scheduler": sched_lm, + "interval": "step" + } + return [opt_lm], [sched_lm_config] + + return [opt_lm] + + # Copied and modified from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/solvers/musicgen.py under MIT license + # License can be found in LICENSES/LICENSE_META.txt + + def _compute_cross_entropy( + self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor + ) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]: + """Compute cross entropy between multi-codebook targets and model's logits. + The cross entropy is computed per codebook to provide codebook-level cross entropy. + Valid timesteps for each of the codebook are pulled from the mask, where invalid + timesteps are set to 0. + + Args: + logits (torch.Tensor): Model's logits of shape [B, K, T, card]. + targets (torch.Tensor): Target codes, of shape [B, K, T]. + mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T]. + Returns: + ce (torch.Tensor): Cross entropy averaged over the codebooks + ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached). + """ + B, K, T = targets.shape + assert logits.shape[:-1] == targets.shape + assert mask.shape == targets.shape + ce = torch.zeros([], device=targets.device) + ce_per_codebook: tp.List[torch.Tensor] = [] + for k in range(K): + logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1)) # [B x T, card] + targets_k = targets[:, k, ...].contiguous().view(-1) # [B x T] + mask_k = mask[:, k, ...].contiguous().view(-1) # [B x T] + ce_targets = targets_k[mask_k] + ce_logits = logits_k[mask_k] + q_ce = F.cross_entropy(ce_logits, ce_targets) + ce += q_ce + ce_per_codebook.append(q_ce.detach()) + # average cross entropy across codebooks + ce = ce / K + return ce, ce_per_codebook + + def training_step(self, batch, batch_idx): + reals, metadata = batch + + if reals.ndim == 4 and reals.shape[0] == 1: + reals = reals[0] + + + codes = self.model.pretransform.tokenize(reals) + + padding_masks = torch.stack([md["padding_mask"][0] for md in metadata], dim=0).to(self.device) # Shape (batch_size, sequence_length) + + # Interpolate padding masks to the same length as the codes + padding_masks = F.interpolate(padding_masks.unsqueeze(1).float(), size=codes.shape[2], mode='nearest').bool() + + condition_tensors = None + + # If the model is conditioned, get the conditioning tensors + if self.model.conditioner is not None: + condition_tensors = self.model.conditioner(metadata, self.device) + + lm_output = self.model.compute_logits(codes, condition_tensors=condition_tensors, cfg_dropout_prob=0.1) + + logits = lm_output.logits # [b, k, t, c] + logits_mask = lm_output.mask # [b, k, t] + + logits_mask = logits_mask & padding_masks + + cross_entropy, cross_entropy_per_codebook = self._compute_cross_entropy(logits, codes, logits_mask) + + loss = cross_entropy + + log_dict = { + 'train/loss': loss.detach(), + 'train/cross_entropy': cross_entropy.detach(), + 'train/perplexity': torch.exp(cross_entropy).detach(), + 'train/lr': self.trainer.optimizers[0].param_groups[0]['lr'] + } + + for k, ce_q in enumerate(cross_entropy_per_codebook): + log_dict[f'cross_entropy_q{k + 1}'] = ce_q + log_dict[f'perplexity_q{k + 1}'] = torch.exp(ce_q) + + self.log_dict(log_dict, prog_bar=True, on_step=True) + return loss + + def on_before_zero_grad(self, *args, **kwargs): + if self.model_ema is not None: + self.model_ema.update() + + def export_model(self, path, use_safetensors=False): + + model = self.model_ema.ema_model if self.model_ema is not None else self.model + + if use_safetensors: + save_file(model.state_dict(), path) + else: + torch.save({"state_dict": model.state_dict()}, path) + + +class AudioLanguageModelDemoCallback(pl.Callback): + def __init__(self, + demo_every=2000, + num_demos=8, + sample_size=65536, + sample_rate=48000, + demo_conditioning: tp.Optional[tp.Dict[str, tp.Any]] = None, + demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7], + **kwargs + ): + super().__init__() + + self.demo_every = demo_every + self.num_demos = num_demos + self.demo_samples = sample_size + self.sample_rate = sample_rate + self.last_demo_step = -1 + self.demo_conditioning = demo_conditioning + self.demo_cfg_scales = demo_cfg_scales + + @rank_zero_only + @torch.no_grad() + def on_train_batch_end(self, trainer, module: AudioLanguageModelTrainingWrapper, outputs, batch, batch_idx): + + if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: + return + + module.eval() + + print(f"Generating demo") + self.last_demo_step = trainer.global_step + + demo_length_tokens = self.demo_samples // module.model.pretransform.downsampling_ratio + + # demo_reals = batch[0][:self.num_demos] + + # if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: + # demo_reals = demo_reals[0] + + #demo_reals_tokens = module.model.pretransform.tokenize(demo_reals) + + # Limit to first 50 tokens + #demo_reals_tokens = demo_reals_tokens[:, :, :50] + + try: + print("Getting conditioning") + + for cfg_scale in self.demo_cfg_scales: + + model = module.model # module.model_ema.ema_model if module.model_ema is not None else module.model + + print(f"Generating demo for cfg scale {cfg_scale}") + fakes = model.generate_audio( + batch_size=self.num_demos, + max_gen_len=demo_length_tokens, + conditioning=self.demo_conditioning, + #init_data = demo_reals_tokens, + cfg_scale=cfg_scale, + temp=1.0, + top_p=0.95 + ) + + # Put the demos together + fakes = rearrange(fakes, 'b d n -> d (b n)') + + log_dict = {} + + filename = f'demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav' + fakes = fakes.clamp(-1, 1).mul(32766).to(torch.int16).cpu() + torchaudio.save(filename, fakes, self.sample_rate) + + log_dict[f'demo_cfg_{cfg_scale}'] = wandb.Audio(filename, + sample_rate=self.sample_rate, + caption=f'Reconstructed') + + log_dict[f'demo_melspec_left_cfg_{cfg_scale}'] = wandb.Image(audio_spectrogram_image(fakes)) + + trainer.logger.experiment.log(log_dict) + + except Exception as e: + raise e + finally: + gc.collect() + torch.cuda.empty_cache() + module.train() \ No newline at end of file diff --git a/stable_audio_tools/training/losses.py b/stable_audio_tools/training/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..55b154615d833f68ecc7283a7f78d6a4ab9def31 --- /dev/null +++ b/stable_audio_tools/training/losses.py @@ -0,0 +1,101 @@ +import typing as tp + +from torch.nn import functional as F +from torch import nn + +class LossModule(nn.Module): + def __init__(self, name: str, weight: float = 1.0): + super().__init__() + + self.name = name + self.weight = weight + + def forward(self, info, *args, **kwargs): + raise NotImplementedError + +class ValueLoss(LossModule): + def __init__(self, key: str, name, weight: float = 1.0): + super().__init__(name=name, weight=weight) + + self.key = key + + def forward(self, info): + return self.weight * info[self.key] + +class L1Loss(LossModule): + def __init__(self, key_a: str, key_b: str, weight: float = 1.0, mask_key: str = None, name: str = 'l1_loss'): + super().__init__(name=name, weight=weight) + + self.key_a = key_a + self.key_b = key_b + + self.mask_key = mask_key + + def forward(self, info): + mse_loss = F.l1_loss(info[self.key_a], info[self.key_b], reduction='none') + + if self.mask_key is not None and self.mask_key in info: + mse_loss = mse_loss[info[self.mask_key]] + + mse_loss = mse_loss.mean() + + return self.weight * mse_loss + +class MSELoss(LossModule): + def __init__(self, key_a: str, key_b: str, weight: float = 1.0, mask_key: str = None, name: str = 'mse_loss'): + super().__init__(name=name, weight=weight) + + self.key_a = key_a + self.key_b = key_b + + self.mask_key = mask_key + + def forward(self, info): + mse_loss = F.mse_loss(info[self.key_a], info[self.key_b], reduction='none') + + if self.mask_key is not None and self.mask_key in info and info[self.mask_key] is not None: + mask = info[self.mask_key] + + if mask.ndim == 2 and mse_loss.ndim == 3: + mask = mask.unsqueeze(1) + + if mask.shape[1] != mse_loss.shape[1]: + mask = mask.repeat(1, mse_loss.shape[1], 1) + + mse_loss = mse_loss[mask] + + mse_loss = mse_loss.mean() + + return self.weight * mse_loss + +class AuralossLoss(LossModule): + def __init__(self, auraloss_module, input_key: str, target_key: str, name: str, weight: float = 1): + super().__init__(name, weight) + + self.auraloss_module = auraloss_module + + self.input_key = input_key + self.target_key = target_key + + def forward(self, info): + loss = self.auraloss_module(info[self.input_key], info[self.target_key]) + + return self.weight * loss + +class MultiLoss(nn.Module): + def __init__(self, losses: tp.List[LossModule]): + super().__init__() + + self.losses = nn.ModuleList(losses) + + def forward(self, info): + total_loss = 0 + + losses = {} + + for loss_module in self.losses: + module_loss = loss_module(info) + total_loss += module_loss + losses[loss_module.name] = module_loss + + return total_loss, losses \ No newline at end of file diff --git a/stable_audio_tools/training/musicgen.py b/stable_audio_tools/training/musicgen.py new file mode 100644 index 0000000000000000000000000000000000000000..603963d9e36db15cfeb8e41245216fdd3dc514bd --- /dev/null +++ b/stable_audio_tools/training/musicgen.py @@ -0,0 +1,232 @@ +import pytorch_lightning as pl +import sys, gc +import random +import torch +import torchaudio +import typing as tp +import wandb + +from aeiou.viz import pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image +from ema_pytorch import EMA +from einops import rearrange +from torch import optim +from torch.nn import functional as F +from pytorch_lightning.utilities.rank_zero import rank_zero_only + +from audiocraft.models import MusicGen +from audiocraft.modules.conditioners import ClassifierFreeGuidanceDropout, ConditioningAttributes + +from time import time + +class Profiler: + + def __init__(self): + self.ticks = [[time(), None]] + + def tick(self, msg): + self.ticks.append([time(), msg]) + + def __repr__(self): + rep = 80 * "=" + "\n" + for i in range(1, len(self.ticks)): + msg = self.ticks[i][1] + ellapsed = self.ticks[i][0] - self.ticks[i - 1][0] + rep += msg + f": {ellapsed*1000:.2f}ms\n" + rep += 80 * "=" + "\n\n\n" + return rep + + +class MusicGenTrainingWrapper(pl.LightningModule): + def __init__(self, musicgen_model, lr = 1e-4, ema_copy=None): + super().__init__() + + self.musicgen_model: MusicGen = musicgen_model + + self.musicgen_model.compression_model.requires_grad_(False) + + self.lm = self.musicgen_model.lm + + self.lm.to(torch.float32).train().requires_grad_(True) + + self.lm_ema = EMA(self.lm, ema_model=ema_copy, beta=0.99, update_every=10) + + self.cfg_dropout = ClassifierFreeGuidanceDropout(0.1) + + self.lr = lr + + def configure_optimizers(self): + optimizer = optim.AdamW([*self.lm.parameters()], lr=self.lr, betas=(0.9, 0.95), weight_decay=0.1) + + return optimizer + + # Copied and modified from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/solvers/musicgen.py under MIT license + # License can be found in LICENSES/LICENSE_META.txt + + def _compute_cross_entropy( + self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor + ) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]: + """Compute cross entropy between multi-codebook targets and model's logits. + The cross entropy is computed per codebook to provide codebook-level cross entropy. + Valid timesteps for each of the codebook are pulled from the mask, where invalid + timesteps are set to 0. + + Args: + logits (torch.Tensor): Model's logits of shape [B, K, T, card]. + targets (torch.Tensor): Target codes, of shape [B, K, T]. + mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T]. + Returns: + ce (torch.Tensor): Cross entropy averaged over the codebooks + ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached). + """ + B, K, T = targets.shape + assert logits.shape[:-1] == targets.shape + assert mask.shape == targets.shape + ce = torch.zeros([], device=targets.device) + ce_per_codebook: tp.List[torch.Tensor] = [] + for k in range(K): + logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1)) # [B x T, card] + targets_k = targets[:, k, ...].contiguous().view(-1) # [B x T] + mask_k = mask[:, k, ...].contiguous().view(-1) # [B x T] + ce_targets = targets_k[mask_k] + ce_logits = logits_k[mask_k] + q_ce = F.cross_entropy(ce_logits, ce_targets) + ce += q_ce + ce_per_codebook.append(q_ce.detach()) + # average cross entropy across codebooks + ce = ce / K + return ce, ce_per_codebook + + def training_step(self, batch, batch_idx): + reals, metadata = batch + + if reals.ndim == 4 and reals.shape[0] == 1: + reals = reals[0] + + # Convert reals to mono if necessary + if self.musicgen_model.audio_channels == 1: + reals = reals.mean(dim=1, keepdim=True) + + self.musicgen_model.compression_model.to(self.device).eval() + self.lm.to(self.device).train() + self.lm.condition_provider.to(self.device).eval() + + self.lm.condition_provider.conditioners["description"].device = self.device + self.lm.condition_provider.conditioners["description"].t5.to(self.device).eval() + + with torch.cuda.amp.autocast(): + + codes, _ = self.musicgen_model.compression_model.encode(reals) # [b, k, t] + + attributes = [ConditioningAttributes(text={'description': md["prompt"][0][:512]}) for md in metadata] + attributes = self.lm.cfg_dropout(attributes) + attributes = self.lm.att_dropout(attributes) + tokenized = self.lm.condition_provider.tokenize(attributes) + + with torch.cuda.amp.autocast(enabled=False): + condition_tensors = self.lm.condition_provider(tokenized) + + lm_output = self.lm.compute_predictions( + codes=codes, + conditions = [], + condition_tensors = condition_tensors, + ) + + logits = lm_output.logits # [b, k, t, c] + logits_mask = lm_output.mask # [b, k, t] + + cross_entropy, cross_entropy_per_codebook = self._compute_cross_entropy(logits, codes, logits_mask) + + loss = cross_entropy + + log_dict = { + 'train/loss': loss.detach(), + 'train/cross_entropy': cross_entropy.detach(), + 'train/perplexity': torch.exp(cross_entropy).detach(), + } + + for k, ce_q in enumerate(cross_entropy_per_codebook): + log_dict[f'cross_entropy_q{k + 1}'] = ce_q + log_dict[f'perplexity_q{k + 1}'] = torch.exp(ce_q) + + self.log_dict(log_dict, prog_bar=True, on_step=True) + return loss + + def on_before_zero_grad(self, *args, **kwargs): + self.lm_ema.update() + + def export_model(self, path): + self.musicgen_model.lm = self.lm_ema.ema_model + export_state_dict = {"state_dict": self.musicgen_model.state_dict()} + + torch.save(export_state_dict, path) + +class MusicGenDemoCallback(pl.Callback): + def __init__(self, + demo_every=2000, + num_demos=8, + sample_size=65536, + sample_rate=48000, + demo_conditioning: tp.Optional[tp.Dict[str, tp.Any]] = None, + demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7], + **kwargs + ): + super().__init__() + + self.demo_every = demo_every + self.num_demos = num_demos + self.demo_samples = sample_size + self.sample_rate = sample_rate + self.last_demo_step = -1 + self.demo_conditioning = demo_conditioning + self.demo_cfg_scales = demo_cfg_scales + + @rank_zero_only + @torch.no_grad() + def on_train_batch_end(self, trainer, module: MusicGenTrainingWrapper, outputs, batch, batch_idx): + + if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: + return + + module.eval() + + print(f"Generating demo") + self.last_demo_step = trainer.global_step + + demo_length_sec = self.demo_samples // self.sample_rate + + try: + print("Getting conditioning") + + prompts = [md["prompt"][:512] for md in self.demo_conditioning] + + for cfg_scale in self.demo_cfg_scales: + + module.musicgen_model.set_generation_params(duration=demo_length_sec, cfg_coef=cfg_scale) + + with torch.cuda.amp.autocast(): + print(f"Generating demo for cfg scale {cfg_scale}") + fakes = module.musicgen_model.generate(prompts, progress=True) + + # Put the demos together + fakes = rearrange(fakes, 'b d n -> d (b n)') + + log_dict = {} + + filename = f'demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav' + fakes = fakes.to(torch.float32).clamp(-1, 1).mul(32767).to(torch.int16).cpu() + torchaudio.save(filename, fakes, self.sample_rate) + + log_dict[f'demo_cfg_{cfg_scale}'] = wandb.Audio(filename, + sample_rate=self.sample_rate, + caption=f'Reconstructed') + + log_dict[f'demo_melspec_left_cfg_{cfg_scale}'] = wandb.Image(audio_spectrogram_image(fakes)) + + trainer.logger.experiment.log(log_dict) + + except Exception as e: + raise e + finally: + gc.collect() + torch.cuda.empty_cache() + module.train() \ No newline at end of file diff --git a/stable_audio_tools/training/utils.py b/stable_audio_tools/training/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..38a3fcccb5420986f863d70e00abcbdb0a8f06a8 --- /dev/null +++ b/stable_audio_tools/training/utils.py @@ -0,0 +1,111 @@ +import torch +import os + +def get_rank(): + """Get rank of current process.""" + + print(os.environ.keys()) + + if "SLURM_PROCID" in os.environ: + return int(os.environ["SLURM_PROCID"]) + + if not torch.distributed.is_available() or not torch.distributed.is_initialized(): + return 0 + + return torch.distributed.get_rank() + +class InverseLR(torch.optim.lr_scheduler._LRScheduler): + """Implements an inverse decay learning rate schedule with an optional exponential + warmup. When last_epoch=-1, sets initial lr as lr. + inv_gamma is the number of steps/epochs required for the learning rate to decay to + (1 / 2)**power of its original value. + Args: + optimizer (Optimizer): Wrapped optimizer. + inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1. + power (float): Exponential factor of learning rate decay. Default: 1. + warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable) + Default: 0. + final_lr (float): The final learning rate. Default: 0. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + """ + + def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., final_lr=0., + last_epoch=-1, verbose=False): + self.inv_gamma = inv_gamma + self.power = power + if not 0. <= warmup < 1: + raise ValueError('Invalid value for warmup') + self.warmup = warmup + self.final_lr = final_lr + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + if not self._get_lr_called_within_step: + import warnings + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.") + + return self._get_closed_form_lr() + + def _get_closed_form_lr(self): + warmup = 1 - self.warmup ** (self.last_epoch + 1) + lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power + return [warmup * max(self.final_lr, base_lr * lr_mult) + for base_lr in self.base_lrs] + +def copy_state_dict(model, state_dict): + """Load state_dict to model, but only for keys that match exactly. + + Args: + model (nn.Module): model to load state_dict. + state_dict (OrderedDict): state_dict to load. + """ + model_state_dict = model.state_dict() + for key in state_dict: + if key in model_state_dict and state_dict[key].shape == model_state_dict[key].shape: + if isinstance(state_dict[key], torch.nn.Parameter): + # backwards compatibility for serialized parameters + state_dict[key] = state_dict[key].data + model_state_dict[key] = state_dict[key] + + model.load_state_dict(model_state_dict, strict=False) + +def create_optimizer_from_config(optimizer_config, parameters): + """Create optimizer from config. + + Args: + parameters (iterable): parameters to optimize. + optimizer_config (dict): optimizer config. + + Returns: + torch.optim.Optimizer: optimizer. + """ + + optimizer_type = optimizer_config["type"] + + if optimizer_type == "FusedAdam": + from deepspeed.ops.adam import FusedAdam + optimizer = FusedAdam(parameters, **optimizer_config["config"]) + else: + optimizer_fn = getattr(torch.optim, optimizer_type) + optimizer = optimizer_fn(parameters, **optimizer_config["config"]) + return optimizer + +def create_scheduler_from_config(scheduler_config, optimizer): + """Create scheduler from config. + + Args: + scheduler_config (dict): scheduler config. + optimizer (torch.optim.Optimizer): optimizer. + + Returns: + torch.optim.lr_scheduler._LRScheduler: scheduler. + """ + if scheduler_config["type"] == "InverseLR": + scheduler_fn = InverseLR + else: + scheduler_fn = getattr(torch.optim.lr_scheduler, scheduler_config["type"]) + scheduler = scheduler_fn(optimizer, **scheduler_config["config"]) + return scheduler \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000000000000000000000000000000000000..8dbbfbbb6a99a36c1bffdb613230aff2fb51fa06 --- /dev/null +++ b/train.py @@ -0,0 +1,128 @@ +from prefigure.prefigure import get_all_args, push_wandb_config +import json +import os +import torch +import pytorch_lightning as pl +import random + +from stable_audio_tools.data.dataset import create_dataloader_from_config +from stable_audio_tools.models import create_model_from_config +from stable_audio_tools.models.utils import load_ckpt_state_dict, remove_weight_norm_from_model +from stable_audio_tools.training import create_training_wrapper_from_config, create_demo_callback_from_config +from stable_audio_tools.training.utils import copy_state_dict + +class ExceptionCallback(pl.Callback): + def on_exception(self, trainer, module, err): + print(f'{type(err).__name__}: {err}') + +class ModelConfigEmbedderCallback(pl.Callback): + def __init__(self, model_config): + self.model_config = model_config + + def on_save_checkpoint(self, trainer, pl_module, checkpoint): + checkpoint["model_config"] = self.model_config + +def main(): + + args = get_all_args() + + seed = args.seed + + # Set a different seed for each process if using SLURM + if os.environ.get("SLURM_PROCID") is not None: + seed += int(os.environ.get("SLURM_PROCID")) + + random.seed(seed) + torch.manual_seed(seed) + + #Get JSON config from args.model_config + with open(args.model_config) as f: + model_config = json.load(f) + + with open(args.dataset_config) as f: + dataset_config = json.load(f) + + train_dl = create_dataloader_from_config( + dataset_config, + batch_size=args.batch_size, + num_workers=args.num_workers, + sample_rate=model_config["sample_rate"], + sample_size=model_config["sample_size"], + audio_channels=model_config.get("audio_channels", 2), + ) + + model = create_model_from_config(model_config) + + if args.pretrained_ckpt_path: + copy_state_dict(model, load_ckpt_state_dict(args.pretrained_ckpt_path)) + + if args.remove_pretransform_weight_norm == "pre_load": + remove_weight_norm_from_model(model.pretransform) + + if args.pretransform_ckpt_path: + model.pretransform.load_state_dict(load_ckpt_state_dict(args.pretransform_ckpt_path)) + + # Remove weight_norm from the pretransform if specified + if args.remove_pretransform_weight_norm == "post_load": + remove_weight_norm_from_model(model.pretransform) + + training_wrapper = create_training_wrapper_from_config(model_config, model) + + wandb_logger = pl.loggers.WandbLogger(project=args.name) + wandb_logger.watch(training_wrapper) + + exc_callback = ExceptionCallback() + + if args.save_dir and isinstance(wandb_logger.experiment.id, str): + checkpoint_dir = os.path.join(args.save_dir, wandb_logger.experiment.project, wandb_logger.experiment.id, "checkpoints") + else: + checkpoint_dir = None + + ckpt_callback = pl.callbacks.ModelCheckpoint(every_n_train_steps=args.checkpoint_every, dirpath=checkpoint_dir, save_top_k=-1) + save_model_config_callback = ModelConfigEmbedderCallback(model_config) + + demo_callback = create_demo_callback_from_config(model_config, demo_dl=train_dl) + + #Combine args and config dicts + args_dict = vars(args) + args_dict.update({"model_config": model_config}) + args_dict.update({"dataset_config": dataset_config}) + push_wandb_config(wandb_logger, args_dict) + + #Set multi-GPU strategy if specified + if args.strategy: + if args.strategy == "deepspeed": + from pytorch_lightning.strategies import DeepSpeedStrategy + strategy = DeepSpeedStrategy(stage=2, + contiguous_gradients=True, + overlap_comm=True, + reduce_scatter=True, + reduce_bucket_size=5e8, + allgather_bucket_size=5e8, + load_full_weights=True + ) + else: + strategy = args.strategy + else: + strategy = 'ddp_find_unused_parameters_true' if args.num_gpus > 1 else "auto" + + trainer = pl.Trainer( + devices=args.num_gpus, + accelerator="gpu", + num_nodes = args.num_nodes, + strategy=strategy, + precision=args.precision, + accumulate_grad_batches=args.accum_batches, + callbacks=[ckpt_callback, demo_callback, exc_callback, save_model_config_callback], + logger=wandb_logger, + log_every_n_steps=1, + max_epochs=10000000, + default_root_dir=args.save_dir, + gradient_clip_val=args.gradient_clip_val, + reload_dataloaders_every_n_epochs = 0 + ) + + trainer.fit(training_wrapper, train_dl, ckpt_path=args.ckpt_path if args.ckpt_path else None) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/unwrap_model.py b/unwrap_model.py new file mode 100644 index 0000000000000000000000000000000000000000..4afb7bb7c9229143dac893340862e02573c320ed --- /dev/null +++ b/unwrap_model.py @@ -0,0 +1,133 @@ +import argparse +import json +import torch +from torch.nn.parameter import Parameter +from stable_audio_tools.models import create_model_from_config + +if __name__ == '__main__': + args = argparse.ArgumentParser() + args.add_argument('--model-config', type=str, default=None) + args.add_argument('--ckpt-path', type=str, default=None) + args.add_argument('--name', type=str, default='exported_model') + args.add_argument('--use-safetensors', action='store_true') + + args = args.parse_args() + + with open(args.model_config) as f: + model_config = json.load(f) + + model = create_model_from_config(model_config) + + model_type = model_config.get('model_type', None) + + assert model_type is not None, 'model_type must be specified in model config' + + training_config = model_config.get('training', None) + + if model_type == 'autoencoder': + from stable_audio_tools.training.autoencoders import AutoencoderTrainingWrapper + + ema_copy = None + + if training_config.get("use_ema", False): + from stable_audio_tools.models.factory import create_model_from_config + ema_copy = create_model_from_config(model_config) + ema_copy = create_model_from_config(model_config) # I don't know why this needs to be called twice but it broke when I called it once + + # Copy each weight to the ema copy + for name, param in model.state_dict().items(): + if isinstance(param, Parameter): + # backwards compatibility for serialized parameters + param = param.data + ema_copy.state_dict()[name].copy_(param) + + use_ema = training_config.get("use_ema", False) + + training_wrapper = AutoencoderTrainingWrapper.load_from_checkpoint( + args.ckpt_path, + autoencoder=model, + strict=False, + loss_config=training_config["loss_configs"], + use_ema=training_config["use_ema"], + ema_copy=ema_copy if use_ema else None + ) + elif model_type == 'diffusion_uncond': + from stable_audio_tools.training.diffusion import DiffusionUncondTrainingWrapper + training_wrapper = DiffusionUncondTrainingWrapper.load_from_checkpoint(args.ckpt_path, model=model, strict=False) + + elif model_type == 'diffusion_autoencoder': + from stable_audio_tools.training.diffusion import DiffusionAutoencoderTrainingWrapper + + ema_copy = create_model_from_config(model_config) + + for name, param in model.state_dict().items(): + if isinstance(param, Parameter): + # backwards compatibility for serialized parameters + param = param.data + ema_copy.state_dict()[name].copy_(param) + + training_wrapper = DiffusionAutoencoderTrainingWrapper.load_from_checkpoint(args.ckpt_path, model=model, ema_copy=ema_copy, strict=False) + elif model_type == 'diffusion_cond': + from stable_audio_tools.training.diffusion import DiffusionCondTrainingWrapper + + use_ema = training_config.get("use_ema", True) + + training_wrapper = DiffusionCondTrainingWrapper.load_from_checkpoint( + args.ckpt_path, + model=model, + use_ema=use_ema, + lr=training_config.get("learning_rate", None), + optimizer_configs=training_config.get("optimizer_configs", None), + strict=False + ) + elif model_type == 'diffusion_cond_inpaint': + from stable_audio_tools.training.diffusion import DiffusionCondInpaintTrainingWrapper + training_wrapper = DiffusionCondInpaintTrainingWrapper.load_from_checkpoint(args.ckpt_path, model=model, strict=False) + elif model_type == 'diffusion_prior': + from stable_audio_tools.training.diffusion import DiffusionPriorTrainingWrapper + + ema_copy = create_model_from_config(model_config) + + for name, param in model.state_dict().items(): + if isinstance(param, Parameter): + # backwards compatibility for serialized parameters + param = param.data + ema_copy.state_dict()[name].copy_(param) + + training_wrapper = DiffusionPriorTrainingWrapper.load_from_checkpoint(args.ckpt_path, model=model, strict=False, ema_copy=ema_copy) + elif model_type == 'lm': + from stable_audio_tools.training.lm import AudioLanguageModelTrainingWrapper + + ema_copy = None + + if training_config.get("use_ema", False): + + ema_copy = create_model_from_config(model_config) + + for name, param in model.state_dict().items(): + if isinstance(param, Parameter): + # backwards compatibility for serialized parameters + param = param.data + ema_copy.state_dict()[name].copy_(param) + + training_wrapper = AudioLanguageModelTrainingWrapper.load_from_checkpoint( + args.ckpt_path, + model=model, + strict=False, + ema_copy=ema_copy, + optimizer_configs=training_config.get("optimizer_configs", None) + ) + + else: + raise ValueError(f"Unknown model type {model_type}") + + print(f"Loaded model from {args.ckpt_path}") + + if args.use_safetensors: + ckpt_path = f"{args.name}.safetensors" + else: + ckpt_path = f"{args.name}.ckpt" + + training_wrapper.export_model(ckpt_path, use_safetensors=args.use_safetensors) + + print(f"Exported model to {ckpt_path}") \ No newline at end of file