Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,272 Bytes
3324de2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import os
import json
import argparse
import torch
import laion_clap
import numpy as np
import multiprocessing
from tqdm import tqdm
def parse_args():
parser = argparse.ArgumentParser(
description="Labelling clap score for crpo dataset"
)
parser.add_argument(
"--num_samples", type=int, default=5,
help="Number of audio samples per prompt"
)
parser.add_argument(
"--json_path", type=str, required=True,
help="Path to input JSON file"
)
parser.add_argument(
"--output_dir", type=str, required=True,
help="Directory to save the final JSON with CLAP scores"
)
return parser.parse_args()
#python3 label_clap.py --json_path=/mnt/data/chiayu/crpo/crpo_iteration1/results.json --output_dir=/mnt/data/chiayu/crpo/crpo_iteration1
@torch.no_grad()
def compute_clap(model, audio_files, text_data):
# Compute audio and text embeddings, then compute the dot product (CLAP score)
audio_embed = model.get_audio_embedding_from_filelist(x=audio_files, use_tensor=True)
text_embed = model.get_text_embedding(text_data, use_tensor=True)
return audio_embed @ text_embed.T
def process_chunk(args, chunk, gpu_id, return_dict, process_id):
"""
Process a chunk of the data on a specific GPU.
Loads the CLAP model on the designated device, then for each item in the chunk,
computes the CLAP scores and attaches them to the data.
"""
try:
device = f"cuda:{gpu_id}"
torch.cuda.set_device(device)
print(f"Process {process_id}: Using device {device}")
# Initialize the CLAP model on this GPU
model = laion_clap.CLAP_Module(enable_fusion=False)
model.to(device)
model.load_ckpt()
model.eval()
for j, item in enumerate(tqdm(chunk, desc=f"GPU {gpu_id}")):
# Each item is assumed to be a list of samples.
# Skip if already computed.
if 'clap_score' in item[0]:
continue
# Collect audio file paths and text data (using the first caption)
audio_files = [item[i]['path'] for i in range(args.num_samples)]
text_data = [item[0]['captions']]
try:
clap_scores = compute_clap(model, audio_files, text_data)
except Exception as e:
print(f"Error processing item index {j} on GPU {gpu_id}: {e}")
continue
# Attach the computed score to each sample in the item
for k in range(args.num_samples):
item[k]['clap_score'] = np.round(clap_scores[k].item(), 3)
return_dict[process_id] = chunk
print(f"Process {process_id}: Completed processing on GPU {gpu_id}")
except Exception as e:
print(f"Process {process_id}: Error on GPU {gpu_id}: {e}")
return_dict[process_id] = []
def split_into_chunks(data, num_chunks):
"""
Splits data into num_chunks approximately equal parts.
"""
avg = len(data) // num_chunks
chunks = []
for i in range(num_chunks):
start = i * avg
# Ensure the last chunk takes the remainder of the data
end = (i + 1) * avg if i != num_chunks - 1 else len(data)
chunks.append(data[start:end])
return chunks
def main():
args = parse_args()
# Load data from JSON and slice by start/end if provided
with open(args.json_path, 'r') as f:
data = json.load(f)
# Check GPU availability and split data accordingly
num_gpus = torch.cuda.device_count()
print(f"Found {num_gpus} GPUs. Splitting data into {num_gpus} chunks.")
chunks = split_into_chunks(data, num_gpus)
# Prepare output directory
os.makedirs(args.output_dir, exist_ok=True)
# Create a manager dict to collect results from all processes
manager = multiprocessing.Manager()
return_dict = manager.dict()
processes = []
for i in range(num_gpus):
p = multiprocessing.Process(
target=process_chunk,
args=(args, chunks[i], i, return_dict, i)
)
processes.append(p)
p.start()
print(f"Started process {i} on GPU {i}")
for p in processes:
p.join()
print(f"Process {p.pid} has finished.")
# Aggregate all chunks back into a single list
combined_data = []
for i in range(num_gpus):
combined_data.extend(return_dict[i])
# Save the combined results to a single JSON file
output_file = f"{args.output_dir}/clap_scores.json"
with open(output_file, 'w') as f:
json.dump(combined_data, f)
print(f"All CLAP scores have been computed and saved to {output_file}")
max_item = [max(x, key=lambda item: item['clap_score']) for x in combined_data]
min_item = [min(x, key=lambda item: item['clap_score']) for x in combined_data]
crpo_dataset = []
for chosen,reject in zip(max_item,min_item):
crpo_dataset.append({"captions": chosen['captions'],
"duration": chosen['duration'],
"chosen": chosen['path'],
"reject": reject['path']})
with open(f"{args.output_dir}/train.json",'w') as f:
json.dump(crpo_dataset,f)
if __name__ == '__main__':
multiprocessing.set_start_method('spawn')
main()
|