Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,628 Bytes
3324de2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 |
import os
import json
import time
import torch
import argparse
import multiprocessing
from tqdm import tqdm
from safetensors.torch import load_file
from diffusers import AutoencoderOobleck
import soundfile as sf
from model import TangoFlux
import random
def generate_audio_chunk(args, chunk, gpu_id, output_dir, samplerate, return_dict, process_id):
"""
Function to generate audio for a chunk of text prompts on a specific GPU.
"""
try:
device = f"cuda:{gpu_id}"
torch.cuda.set_device(device)
print(f"Process {process_id}: Using device {device}")
# Initialize model
config = {
'num_layers': 6,
'num_single_layers': 18,
'in_channels': 64,
'attention_head_dim': 128,
'joint_attention_dim': 1024,
'num_attention_heads': 8,
'audio_seq_len': 645,
'max_duration': 30,
'uncondition': False,
'text_encoder_name': "google/flan-t5-large"
}
model = TangoFlux(config)
print(f"Process {process_id}: Loading model from {args.model} on {device}")
w1 = load_file(args.model)
model.load_state_dict(w1, strict=False)
model = model.to(device)
model.eval()
# Initialize VAE
vae = AutoencoderOobleck.from_pretrained("stabilityai/stable-audio-open-1.0", subfolder='vae')
vae = vae.to(device)
vae.eval()
outputs = []
# Corrected loop using enumerate properly with tqdm
for idx, item in tqdm(enumerate(chunk), total=len(chunk), desc=f"GPU {gpu_id}"):
text = item['captions']
if os.path.exists(os.path.join(output_dir, f"id_{item['id']}_sample1.wav")):
print("Exist! Skipping!")
continue
with torch.no_grad():
latent = model.inference_flow(
text,
num_inference_steps=args.num_steps,
guidance_scale=args.guidance_scale,
duration=10,
num_samples_per_prompt=args.num_samples
)
#waveform_end = int(duration * vae.config.sampling_rate)
latent = latent[:, :220, :] ## 220 correspond to the latent length of audiocaps encoded with this vae. You can modify this
wave = vae.decode(latent.transpose(2, 1)).sample.cpu()
for i in range(args.num_samples):
filename = f"id_{item['id']}_sample{i+1}.wav"
filepath = os.path.join(output_dir, filename)
sf.write(filepath, wave[i].T, samplerate)
outputs.append({
"id": item['id'],
"sample": i + 1,
"path": filepath,
"captions": text
})
return_dict[process_id] = outputs
print(f"Process {process_id}: Completed processing on GPU {gpu_id}")
except Exception as e:
print(f"Process {process_id}: Error on GPU {gpu_id}: {e}")
return_dict[process_id] = []
def split_into_chunks(data, num_chunks):
"""
Splits data into num_chunks approximately equal parts.
"""
avg = len(data) // num_chunks
chunks = []
for i in range(num_chunks):
start = i * avg
# Ensure the last chunk takes the remainder
end = (i + 1) * avg if i != num_chunks - 1 else len(data)
chunks.append(data[start:end])
return chunks
def main():
parser = argparse.ArgumentParser(description="Generate audio using multiple GPUs")
parser.add_argument('--num_steps', type=int, default=50, help='Number of inference steps')
parser.add_argument('--model', type=str, required=True, help='Path to tangoflux weights')
parser.add_argument('--num_samples', type=int, default=5, help='Number of samples per prompt')
parser.add_argument('--output_dir', type=str, default='output', help='Directory to save outputs')
parser.add_argument('--json_path', type=str, required=True, help='Path to input JSON file')
parser.add_argument('--sample_size', type=int, default=20000, help='Number of prompts to sample for CRPO')
parser.add_argument('--guidance_scale', type=float, default=4.5, help='Guidance scale used for generation')
args = parser.parse_args()
# Check GPU availability
num_gpus = torch.cuda.device_count()
sample_size = args.sample_size
# Load JSON data
import json
try:
with open(args.json_path, 'r') as f:
data = json.load(f)
except Exception as e:
print(f"Error loading JSON file {args.json_path}: {e}")
return
if not isinstance(data, list):
print("Error: JSON data is not a list.")
return
if len(data) < sample_size:
print(f"Warning: JSON data contains only {len(data)} items. Sampling all available data.")
sampled = data
else:
sampled = random.sample(data, sample_size)
# Split data into chunks based on available GPUs
random.shuffle(sampled)
chunks = split_into_chunks(sampled, num_gpus)
# Prepare output directory
os.makedirs(args.output_dir, exist_ok=True)
samplerate = 44100
# Manager for inter-process communication
manager = multiprocessing.Manager()
return_dict = manager.dict()
processes = []
for i in range(num_gpus):
p = multiprocessing.Process(
target=generate_audio_chunk,
args=(
args,
chunks[i],
i, # GPU ID
args.output_dir,
samplerate,
return_dict,
i, # Process ID
)
)
processes.append(p)
p.start()
print(f"Started process {i} on GPU {i}")
for p in processes:
p.join()
print(f"Process {p.pid} has finished.")
# Aggregate results
audio_info_list = [
[{
"path": f"{args.output_dir}/id_{sampled[j]['id']}_sample{i}.wav",
"duration": sampled[j]["duration"],
"captions": sampled[j]["captions"]
}
for i in range(1, args.num_samples+1) ] for j in range(sample_size)
]
#print(audio_info_list)
with open(f'{args.output_dir}/results.json','w') as f:
json.dump(audio_info_list,f)
print(f"All audio samples have been generated and saved to {args.output_dir}")
if __name__ == "__main__":
multiprocessing.set_start_method('spawn')
main() |