import os import json import pickle import torch import numpy as np from tqdm import tqdm from transformers import AutoTokenizer, AutoModel import torch.multiprocessing as mp # Paths and settings INPUT_JSON = "Pretrain.json" mean_shift = True # Enable full-pass mean shifting CKPT = "/root/autodl-tmp/model/siglip2" BATCH_SIZE = 512 LOAD_LIMIT = None # Limit number of items, or None for all # Output directories RAW_DIR = "raw_embeds" SHIFTED_DIR = "shifted_embeds" # Create output directories if they don't exist os.makedirs(RAW_DIR, exist_ok=True) os.makedirs(SHIFTED_DIR, exist_ok=True) # 1. Load data with open(INPUT_JSON, "r", encoding="utf-8") as f: items = json.load(f) if LOAD_LIMIT is not None: items = items[:LOAD_LIMIT] # 2. Initialize tokenizer tokenizer = AutoTokenizer.from_pretrained(CKPT) # 3. Split data among GPUs num_gpus = torch.cuda.device_count() chunks = np.array_split(items, num_gpus) # Function to compute raw embeddings (no shift) def compute_raw_embeddings(device, data_chunk, gpu_id): device = torch.device(device) model = AutoModel.from_pretrained(CKPT).to(device).eval() results = [] # To store raw embeddings for i in tqdm(range(0, len(data_chunk), BATCH_SIZE), desc=f"Device {gpu_id} Raw Batches"): batch = data_chunk[i:i + BATCH_SIZE] ids = [it['id'] for it in batch] captions = [it.get('caption', '') for it in batch] inputs = tokenizer( captions, padding="max_length", truncation=True, max_length=64, return_tensors="pt" ).to(device) with torch.no_grad(): embs = model.get_text_features(**inputs) embs_np = embs.cpu().numpy() for idx, item_id in enumerate(ids): results.append({'id': item_id, 'embed': embs_np[idx]}) # Save raw embeddings raw_file = os.path.join(RAW_DIR, f"raw_embeds_device_{gpu_id}.pkl") with open(raw_file, 'wb') as f: pickle.dump(results, f, protocol=pickle.HIGHEST_PROTOCOL) print(f"Device {gpu_id} saved {len(results)} raw embeddings to {raw_file}") # Function to apply mean shift and save final embeddings def apply_mean_shift_and_save(global_mean, gpu_id): raw_file = os.path.join(RAW_DIR, f"raw_embeds_device_{gpu_id}.pkl") out_file = os.path.join(SHIFTED_DIR, f"embeds_device_{gpu_id}.pkl") with open(raw_file, 'rb') as f: data = pickle.load(f) # Subtract global mean for item in data: item['embed'] = item['embed'] - global_mean # Save shifted embeddings with open(out_file, 'wb') as f: pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL) print(f"Device {gpu_id} saved {len(data)} shifted embeddings to {out_file}") # Main entry def main(): # 1st pass: compute raw embeddings in parallel procs = [] for i in range(num_gpus): p = mp.Process(target=compute_raw_embeddings, args=(f"cuda:{i}", chunks[i], i)) p.start() procs.append(p) for p in procs: p.join() if mean_shift: # Load all raw embeddings to compute global mean all_embeds = [] for i in range(num_gpus): raw_file = os.path.join(RAW_DIR, f"raw_embeds_device_{i}.pkl") with open(raw_file, 'rb') as f: data = pickle.load(f) all_embeds.extend([item['embed'] for item in data]) all_embeds = np.stack(all_embeds, axis=0) global_mean = np.mean(all_embeds, axis=0) print("Computed global mean of shape", global_mean.shape) # 2nd pass: subtract mean and save shifted embeddings for i in range(num_gpus): apply_mean_shift_and_save(global_mean, i) if __name__ == "__main__": main()