File size: 3,757 Bytes
cd77de9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import os
import json
import pickle
import torch
import numpy as np
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel
import torch.multiprocessing as mp
# Paths and settings
INPUT_JSON = "Pretrain.json"
mean_shift = True # Enable full-pass mean shifting
CKPT = "/root/autodl-tmp/model/siglip2"
BATCH_SIZE = 512
LOAD_LIMIT = None # Limit number of items, or None for all
# Output directories
RAW_DIR = "raw_embeds"
SHIFTED_DIR = "shifted_embeds"
# Create output directories if they don't exist
os.makedirs(RAW_DIR, exist_ok=True)
os.makedirs(SHIFTED_DIR, exist_ok=True)
# 1. Load data
with open(INPUT_JSON, "r", encoding="utf-8") as f:
items = json.load(f)
if LOAD_LIMIT is not None:
items = items[:LOAD_LIMIT]
# 2. Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(CKPT)
# 3. Split data among GPUs
num_gpus = torch.cuda.device_count()
chunks = np.array_split(items, num_gpus)
# Function to compute raw embeddings (no shift)
def compute_raw_embeddings(device, data_chunk, gpu_id):
device = torch.device(device)
model = AutoModel.from_pretrained(CKPT).to(device).eval()
results = [] # To store raw embeddings
for i in tqdm(range(0, len(data_chunk), BATCH_SIZE), desc=f"Device {gpu_id} Raw Batches"):
batch = data_chunk[i:i + BATCH_SIZE]
ids = [it['id'] for it in batch]
captions = [it.get('caption', '') for it in batch]
inputs = tokenizer(
captions,
padding="max_length",
truncation=True,
max_length=64,
return_tensors="pt"
).to(device)
with torch.no_grad():
embs = model.get_text_features(**inputs)
embs_np = embs.cpu().numpy()
for idx, item_id in enumerate(ids):
results.append({'id': item_id, 'embed': embs_np[idx]})
# Save raw embeddings
raw_file = os.path.join(RAW_DIR, f"raw_embeds_device_{gpu_id}.pkl")
with open(raw_file, 'wb') as f:
pickle.dump(results, f, protocol=pickle.HIGHEST_PROTOCOL)
print(f"Device {gpu_id} saved {len(results)} raw embeddings to {raw_file}")
# Function to apply mean shift and save final embeddings
def apply_mean_shift_and_save(global_mean, gpu_id):
raw_file = os.path.join(RAW_DIR, f"raw_embeds_device_{gpu_id}.pkl")
out_file = os.path.join(SHIFTED_DIR, f"embeds_device_{gpu_id}.pkl")
with open(raw_file, 'rb') as f:
data = pickle.load(f)
# Subtract global mean
for item in data:
item['embed'] = item['embed'] - global_mean
# Save shifted embeddings
with open(out_file, 'wb') as f:
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
print(f"Device {gpu_id} saved {len(data)} shifted embeddings to {out_file}")
# Main entry
def main():
# 1st pass: compute raw embeddings in parallel
procs = []
for i in range(num_gpus):
p = mp.Process(target=compute_raw_embeddings, args=(f"cuda:{i}", chunks[i], i))
p.start()
procs.append(p)
for p in procs:
p.join()
if mean_shift:
# Load all raw embeddings to compute global mean
all_embeds = []
for i in range(num_gpus):
raw_file = os.path.join(RAW_DIR, f"raw_embeds_device_{i}.pkl")
with open(raw_file, 'rb') as f:
data = pickle.load(f)
all_embeds.extend([item['embed'] for item in data])
all_embeds = np.stack(all_embeds, axis=0)
global_mean = np.mean(all_embeds, axis=0)
print("Computed global mean of shape", global_mean.shape)
# 2nd pass: subtract mean and save shifted embeddings
for i in range(num_gpus):
apply_mean_shift_and_save(global_mean, i)
if __name__ == "__main__":
main()
|