File size: 3,757 Bytes
cd77de9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import json
import pickle
import torch
import numpy as np
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel
import torch.multiprocessing as mp

# Paths and settings
INPUT_JSON = "Pretrain.json"
mean_shift = True  # Enable full-pass mean shifting
CKPT = "/root/autodl-tmp/model/siglip2"
BATCH_SIZE = 512
LOAD_LIMIT = None  # Limit number of items, or None for all

# Output directories
RAW_DIR = "raw_embeds"
SHIFTED_DIR = "shifted_embeds"

# Create output directories if they don't exist
os.makedirs(RAW_DIR, exist_ok=True)
os.makedirs(SHIFTED_DIR, exist_ok=True)

# 1. Load data
with open(INPUT_JSON, "r", encoding="utf-8") as f:
    items = json.load(f)
if LOAD_LIMIT is not None:
    items = items[:LOAD_LIMIT]

# 2. Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(CKPT)

# 3. Split data among GPUs
num_gpus = torch.cuda.device_count()
chunks = np.array_split(items, num_gpus)

# Function to compute raw embeddings (no shift)
def compute_raw_embeddings(device, data_chunk, gpu_id):
    device = torch.device(device)
    model = AutoModel.from_pretrained(CKPT).to(device).eval()
    results = []  # To store raw embeddings

    for i in tqdm(range(0, len(data_chunk), BATCH_SIZE), desc=f"Device {gpu_id} Raw Batches"):
        batch = data_chunk[i:i + BATCH_SIZE]
        ids = [it['id'] for it in batch]
        captions = [it.get('caption', '') for it in batch]

        inputs = tokenizer(
            captions,
            padding="max_length",
            truncation=True,
            max_length=64,
            return_tensors="pt"
        ).to(device)

        with torch.no_grad():
            embs = model.get_text_features(**inputs)
        embs_np = embs.cpu().numpy()

        for idx, item_id in enumerate(ids):
            results.append({'id': item_id, 'embed': embs_np[idx]})

    # Save raw embeddings
    raw_file = os.path.join(RAW_DIR, f"raw_embeds_device_{gpu_id}.pkl")
    with open(raw_file, 'wb') as f:
        pickle.dump(results, f, protocol=pickle.HIGHEST_PROTOCOL)
    print(f"Device {gpu_id} saved {len(results)} raw embeddings to {raw_file}")

# Function to apply mean shift and save final embeddings
def apply_mean_shift_and_save(global_mean, gpu_id):
    raw_file = os.path.join(RAW_DIR, f"raw_embeds_device_{gpu_id}.pkl")
    out_file = os.path.join(SHIFTED_DIR, f"embeds_device_{gpu_id}.pkl")

    with open(raw_file, 'rb') as f:
        data = pickle.load(f)

    # Subtract global mean
    for item in data:
        item['embed'] = item['embed'] - global_mean

    # Save shifted embeddings
    with open(out_file, 'wb') as f:
        pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
    print(f"Device {gpu_id} saved {len(data)} shifted embeddings to {out_file}")

# Main entry
def main():
    # 1st pass: compute raw embeddings in parallel
    procs = []
    for i in range(num_gpus):
        p = mp.Process(target=compute_raw_embeddings, args=(f"cuda:{i}", chunks[i], i))
        p.start()
        procs.append(p)
    for p in procs:
        p.join()

    if mean_shift:
        # Load all raw embeddings to compute global mean
        all_embeds = []
        for i in range(num_gpus):
            raw_file = os.path.join(RAW_DIR, f"raw_embeds_device_{i}.pkl")
            with open(raw_file, 'rb') as f:
                data = pickle.load(f)
            all_embeds.extend([item['embed'] for item in data])

        all_embeds = np.stack(all_embeds, axis=0)
        global_mean = np.mean(all_embeds, axis=0)
        print("Computed global mean of shape", global_mean.shape)

        # 2nd pass: subtract mean and save shifted embeddings
        for i in range(num_gpus):
            apply_mean_shift_and_save(global_mean, i)

if __name__ == "__main__":
    main()