import gc import numpy as np import json import torch import torchaudio import os import re from aeiou.viz import audio_spectrogram_image from einops import rearrange from safetensors.torch import load_file from torch.nn import functional as F from torchaudio import transforms as T from ..inference.generation import generate_diffusion_cond, generate_diffusion_uncond from ..models.factory import create_model_from_config from ..models.pretrained import get_pretrained_model from ..models.utils import load_ckpt_state_dict from ..inference.utils import prepare_audio from ..training.utils import copy_state_dict model = None sample_rate = 44100 sample_size = 524288 def load_model(model_config=None, model_ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, device="cuda", model_half=False): global model, sample_rate, sample_size if pretrained_name is not None: print(f"Loading pretrained model {pretrained_name}") model, model_config = get_pretrained_model(pretrained_name) elif model_config is not None and model_ckpt_path is not None: print(f"Creating model from config") model = create_model_from_config(model_config) print(f"Loading model checkpoint from {model_ckpt_path}") # Load checkpoint copy_state_dict(model, load_ckpt_state_dict(model_ckpt_path)) #model.load_state_dict(load_ckpt_state_dict(model_ckpt_path)) sample_rate = model_config["sample_rate"] sample_size = model_config["sample_size"] if pretransform_ckpt_path is not None: print(f"Loading pretransform checkpoint from {pretransform_ckpt_path}") model.pretransform.load_state_dict(load_ckpt_state_dict(pretransform_ckpt_path), strict=False) print(f"Done loading pretransform") model.to(device).eval().requires_grad_(False) if model_half: model.to(torch.float16) print(f"Done loading model") return model, model_config def generate_cond_with_path( prompt, negative_prompt=None, seconds_start=0, seconds_total=30, latitude = 0.0, longitude = 0.0, temperature = 0.0, humidity = 0.0, wind_speed = 0.0, pressure = 0.0, minutes_of_day = 0.0, day_of_year = 0.0, cfg_scale=6.0, steps=250, preview_every=None, seed=-1, sampler_type="dpmpp-2m-sde", sigma_min=0.03, sigma_max=50, cfg_rescale=0.4, use_init=False, init_audio=None, init_noise_level=1.0, mask_cropfrom=None, mask_pastefrom=None, mask_pasteto=None, mask_maskstart=None, mask_maskend=None, mask_softnessL=None, mask_softnessR=None, mask_marination=None, batch_size=1, destination_folder=None, file_name=None ): if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() print(f"Prompt: {prompt}") global preview_images preview_images = [] if preview_every == 0: preview_every = None # Return fake stereo audio conditioning = [{"prompt": prompt, "latitude": latitude, "longitude": longitude, "temperature": temperature, "humidity": humidity, "wind_speed": wind_speed, "pressure": pressure, "minutes_of_day": minutes_of_day,"day_of_year": day_of_year, "seconds_start":seconds_start, "seconds_total": seconds_total }] * batch_size if negative_prompt: negative_conditioning = [{"prompt": negative_prompt, "latitude": latitude, "longitude": longitude, "temperature": temperature, "humidity": humidity, "wind_speed": wind_speed, "pressure": pressure, "minutes_of_day": minutes_of_day,"day_of_year": day_of_year, "seconds_start":seconds_start, "seconds_total": seconds_total}] * batch_size else: negative_conditioning = None #Get the device from the model device = next(model.parameters()).device seed = int(seed) if not use_init: init_audio = None input_sample_size = sample_size if init_audio is not None: in_sr, init_audio = init_audio # Turn into torch tensor, converting from int16 to float32 init_audio = torch.from_numpy(init_audio).float().div(32767) if init_audio.dim() == 1: init_audio = init_audio.unsqueeze(0) # [1, n] elif init_audio.dim() == 2: init_audio = init_audio.transpose(0, 1) # [n, 2] -> [2, n] if in_sr != sample_rate: resample_tf = T.Resample(in_sr, sample_rate).to(init_audio.device) init_audio = resample_tf(init_audio) audio_length = init_audio.shape[-1] if audio_length > sample_size: input_sample_size = audio_length + (model.min_input_length - (audio_length % model.min_input_length)) % model.min_input_length init_audio = (sample_rate, init_audio) def progress_callback(callback_info): global preview_images denoised = callback_info["denoised"] current_step = callback_info["i"] sigma = callback_info["sigma"] if (current_step - 1) % preview_every == 0: if model.pretransform is not None: denoised = model.pretransform.decode(denoised) denoised = rearrange(denoised, "b d n -> d (b n)") denoised = denoised.clamp(-1, 1).mul(32767).to(torch.int16).cpu() audio_spectrogram = audio_spectrogram_image(denoised, sample_rate=sample_rate) preview_images.append((audio_spectrogram, f"Step {current_step} sigma={sigma:.3f})")) # If inpainting, send mask args # This will definitely change in the future if mask_cropfrom is not None: mask_args = { "cropfrom": mask_cropfrom, "pastefrom": mask_pastefrom, "pasteto": mask_pasteto, "maskstart": mask_maskstart, "maskend": mask_maskend, "softnessL": mask_softnessL, "softnessR": mask_softnessR, "marination": mask_marination, } else: mask_args = None # Do the audio generation audio = generate_diffusion_cond( model, conditioning=conditioning, negative_conditioning=negative_conditioning, steps=steps, cfg_scale=cfg_scale, batch_size=batch_size, sample_size=input_sample_size, sample_rate=sample_rate, seed=seed, device=device, sampler_type=sampler_type, sigma_min=sigma_min, sigma_max=sigma_max, init_audio=init_audio, init_noise_level=init_noise_level, mask_args = mask_args, callback = progress_callback if preview_every is not None else None, scale_phi = cfg_rescale ) # Convert to WAV file audio = rearrange(audio, "b d n -> d (b n)") audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() #save to the desired folder with the required filename and add the .wav extension if destination_folder is not None and file_name is not None: torchaudio.save(f"{destination_folder}/{file_name}.wav", audio, sample_rate) # Let's look at a nice spectrogram too # audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate) # return ("output.wav", [audio_spectrogram, *preview_images]) def generate_lm( temperature=1.0, top_p=0.95, top_k=0, batch_size=1, ): if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() #Get the device from the model device = next(model.parameters()).device audio = model.generate_audio( batch_size=batch_size, max_gen_len = sample_size//model.pretransform.downsampling_ratio, conditioning=None, temp=temperature, top_p=top_p, top_k=top_k, use_cache=True ) audio = rearrange(audio, "b d n -> d (b n)") audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() torchaudio.save("output.wav", audio, sample_rate) audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate) return ("output.wav", [audio_spectrogram]) def autoencoder_process(audio, latent_noise, n_quantizers): if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() #Get the device from the model device = next(model.parameters()).device in_sr, audio = audio audio = torch.from_numpy(audio).float().div(32767).to(device) if audio.dim() == 1: audio = audio.unsqueeze(0) else: audio = audio.transpose(0, 1) audio = model.preprocess_audio_for_encoder(audio, in_sr) # Note: If you need to do chunked encoding, to reduce VRAM, # then add these arguments to encode_audio and decode_audio: chunked=True, overlap=32, chunk_size=128 # To turn it off, do chunked=False # Optimal overlap and chunk_size values will depend on the model. # See encode_audio & decode_audio in autoencoders.py for more info # Get dtype of model dtype = next(model.parameters()).dtype audio = audio.to(dtype) if n_quantizers > 0: latents = model.encode_audio(audio, chunked=False, n_quantizers=n_quantizers) else: latents = model.encode_audio(audio, chunked=False) if latent_noise > 0: latents = latents + torch.randn_like(latents) * latent_noise audio = model.decode_audio(latents, chunked=False) audio = rearrange(audio, "b d n -> d (b n)") audio = audio.to(torch.float32).clamp(-1, 1).mul(32767).to(torch.int16).cpu() torchaudio.save("output.wav", audio, sample_rate) return "output.wav" def load_and_generate(model_path, json_dir, output_dir): """Load JSON files and generate audio for each set of conditions.""" # List all files in the json_dir files = os.listdir(json_dir) # Filter for JSON files json_files = [file for file in files if file.endswith('.json')] if not json_files: print(f"No JSON files found in {json_dir}. Please check the directory path and file permissions.") return for json_filename in json_files: json_file_path = os.path.join(json_dir, json_filename) try: with open(json_file_path, 'r') as file: data = json.load(file) except Exception as e: print(f"Failed to read or parse {json_file_path}: {e}") continue # Print the JSON path print(json_file_path) # Extract conditions from JSON conditions = { 'birdSpecies': data['birdSpecies'], 'latitude': data['coord']['lat'], 'longitude': data['coord']['lon'], 'temperature': data['main']['temp'], 'humidity': data['main']['humidity'], 'pressure': data['main']['pressure'], 'wind_speed': data['wind']['speed'], 'day_of_year': data['dayOfYear'], 'minutes_of_day': data['minutesOfDay'] } # Extract base filename components step_number = re.search(r'step=(\d+)', model_path).group(1) bird_species = conditions['birdSpecies'].replace(' ', '_') base_filename = f"{bird_species}_{os.path.splitext(json_filename)[0]}_{step_number}_cfg_scale_" #An array of cfg scale values to test cfg_scales = [1.8, 2.5, 4.0, 5.0, 12.0] # Generate audio we do this 4 times with a loop for scale in cfg_scales: generate_cond_with_path(prompt = "", negative_prompt="", seconds_start=0, seconds_total=22, latitude = conditions['latitude'], longitude = conditions['longitude'], temperature = conditions['temperature'], humidity = conditions['humidity'], wind_speed = conditions['wind_speed'], pressure = conditions['pressure'], minutes_of_day = conditions['minutes_of_day'], day_of_year = conditions['day_of_year'], cfg_scale=scale, steps=250, preview_every=None, seed=-1, sampler_type="dpmpp-2m-sde", sigma_min=0.03, sigma_max=50, cfg_rescale=0.4, use_init=False, init_audio=None, init_noise_level=1.0, mask_cropfrom=None, mask_pastefrom=None, mask_pasteto=None, mask_maskstart=None, mask_maskend=None, mask_softnessL=None, mask_softnessR=None, mask_marination=None, batch_size=1, destination_folder=output_dir, file_name=base_filename + str(scale)) def runTests(model_config_path=None, ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, model_half=False, json_dir=None, output_dir=None): assert (pretrained_name is not None) ^ (model_config_path is not None and ckpt_path is not None), "Must specify either pretrained name or provide a model config and checkpoint, but not both" if model_config_path is not None: # Load config from json file with open(model_config_path) as f: model_config = json.load(f) else: model_config = None device = torch.device("cuda" if torch.cuda.is_available() else "cpu") _, model_config = load_model(model_config, ckpt_path, pretrained_name=pretrained_name, pretransform_ckpt_path=pretransform_ckpt_path, model_half=model_half, device=device) # Ensure output directory exists- os.makedirs(args.output_dir, exist_ok=True) # Process all JSON files and generate audio load_and_generate(ckpt_path, json_dir, output_dir)