Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import json | |
import time | |
import torch | |
import argparse | |
import multiprocessing | |
from tqdm import tqdm | |
from safetensors.torch import load_file | |
from diffusers import AutoencoderOobleck | |
import soundfile as sf | |
from model import TangoFlux | |
import random | |
def generate_audio_chunk(args, chunk, gpu_id, output_dir, samplerate, return_dict, process_id): | |
""" | |
Function to generate audio for a chunk of text prompts on a specific GPU. | |
""" | |
try: | |
device = f"cuda:{gpu_id}" | |
torch.cuda.set_device(device) | |
print(f"Process {process_id}: Using device {device}") | |
# Initialize model | |
config = { | |
'num_layers': 6, | |
'num_single_layers': 18, | |
'in_channels': 64, | |
'attention_head_dim': 128, | |
'joint_attention_dim': 1024, | |
'num_attention_heads': 8, | |
'audio_seq_len': 645, | |
'max_duration': 30, | |
'uncondition': False, | |
'text_encoder_name': "google/flan-t5-large" | |
} | |
model = TangoFlux(config) | |
print(f"Process {process_id}: Loading model from {args.model} on {device}") | |
w1 = load_file(args.model) | |
model.load_state_dict(w1, strict=False) | |
model = model.to(device) | |
model.eval() | |
# Initialize VAE | |
vae = AutoencoderOobleck.from_pretrained("stabilityai/stable-audio-open-1.0", subfolder='vae') | |
vae = vae.to(device) | |
vae.eval() | |
outputs = [] | |
# Corrected loop using enumerate properly with tqdm | |
for idx, item in tqdm(enumerate(chunk), total=len(chunk), desc=f"GPU {gpu_id}"): | |
text = item['captions'] | |
if os.path.exists(os.path.join(output_dir, f"id_{item['id']}_sample1.wav")): | |
print("Exist! Skipping!") | |
continue | |
with torch.no_grad(): | |
latent = model.inference_flow( | |
text, | |
num_inference_steps=args.num_steps, | |
guidance_scale=args.guidance_scale, | |
duration=10, | |
num_samples_per_prompt=args.num_samples | |
) | |
#waveform_end = int(duration * vae.config.sampling_rate) | |
latent = latent[:, :220, :] ## 220 correspond to the latent length of audiocaps encoded with this vae. You can modify this | |
wave = vae.decode(latent.transpose(2, 1)).sample.cpu() | |
for i in range(args.num_samples): | |
filename = f"id_{item['id']}_sample{i+1}.wav" | |
filepath = os.path.join(output_dir, filename) | |
sf.write(filepath, wave[i].T, samplerate) | |
outputs.append({ | |
"id": item['id'], | |
"sample": i + 1, | |
"path": filepath, | |
"captions": text | |
}) | |
return_dict[process_id] = outputs | |
print(f"Process {process_id}: Completed processing on GPU {gpu_id}") | |
except Exception as e: | |
print(f"Process {process_id}: Error on GPU {gpu_id}: {e}") | |
return_dict[process_id] = [] | |
def split_into_chunks(data, num_chunks): | |
""" | |
Splits data into num_chunks approximately equal parts. | |
""" | |
avg = len(data) // num_chunks | |
chunks = [] | |
for i in range(num_chunks): | |
start = i * avg | |
# Ensure the last chunk takes the remainder | |
end = (i + 1) * avg if i != num_chunks - 1 else len(data) | |
chunks.append(data[start:end]) | |
return chunks | |
def main(): | |
parser = argparse.ArgumentParser(description="Generate audio using multiple GPUs") | |
parser.add_argument('--num_steps', type=int, default=50, help='Number of inference steps') | |
parser.add_argument('--model', type=str, required=True, help='Path to tangoflux weights') | |
parser.add_argument('--num_samples', type=int, default=5, help='Number of samples per prompt') | |
parser.add_argument('--output_dir', type=str, default='output', help='Directory to save outputs') | |
parser.add_argument('--json_path', type=str, required=True, help='Path to input JSON file') | |
parser.add_argument('--sample_size', type=int, default=20000, help='Number of prompts to sample for CRPO') | |
parser.add_argument('--guidance_scale', type=float, default=4.5, help='Guidance scale used for generation') | |
args = parser.parse_args() | |
# Check GPU availability | |
num_gpus = torch.cuda.device_count() | |
sample_size = args.sample_size | |
# Load JSON data | |
import json | |
try: | |
with open(args.json_path, 'r') as f: | |
data = json.load(f) | |
except Exception as e: | |
print(f"Error loading JSON file {args.json_path}: {e}") | |
return | |
if not isinstance(data, list): | |
print("Error: JSON data is not a list.") | |
return | |
if len(data) < sample_size: | |
print(f"Warning: JSON data contains only {len(data)} items. Sampling all available data.") | |
sampled = data | |
else: | |
sampled = random.sample(data, sample_size) | |
# Split data into chunks based on available GPUs | |
random.shuffle(sampled) | |
chunks = split_into_chunks(sampled, num_gpus) | |
# Prepare output directory | |
os.makedirs(args.output_dir, exist_ok=True) | |
samplerate = 44100 | |
# Manager for inter-process communication | |
manager = multiprocessing.Manager() | |
return_dict = manager.dict() | |
processes = [] | |
for i in range(num_gpus): | |
p = multiprocessing.Process( | |
target=generate_audio_chunk, | |
args=( | |
args, | |
chunks[i], | |
i, # GPU ID | |
args.output_dir, | |
samplerate, | |
return_dict, | |
i, # Process ID | |
) | |
) | |
processes.append(p) | |
p.start() | |
print(f"Started process {i} on GPU {i}") | |
for p in processes: | |
p.join() | |
print(f"Process {p.pid} has finished.") | |
# Aggregate results | |
audio_info_list = [ | |
[{ | |
"path": f"{args.output_dir}/id_{sampled[j]['id']}_sample{i}.wav", | |
"duration": sampled[j]["duration"], | |
"captions": sampled[j]["captions"] | |
} | |
for i in range(1, args.num_samples+1) ] for j in range(sample_size) | |
] | |
#print(audio_info_list) | |
with open(f'{args.output_dir}/results.json','w') as f: | |
json.dump(audio_info_list,f) | |
print(f"All audio samples have been generated and saved to {args.output_dir}") | |
if __name__ == "__main__": | |
multiprocessing.set_start_method('spawn') | |
main() |