FQiao's picture
Upload 70 files
3324de2 verified
raw
history blame contribute delete
5.27 kB
import os
import json
import argparse
import torch
import laion_clap
import numpy as np
import multiprocessing
from tqdm import tqdm
def parse_args():
parser = argparse.ArgumentParser(
description="Labelling clap score for crpo dataset"
)
parser.add_argument(
"--num_samples", type=int, default=5,
help="Number of audio samples per prompt"
)
parser.add_argument(
"--json_path", type=str, required=True,
help="Path to input JSON file"
)
parser.add_argument(
"--output_dir", type=str, required=True,
help="Directory to save the final JSON with CLAP scores"
)
return parser.parse_args()
#python3 label_clap.py --json_path=/mnt/data/chiayu/crpo/crpo_iteration1/results.json --output_dir=/mnt/data/chiayu/crpo/crpo_iteration1
@torch.no_grad()
def compute_clap(model, audio_files, text_data):
# Compute audio and text embeddings, then compute the dot product (CLAP score)
audio_embed = model.get_audio_embedding_from_filelist(x=audio_files, use_tensor=True)
text_embed = model.get_text_embedding(text_data, use_tensor=True)
return audio_embed @ text_embed.T
def process_chunk(args, chunk, gpu_id, return_dict, process_id):
"""
Process a chunk of the data on a specific GPU.
Loads the CLAP model on the designated device, then for each item in the chunk,
computes the CLAP scores and attaches them to the data.
"""
try:
device = f"cuda:{gpu_id}"
torch.cuda.set_device(device)
print(f"Process {process_id}: Using device {device}")
# Initialize the CLAP model on this GPU
model = laion_clap.CLAP_Module(enable_fusion=False)
model.to(device)
model.load_ckpt()
model.eval()
for j, item in enumerate(tqdm(chunk, desc=f"GPU {gpu_id}")):
# Each item is assumed to be a list of samples.
# Skip if already computed.
if 'clap_score' in item[0]:
continue
# Collect audio file paths and text data (using the first caption)
audio_files = [item[i]['path'] for i in range(args.num_samples)]
text_data = [item[0]['captions']]
try:
clap_scores = compute_clap(model, audio_files, text_data)
except Exception as e:
print(f"Error processing item index {j} on GPU {gpu_id}: {e}")
continue
# Attach the computed score to each sample in the item
for k in range(args.num_samples):
item[k]['clap_score'] = np.round(clap_scores[k].item(), 3)
return_dict[process_id] = chunk
print(f"Process {process_id}: Completed processing on GPU {gpu_id}")
except Exception as e:
print(f"Process {process_id}: Error on GPU {gpu_id}: {e}")
return_dict[process_id] = []
def split_into_chunks(data, num_chunks):
"""
Splits data into num_chunks approximately equal parts.
"""
avg = len(data) // num_chunks
chunks = []
for i in range(num_chunks):
start = i * avg
# Ensure the last chunk takes the remainder of the data
end = (i + 1) * avg if i != num_chunks - 1 else len(data)
chunks.append(data[start:end])
return chunks
def main():
args = parse_args()
# Load data from JSON and slice by start/end if provided
with open(args.json_path, 'r') as f:
data = json.load(f)
# Check GPU availability and split data accordingly
num_gpus = torch.cuda.device_count()
print(f"Found {num_gpus} GPUs. Splitting data into {num_gpus} chunks.")
chunks = split_into_chunks(data, num_gpus)
# Prepare output directory
os.makedirs(args.output_dir, exist_ok=True)
# Create a manager dict to collect results from all processes
manager = multiprocessing.Manager()
return_dict = manager.dict()
processes = []
for i in range(num_gpus):
p = multiprocessing.Process(
target=process_chunk,
args=(args, chunks[i], i, return_dict, i)
)
processes.append(p)
p.start()
print(f"Started process {i} on GPU {i}")
for p in processes:
p.join()
print(f"Process {p.pid} has finished.")
# Aggregate all chunks back into a single list
combined_data = []
for i in range(num_gpus):
combined_data.extend(return_dict[i])
# Save the combined results to a single JSON file
output_file = f"{args.output_dir}/clap_scores.json"
with open(output_file, 'w') as f:
json.dump(combined_data, f)
print(f"All CLAP scores have been computed and saved to {output_file}")
max_item = [max(x, key=lambda item: item['clap_score']) for x in combined_data]
min_item = [min(x, key=lambda item: item['clap_score']) for x in combined_data]
crpo_dataset = []
for chosen,reject in zip(max_item,min_item):
crpo_dataset.append({"captions": chosen['captions'],
"duration": chosen['duration'],
"chosen": chosen['path'],
"reject": reject['path']})
with open(f"{args.output_dir}/train.json",'w') as f:
json.dump(crpo_dataset,f)
if __name__ == '__main__':
multiprocessing.set_start_method('spawn')
main()