import gc import os import yaml import inspect import torch import torch.nn as nn import numpy as np from diffusers import DDIMScheduler from PIL import Image # from basicsr.utils import tensor2img from diffusers import AutoencoderKL from diffusers.utils.torch_utils import randn_tensor from transformers import ( CLIPTextModel, CLIPTokenizer, AutoTokenizer, CLIPVisionModelWithProjection, CLIPImageProcessor, ClapTextModelWithProjection, RobertaTokenizer, RobertaTokenizerFast, SpeechT5HifiGan, ) from diffusers.utils.import_utils import is_xformers_available from src.module.unet.unet_2d_condition import ( CustomUNet2DConditionModel, UNet2DConditionModel, ) from src.module.unet.estimator import _UNet2DConditionModel from src.utils.inversion import DDIMInversion from src.module.unet.attention_processor import ( IPAttnProcessor, AttnProcessor, Resampler, ) from src.model.sampler import Sampler from src.utils.audio_processing import extract_fbank, wav_to_fbank, TacotronSTFT, maybe_add_dimension import sys sys.path.append("src/module/tango") from tools.torch_tools import wav_to_fbank as tng_wav_to_fbank CWD = os.getcwd() class TangoPipeline: def __init__( self, sd_id="declare-lab/tango", NUM_DDIM_STEPS=100, precision=torch.float32, device = torch.device("cuda" if torch.cuda.is_available() else "cpu"), **kwargs, ): import sys import json import torch from huggingface_hub import snapshot_download sys.path.append("./src/module/tango") from tango2.models import AudioDiffusion from audioldm.audio.stft import TacotronSTFT as tng_TacotronSTFT from audioldm.variational_autoencoder import AutoencoderKL path = snapshot_download(repo_id=sd_id) vae_config = json.load(open("{}/vae_config.json".format(path))) stft_config = json.load(open("{}/stft_config.json".format(path))) main_config = json.load(open("{}/main_config.json".format(path))) main_config["unet_model_config_path"] = os.path.join( CWD, "src/module/tango", main_config["unet_model_config_path"] ) unet = self._set_unet2dconditional_model( CustomUNet2DConditionModel, unet_model_name=main_config["unet_model_name"], unet_model_config_path=main_config["unet_model_config_path"], ).to(device) feature_estimator = self._set_unet2dconditional_model( _UNet2DConditionModel, unet_model_name=main_config["unet_model_name"], unet_model_config_path=main_config["unet_model_config_path"], ).to(device) ##### Load pretrained model ##### vae = AutoencoderKL(**vae_config).to(device) vae.dtype = torch.float32 # avoid attribute missing stft = tng_TacotronSTFT(**stft_config).to(device) model = AudioDiffusion(**main_config).to(device) model.unet = unet # replace unet with the custom unet vae_weights = torch.load( "{}/pytorch_model_vae.bin".format(path), map_location=device ) stft_weights = torch.load( "{}/pytorch_model_stft.bin".format(path), map_location=device ) main_weights = torch.load( "{}/pytorch_model_main.bin".format(path), map_location=device ) vae.load_state_dict(vae_weights) stft.load_state_dict(stft_weights) model.load_state_dict(main_weights) unet_weights = {".".join(layer.split(".")[1:]): param for layer, param in model.named_parameters() if "unet" in layer} feature_estimator.load_state_dict(unet_weights) vae.eval() stft.eval() model.eval() feature_estimator.eval() # Free memeory del vae_weights del stft_weights del main_weights del unet_weights feature_estimator.scheduler = DDIMScheduler.from_pretrained( main_config["scheduler_name"], subfolder="scheduler" ) # Create pipeline for audio editing onestep_pipe = Sampler( vae=vae, tokenizer=model.tokenizer, text_encoder=model.text_encoder, unet=model.unet, feature_estimator=feature_estimator, scheduler=DDIMScheduler.from_pretrained( main_config["scheduler_name"], subfolder="scheduler" ), device=device, precision=precision, ) onestep_pipe.use_cross_attn = True gc.collect() onestep_pipe.enable_attention_slicing() if is_xformers_available(): onestep_pipe.feature_estimator.enable_xformers_memory_efficient_attention() onestep_pipe.enable_xformers_memory_efficient_attention() self.pipe = onestep_pipe self.fn_STFT = stft self.vae_scale_factor = vae_config["ddconfig"]["ch_mult"][-1] self.NUM_DDIM_STEPS = NUM_DDIM_STEPS self.num_tokens = 512 # flant5 self.precision = precision self.device = device # self.load_adapter() # replace the 1-st self-attn layer with cross-attn difference trajactory def _set_unet2dconditional_model( self, cls_obj: UNet2DConditionModel, *, unet_model_name=None, unet_model_config_path=None, ): assert ( unet_model_name is not None or unet_model_config_path is not None ), "Either UNet pretrain model name or a config file path is required" if unet_model_config_path: unet_config = cls_obj.load_config(unet_model_config_path) unet = cls_obj.from_config(unet_config, subfolder="unet") unet.set_from = "random" else: unet = cls_obj.from_pretrained(unet_model_name, subfolder="unet") unet.set_from = "pre-trained" unet.group_in = nn.Sequential(nn.Linear(8, 512), nn.Linear(512, 4)) unet.group_out = nn.Sequential(nn.Linear(4, 512), nn.Linear(512, 8)) return unet @torch.no_grad() def decode_latents(self, latents): return self.pipe.vae.decode_first_stage(latents) @torch.no_grad() def mel_spectrogram_to_waveform(self, mel_spectrogram): return self.pipe.vae.decode_to_waveform(mel_spectrogram) def get_fbank(self, audio_or_path, stft_cfg, return_intermediate=False): r"""Helper function to get fbank from audio file.""" if isinstance(audio_or_path, torch.Tensor): return maybe_add_dimension(audio_or_path, 4) if isinstance(audio_or_path, str): fbank, log_stft, wav = tng_wav_to_fbank( [audio_or_path], fn_STFT=self.fn_STFT, target_length=stft_cfg.filter_length, ) fbank = maybe_add_dimension(fbank, 4) # (B,C,T,F) if return_intermediate: return fbank, log_stft, wav return fbank @torch.no_grad() def encode_fbank(self, fbank): return self.pipe.vae.get_first_stage_encoding( self.pipe.vae.encode_first_stage(fbank) ) @torch.no_grad() def fbank2latent(self, fbank): latent = self.encode_fbank(fbank) return latent def ddim_inv(self, latent, prompt, emb_im=None, save_kv=True, mode="mix", prediction_type="v_prediction"): ddim_inv = DDIMInversion(model=self.pipe, NUM_DDIM_STEPS=self.NUM_DDIM_STEPS) ddim_latents = ddim_inv.invert( ddim_latents=latent.unsqueeze(2), prompt=prompt, emb_im=emb_im, save_kv=save_kv, mode=mode, prediction_type=prediction_type, ) return ddim_latents def init_proj(self, precision): image_proj_model = Resampler( dim=self.pipe.unet.config.cross_attention_dim, depth=4, dim_head=64, heads=12, num_queries=self.num_tokens, embedding_dim=self.image_encoder.config.hidden_size, output_dim=self.pipe.unet.config.cross_attention_dim, ff_mult=4, ).to("cuda", dtype=precision) return image_proj_model def load_adapter(self): scale = 1.0 attn_procs = {} for name in self.pipe.unet.attn_processors.keys(): cross_attention_dim = None if name.startswith("mid_block"): hidden_size = self.pipe.unet.config.block_out_channels[-1] elif name.startswith("up_blocks"): block_id = int(name[len("up_blocks.")]) hidden_size = list(reversed(self.pipe.unet.config.block_out_channels))[ block_id ] elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = self.pipe.unet.config.block_out_channels[block_id] # Only the first self-attention should be used for cross-attend different trojactory if name.endswith("attn1.processor"): attn_procs[name] = AttnProcessor() else: attn_procs[name] = IPAttnProcessor( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=scale, num_tokens=self.num_tokens, ).to("cuda", dtype=self.precision) self.pipe.unet.set_attn_processor(attn_procs) class AudioLDMPipeline: def __init__( self, sd_id="cvssp/audioldm-l-full", ip_id="cvssp/audioldm-l-full", NUM_DDIM_STEPS=50, precision=torch.float32, ip_scale=0, device = torch.device("cuda" if torch.cuda.is_available() else "cpu"), ): onestep_pipe = Sampler( vae=AutoencoderKL.from_pretrained( sd_id, subfolder="vae", torch_dtype=precision ), tokenizer=RobertaTokenizerFast.from_pretrained( sd_id, subfolder="tokenizer" ), text_encoder=ClapTextModelWithProjection.from_pretrained( sd_id, subfolder="text_encoder", torch_dtype=precision ), unet=CustomUNet2DConditionModel.from_pretrained( sd_id, subfolder="unet", torch_dtype=precision ), feature_estimator=_UNet2DConditionModel.from_pretrained( sd_id, subfolder="unet", vae=None, text_encoder=None, tokenizer=None, scheduler=DDIMScheduler.from_pretrained(sd_id, subfolder="scheduler"), safety_checker=None, feature_extractor=None, ), scheduler=DDIMScheduler.from_pretrained(sd_id, subfolder="scheduler"), device=device, precision=precision, ) onestep_pipe.vocoder = SpeechT5HifiGan.from_pretrained( sd_id, subfolder="vocoder", torch_dtype=precision ) onestep_pipe.use_cross_attn = False gc.collect() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") onestep_pipe = onestep_pipe.to(device) onestep_pipe.vocoder.to(device) onestep_pipe.enable_attention_slicing() if is_xformers_available(): onestep_pipe.feature_estimator.enable_xformers_memory_efficient_attention() onestep_pipe.enable_xformers_memory_efficient_attention() self.pipe = onestep_pipe self.vae_scale_factor = 2 ** (len(self.pipe.vae.config.block_out_channels) - 1) self.NUM_DDIM_STEPS = NUM_DDIM_STEPS self.precision = precision self.device = device self.num_tokens = 64 # This is fixed as per pretrained model self.fn_STFT = TacotronSTFT( filter_length=1024, hop_length=160, win_length=1024, n_mel_channels=64, sampling_rate=16000, mel_fmin=0, mel_fmax=8000, ) # self.load_adapter() @torch.no_grad() def decode_latents(self, latents): latents = 1 / self.pipe.vae.config.scaling_factor * latents mel_spectrogram = self.pipe.vae.decode(latents).sample return mel_spectrogram @torch.no_grad() def mel_spectrogram_to_waveform(self, mel_spectrogram): if mel_spectrogram.dim() == 4: mel_spectrogram = mel_spectrogram.squeeze(1) waveform = self.pipe.vocoder( mel_spectrogram.to(device=self.device, dtype=self.precision) ) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 waveform = waveform.cpu().float() return waveform @torch.no_grad() def fbank2latent(self, fbank): latent = self.encode_fbank(fbank) return latent def get_fbank(self, audio_or_path, stft_cfg, return_intermediate=False): r"""Helper function to get fbank from audio file.""" if isinstance(audio_or_path, torch.Tensor): return maybe_add_dimension(audio_or_path, 3) if isinstance(audio_or_path, str): fbank, log_stft, wav = extract_fbank( audio_or_path, fn_STFT=self.fn_STFT, target_length=stft_cfg.filter_length, hop_size=stft_cfg.hop_length, ) fbank = maybe_add_dimension(fbank, 3) # (C,T,F) if return_intermediate: return fbank, log_stft, wav return fbank def wav2fbank(self, wav, target_length): fbank, log_magnitudes_stft = wav_to_fbank(wav, target_length, self.fn_STFT) return fbank, log_magnitudes_stft @torch.no_grad() def encode_fbank(self, fbank): latent = self.pipe.vae.encode(fbank)["latent_dist"].mean # NOTE: Scale the noise latent latent = latent * self.pipe.scheduler.init_noise_sigma return latent def ddim_inv(self, latent, prompt, emb_im=None, save_kv=True, mode="mix", prediction_type="epsilon"): ddim_inv = DDIMInversion(model=self.pipe, NUM_DDIM_STEPS=self.NUM_DDIM_STEPS) ddim_latents = ddim_inv.invert( ddim_latents=latent.unsqueeze(2), prompt=prompt, emb_im=emb_im, save_kv=save_kv, mode=mode, prediction_type=prediction_type ) return ddim_latents def init_proj(self, precision): image_proj_model = Resampler( dim=self.pipe.unet.config.cross_attention_dim, depth=4, dim_head=64, heads=12, num_queries=self.num_tokens, embedding_dim=self.image_encoder.config.hidden_size, output_dim=self.pipe.unet.config.cross_attention_dim, ff_mult=4, ).to("cuda", dtype=precision) return image_proj_model # @torch.inference_mode() # def get_image_embeds(self, pil_image): # if isinstance(pil_image, Image.Image): # pil_image = [pil_image] # clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values # clip_image = clip_image.to('cuda', dtype=self.precision) # clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] # image_prompt_embeds = self.image_proj_model(clip_image_embeds) # uncond_clip_image_embeds = self.image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[-2].detach() # uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds).detach() # return image_prompt_embeds, uncond_image_prompt_embeds def load_adapter(self): scale = 1.0 attn_procs = {} for name in self.pipe.unet.attn_processors.keys(): cross_attention_dim = None if name.startswith("mid_block"): hidden_size = self.pipe.unet.config.block_out_channels[-1] elif name.startswith("up_blocks"): block_id = int(name[len("up_blocks.")]) hidden_size = list(reversed(self.pipe.unet.config.block_out_channels))[ block_id ] elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = self.pipe.unet.config.block_out_channels[block_id] # Only the first self-attention should be used for cross-attend different trojactory if name.endswith("attn1.processor"): attn_procs[name] = AttnProcessor() else: attn_procs[name] = IPAttnProcessor( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=scale, num_tokens=self.num_tokens, ).to("cuda", dtype=self.precision) self.pipe.unet.set_attn_processor(attn_procs) # def load_adapter(self, model_path, scale=1.0): # from src.unet.attention_processor import IPAttnProcessor, AttnProcessor, Resampler # attn_procs = {} # for name in self.pipe.unet.attn_processors.keys(): # cross_attention_dim = None if name.endswith("attn1.processor") else self.pipe.unet.config.cross_attention_dim # if name.startswith("mid_block"): # hidden_size = self.pipe.unet.config.block_out_channels[-1] # elif name.startswith("up_blocks"): # block_id = int(name[len("up_blocks.")]) # hidden_size = list(reversed(self.pipe.unet.config.block_out_channels))[block_id] # elif name.startswith("down_blocks"): # block_id = int(name[len("down_blocks.")]) # hidden_size = self.pipe.unet.config.block_out_channels[block_id] # if cross_attention_dim is None: # attn_procs[name] = AttnProcessor() # else: # attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, # scale=scale,num_tokens= self.num_tokens).to('cuda', dtype=self.precision) # self.pipe.unet.set_attn_processor(attn_procs) # state_dict = torch.load(model_path, map_location="cpu") # self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True) # ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) # ip_layers.load_state_dict(state_dict["ip_adapter"], strict=True) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set( inspect.signature(self.pipe.scheduler.step).parameters.keys() ) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set( inspect.signature(self.pipe.scheduler.step).parameters.keys() ) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with width->self.vocoder.config.model_in_dim def prepare_latents( self, batch_size, num_channels_latents, height, dtype, device, generator, latents=None, ): shape = ( batch_size, num_channels_latents, height // self.vae_scale_factor, self.pipe.vocoder.config.model_in_dim // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor( shape, generator=generator, device=device, dtype=dtype ) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.pipe.scheduler.init_noise_sigma return latents if __name__ == "__main__": # pipeline = AudioLDMPipeline( # sd_id="cvssp/audioldm-l-full", ip_id="cvssp/audioldm-l-full", NUM_DDIM_STEPS=50 # ) pipeline = TangoPipeline( sd_id="declare-lab/tango", ip_id="declare-lab/tango", NUM_DDIM_STEPS=50, precision=torch.float16, ) print(pipeline.__dict__)