Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import json | |
import argparse | |
import torch | |
import laion_clap | |
import numpy as np | |
import multiprocessing | |
from tqdm import tqdm | |
def parse_args(): | |
parser = argparse.ArgumentParser( | |
description="Labelling clap score for crpo dataset" | |
) | |
parser.add_argument( | |
"--num_samples", type=int, default=5, | |
help="Number of audio samples per prompt" | |
) | |
parser.add_argument( | |
"--json_path", type=str, required=True, | |
help="Path to input JSON file" | |
) | |
parser.add_argument( | |
"--output_dir", type=str, required=True, | |
help="Directory to save the final JSON with CLAP scores" | |
) | |
return parser.parse_args() | |
#python3 label_clap.py --json_path=/mnt/data/chiayu/crpo/crpo_iteration1/results.json --output_dir=/mnt/data/chiayu/crpo/crpo_iteration1 | |
def compute_clap(model, audio_files, text_data): | |
# Compute audio and text embeddings, then compute the dot product (CLAP score) | |
audio_embed = model.get_audio_embedding_from_filelist(x=audio_files, use_tensor=True) | |
text_embed = model.get_text_embedding(text_data, use_tensor=True) | |
return audio_embed @ text_embed.T | |
def process_chunk(args, chunk, gpu_id, return_dict, process_id): | |
""" | |
Process a chunk of the data on a specific GPU. | |
Loads the CLAP model on the designated device, then for each item in the chunk, | |
computes the CLAP scores and attaches them to the data. | |
""" | |
try: | |
device = f"cuda:{gpu_id}" | |
torch.cuda.set_device(device) | |
print(f"Process {process_id}: Using device {device}") | |
# Initialize the CLAP model on this GPU | |
model = laion_clap.CLAP_Module(enable_fusion=False) | |
model.to(device) | |
model.load_ckpt() | |
model.eval() | |
for j, item in enumerate(tqdm(chunk, desc=f"GPU {gpu_id}")): | |
# Each item is assumed to be a list of samples. | |
# Skip if already computed. | |
if 'clap_score' in item[0]: | |
continue | |
# Collect audio file paths and text data (using the first caption) | |
audio_files = [item[i]['path'] for i in range(args.num_samples)] | |
text_data = [item[0]['captions']] | |
try: | |
clap_scores = compute_clap(model, audio_files, text_data) | |
except Exception as e: | |
print(f"Error processing item index {j} on GPU {gpu_id}: {e}") | |
continue | |
# Attach the computed score to each sample in the item | |
for k in range(args.num_samples): | |
item[k]['clap_score'] = np.round(clap_scores[k].item(), 3) | |
return_dict[process_id] = chunk | |
print(f"Process {process_id}: Completed processing on GPU {gpu_id}") | |
except Exception as e: | |
print(f"Process {process_id}: Error on GPU {gpu_id}: {e}") | |
return_dict[process_id] = [] | |
def split_into_chunks(data, num_chunks): | |
""" | |
Splits data into num_chunks approximately equal parts. | |
""" | |
avg = len(data) // num_chunks | |
chunks = [] | |
for i in range(num_chunks): | |
start = i * avg | |
# Ensure the last chunk takes the remainder of the data | |
end = (i + 1) * avg if i != num_chunks - 1 else len(data) | |
chunks.append(data[start:end]) | |
return chunks | |
def main(): | |
args = parse_args() | |
# Load data from JSON and slice by start/end if provided | |
with open(args.json_path, 'r') as f: | |
data = json.load(f) | |
# Check GPU availability and split data accordingly | |
num_gpus = torch.cuda.device_count() | |
print(f"Found {num_gpus} GPUs. Splitting data into {num_gpus} chunks.") | |
chunks = split_into_chunks(data, num_gpus) | |
# Prepare output directory | |
os.makedirs(args.output_dir, exist_ok=True) | |
# Create a manager dict to collect results from all processes | |
manager = multiprocessing.Manager() | |
return_dict = manager.dict() | |
processes = [] | |
for i in range(num_gpus): | |
p = multiprocessing.Process( | |
target=process_chunk, | |
args=(args, chunks[i], i, return_dict, i) | |
) | |
processes.append(p) | |
p.start() | |
print(f"Started process {i} on GPU {i}") | |
for p in processes: | |
p.join() | |
print(f"Process {p.pid} has finished.") | |
# Aggregate all chunks back into a single list | |
combined_data = [] | |
for i in range(num_gpus): | |
combined_data.extend(return_dict[i]) | |
# Save the combined results to a single JSON file | |
output_file = f"{args.output_dir}/clap_scores.json" | |
with open(output_file, 'w') as f: | |
json.dump(combined_data, f) | |
print(f"All CLAP scores have been computed and saved to {output_file}") | |
max_item = [max(x, key=lambda item: item['clap_score']) for x in combined_data] | |
min_item = [min(x, key=lambda item: item['clap_score']) for x in combined_data] | |
crpo_dataset = [] | |
for chosen,reject in zip(max_item,min_item): | |
crpo_dataset.append({"captions": chosen['captions'], | |
"duration": chosen['duration'], | |
"chosen": chosen['path'], | |
"reject": reject['path']}) | |
with open(f"{args.output_dir}/train.json",'w') as f: | |
json.dump(crpo_dataset,f) | |
if __name__ == '__main__': | |
multiprocessing.set_start_method('spawn') | |
main() | |