import os import json import time import torch import argparse import multiprocessing from tqdm import tqdm from safetensors.torch import load_file from diffusers import AutoencoderOobleck import soundfile as sf from model import TangoFlux import random def generate_audio_chunk(args, chunk, gpu_id, output_dir, samplerate, return_dict, process_id): """ Function to generate audio for a chunk of text prompts on a specific GPU. """ try: device = f"cuda:{gpu_id}" torch.cuda.set_device(device) print(f"Process {process_id}: Using device {device}") # Initialize model config = { 'num_layers': 6, 'num_single_layers': 18, 'in_channels': 64, 'attention_head_dim': 128, 'joint_attention_dim': 1024, 'num_attention_heads': 8, 'audio_seq_len': 645, 'max_duration': 30, 'uncondition': False, 'text_encoder_name': "google/flan-t5-large" } model = TangoFlux(config) print(f"Process {process_id}: Loading model from {args.model} on {device}") w1 = load_file(args.model) model.load_state_dict(w1, strict=False) model = model.to(device) model.eval() # Initialize VAE vae = AutoencoderOobleck.from_pretrained("stabilityai/stable-audio-open-1.0", subfolder='vae') vae = vae.to(device) vae.eval() outputs = [] # Corrected loop using enumerate properly with tqdm for idx, item in tqdm(enumerate(chunk), total=len(chunk), desc=f"GPU {gpu_id}"): text = item['captions'] if os.path.exists(os.path.join(output_dir, f"id_{item['id']}_sample1.wav")): print("Exist! Skipping!") continue with torch.no_grad(): latent = model.inference_flow( text, num_inference_steps=args.num_steps, guidance_scale=args.guidance_scale, duration=10, num_samples_per_prompt=args.num_samples ) #waveform_end = int(duration * vae.config.sampling_rate) latent = latent[:, :220, :] ## 220 correspond to the latent length of audiocaps encoded with this vae. You can modify this wave = vae.decode(latent.transpose(2, 1)).sample.cpu() for i in range(args.num_samples): filename = f"id_{item['id']}_sample{i+1}.wav" filepath = os.path.join(output_dir, filename) sf.write(filepath, wave[i].T, samplerate) outputs.append({ "id": item['id'], "sample": i + 1, "path": filepath, "captions": text }) return_dict[process_id] = outputs print(f"Process {process_id}: Completed processing on GPU {gpu_id}") except Exception as e: print(f"Process {process_id}: Error on GPU {gpu_id}: {e}") return_dict[process_id] = [] def split_into_chunks(data, num_chunks): """ Splits data into num_chunks approximately equal parts. """ avg = len(data) // num_chunks chunks = [] for i in range(num_chunks): start = i * avg # Ensure the last chunk takes the remainder end = (i + 1) * avg if i != num_chunks - 1 else len(data) chunks.append(data[start:end]) return chunks def main(): parser = argparse.ArgumentParser(description="Generate audio using multiple GPUs") parser.add_argument('--num_steps', type=int, default=50, help='Number of inference steps') parser.add_argument('--model', type=str, required=True, help='Path to tangoflux weights') parser.add_argument('--num_samples', type=int, default=5, help='Number of samples per prompt') parser.add_argument('--output_dir', type=str, default='output', help='Directory to save outputs') parser.add_argument('--json_path', type=str, required=True, help='Path to input JSON file') parser.add_argument('--sample_size', type=int, default=20000, help='Number of prompts to sample for CRPO') parser.add_argument('--guidance_scale', type=float, default=4.5, help='Guidance scale used for generation') args = parser.parse_args() # Check GPU availability num_gpus = torch.cuda.device_count() sample_size = args.sample_size # Load JSON data import json try: with open(args.json_path, 'r') as f: data = json.load(f) except Exception as e: print(f"Error loading JSON file {args.json_path}: {e}") return if not isinstance(data, list): print("Error: JSON data is not a list.") return if len(data) < sample_size: print(f"Warning: JSON data contains only {len(data)} items. Sampling all available data.") sampled = data else: sampled = random.sample(data, sample_size) # Split data into chunks based on available GPUs random.shuffle(sampled) chunks = split_into_chunks(sampled, num_gpus) # Prepare output directory os.makedirs(args.output_dir, exist_ok=True) samplerate = 44100 # Manager for inter-process communication manager = multiprocessing.Manager() return_dict = manager.dict() processes = [] for i in range(num_gpus): p = multiprocessing.Process( target=generate_audio_chunk, args=( args, chunks[i], i, # GPU ID args.output_dir, samplerate, return_dict, i, # Process ID ) ) processes.append(p) p.start() print(f"Started process {i} on GPU {i}") for p in processes: p.join() print(f"Process {p.pid} has finished.") # Aggregate results audio_info_list = [ [{ "path": f"{args.output_dir}/id_{sampled[j]['id']}_sample{i}.wav", "duration": sampled[j]["duration"], "captions": sampled[j]["captions"] } for i in range(1, args.num_samples+1) ] for j in range(sample_size) ] #print(audio_info_list) with open(f'{args.output_dir}/results.json','w') as f: json.dump(audio_info_list,f) print(f"All audio samples have been generated and saved to {args.output_dir}") if __name__ == "__main__": multiprocessing.set_start_method('spawn') main()