File size: 5,272 Bytes
3324de2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import os
import json
import argparse
import torch
import laion_clap
import numpy as np
import multiprocessing
from tqdm import tqdm

def parse_args():
    parser = argparse.ArgumentParser(
        description="Labelling clap score for crpo dataset"
    )
    parser.add_argument(
        "--num_samples", type=int, default=5,
        help="Number of audio samples per prompt"
    )
    parser.add_argument(
        "--json_path", type=str, required=True,
        help="Path to input JSON file"
    )
    parser.add_argument(
        "--output_dir", type=str, required=True,
        help="Directory to save the final JSON with CLAP scores"
    )
    return parser.parse_args()

#python3 label_clap.py --json_path=/mnt/data/chiayu/crpo/crpo_iteration1/results.json --output_dir=/mnt/data/chiayu/crpo/crpo_iteration1
@torch.no_grad()
def compute_clap(model, audio_files, text_data):
    # Compute audio and text embeddings, then compute the dot product (CLAP score)
    audio_embed = model.get_audio_embedding_from_filelist(x=audio_files, use_tensor=True)
    text_embed = model.get_text_embedding(text_data, use_tensor=True)
    return audio_embed @ text_embed.T

def process_chunk(args, chunk, gpu_id, return_dict, process_id):
    """
    Process a chunk of the data on a specific GPU.
    Loads the CLAP model on the designated device, then for each item in the chunk,
    computes the CLAP scores and attaches them to the data.
    """
    try:
        device = f"cuda:{gpu_id}"
        torch.cuda.set_device(device)
        print(f"Process {process_id}: Using device {device}")

        # Initialize the CLAP model on this GPU
        model = laion_clap.CLAP_Module(enable_fusion=False)
        model.to(device)
        model.load_ckpt()
        model.eval()

        for j, item in enumerate(tqdm(chunk, desc=f"GPU {gpu_id}")):
            # Each item is assumed to be a list of samples.
            # Skip if already computed.
            if 'clap_score' in item[0]:
                continue

            # Collect audio file paths and text data (using the first caption)
            audio_files = [item[i]['path'] for i in range(args.num_samples)]
            text_data = [item[0]['captions']]

            try:
                clap_scores = compute_clap(model, audio_files, text_data)
            except Exception as e:
                print(f"Error processing item index {j} on GPU {gpu_id}: {e}")
                continue

            # Attach the computed score to each sample in the item
            for k in range(args.num_samples):
                item[k]['clap_score'] = np.round(clap_scores[k].item(), 3)

        return_dict[process_id] = chunk
        print(f"Process {process_id}: Completed processing on GPU {gpu_id}")
    except Exception as e:
        print(f"Process {process_id}: Error on GPU {gpu_id}: {e}")
        return_dict[process_id] = []

def split_into_chunks(data, num_chunks):
    """
    Splits data into num_chunks approximately equal parts.
    """
    avg = len(data) // num_chunks
    chunks = []
    for i in range(num_chunks):
        start = i * avg
        # Ensure the last chunk takes the remainder of the data
        end = (i + 1) * avg if i != num_chunks - 1 else len(data)
        chunks.append(data[start:end])
    return chunks

def main():
    args = parse_args()

    # Load data from JSON and slice by start/end if provided
    with open(args.json_path, 'r') as f:
        data = json.load(f)

    # Check GPU availability and split data accordingly
    num_gpus = torch.cuda.device_count()

    print(f"Found {num_gpus} GPUs. Splitting data into {num_gpus} chunks.")
    chunks = split_into_chunks(data, num_gpus)

    # Prepare output directory
    os.makedirs(args.output_dir, exist_ok=True)

    # Create a manager dict to collect results from all processes
    manager = multiprocessing.Manager()
    return_dict = manager.dict()
    processes = []

    for i in range(num_gpus):
        p = multiprocessing.Process(
            target=process_chunk,
            args=(args, chunks[i], i, return_dict, i)
        )
        processes.append(p)
        p.start()
        print(f"Started process {i} on GPU {i}")

    for p in processes:
        p.join()
        print(f"Process {p.pid} has finished.")

    # Aggregate all chunks back into a single list
    combined_data = []
    for i in range(num_gpus):
        combined_data.extend(return_dict[i])

    # Save the combined results to a single JSON file
    output_file =  f"{args.output_dir}/clap_scores.json"
    with open(output_file, 'w') as f:
        json.dump(combined_data, f)
    print(f"All CLAP scores have been computed and saved to {output_file}")

    max_item = [max(x, key=lambda item: item['clap_score']) for x in combined_data]
    min_item = [min(x, key=lambda item: item['clap_score']) for x in combined_data]

    crpo_dataset = []
    for chosen,reject in zip(max_item,min_item):
        crpo_dataset.append({"captions": chosen['captions'], 
        "duration": chosen['duration'], 
        "chosen": chosen['path'], 
        "reject": reject['path']})
        
    with open(f"{args.output_dir}/train.json",'w') as f:
        json.dump(crpo_dataset,f)


if __name__ == '__main__':
    multiprocessing.set_start_method('spawn')
    main()