ssh-download / image_embed.py
XiN0919's picture
Upload folder using huggingface_hub
cd77de9 verified
import os
import json
import pickle
import torch
import numpy as np
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel
import torch.multiprocessing as mp
# Paths and settings
INPUT_JSON = "Pretrain.json"
mean_shift = True # Enable full-pass mean shifting
CKPT = "/root/autodl-tmp/model/siglip2"
BATCH_SIZE = 512
LOAD_LIMIT = None # Limit number of items, or None for all
# Output directories
RAW_DIR = "raw_embeds"
SHIFTED_DIR = "shifted_embeds"
# Create output directories if they don't exist
os.makedirs(RAW_DIR, exist_ok=True)
os.makedirs(SHIFTED_DIR, exist_ok=True)
# 1. Load data
with open(INPUT_JSON, "r", encoding="utf-8") as f:
items = json.load(f)
if LOAD_LIMIT is not None:
items = items[:LOAD_LIMIT]
# 2. Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(CKPT)
# 3. Split data among GPUs
num_gpus = torch.cuda.device_count()
chunks = np.array_split(items, num_gpus)
# Function to compute raw embeddings (no shift)
def compute_raw_embeddings(device, data_chunk, gpu_id):
device = torch.device(device)
model = AutoModel.from_pretrained(CKPT).to(device).eval()
results = [] # To store raw embeddings
for i in tqdm(range(0, len(data_chunk), BATCH_SIZE), desc=f"Device {gpu_id} Raw Batches"):
batch = data_chunk[i:i + BATCH_SIZE]
ids = [it['id'] for it in batch]
captions = [it.get('caption', '') for it in batch]
inputs = tokenizer(
captions,
padding="max_length",
truncation=True,
max_length=64,
return_tensors="pt"
).to(device)
with torch.no_grad():
embs = model.get_text_features(**inputs)
embs_np = embs.cpu().numpy()
for idx, item_id in enumerate(ids):
results.append({'id': item_id, 'embed': embs_np[idx]})
# Save raw embeddings
raw_file = os.path.join(RAW_DIR, f"raw_embeds_device_{gpu_id}.pkl")
with open(raw_file, 'wb') as f:
pickle.dump(results, f, protocol=pickle.HIGHEST_PROTOCOL)
print(f"Device {gpu_id} saved {len(results)} raw embeddings to {raw_file}")
# Function to apply mean shift and save final embeddings
def apply_mean_shift_and_save(global_mean, gpu_id):
raw_file = os.path.join(RAW_DIR, f"raw_embeds_device_{gpu_id}.pkl")
out_file = os.path.join(SHIFTED_DIR, f"embeds_device_{gpu_id}.pkl")
with open(raw_file, 'rb') as f:
data = pickle.load(f)
# Subtract global mean
for item in data:
item['embed'] = item['embed'] - global_mean
# Save shifted embeddings
with open(out_file, 'wb') as f:
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
print(f"Device {gpu_id} saved {len(data)} shifted embeddings to {out_file}")
# Main entry
def main():
# 1st pass: compute raw embeddings in parallel
procs = []
for i in range(num_gpus):
p = mp.Process(target=compute_raw_embeddings, args=(f"cuda:{i}", chunks[i], i))
p.start()
procs.append(p)
for p in procs:
p.join()
if mean_shift:
# Load all raw embeddings to compute global mean
all_embeds = []
for i in range(num_gpus):
raw_file = os.path.join(RAW_DIR, f"raw_embeds_device_{i}.pkl")
with open(raw_file, 'rb') as f:
data = pickle.load(f)
all_embeds.extend([item['embed'] for item in data])
all_embeds = np.stack(all_embeds, axis=0)
global_mean = np.mean(all_embeds, axis=0)
print("Computed global mean of shape", global_mean.shape)
# 2nd pass: subtract mean and save shifted embeddings
for i in range(num_gpus):
apply_mean_shift_and_save(global_mean, i)
if __name__ == "__main__":
main()