|
import os |
|
import json |
|
import pickle |
|
import torch |
|
import numpy as np |
|
from tqdm import tqdm |
|
from transformers import AutoTokenizer, AutoModel |
|
import torch.multiprocessing as mp |
|
|
|
|
|
INPUT_JSON = "Pretrain.json" |
|
mean_shift = True |
|
CKPT = "/root/autodl-tmp/model/siglip2" |
|
BATCH_SIZE = 512 |
|
LOAD_LIMIT = None |
|
|
|
|
|
RAW_DIR = "raw_embeds" |
|
SHIFTED_DIR = "shifted_embeds" |
|
|
|
|
|
os.makedirs(RAW_DIR, exist_ok=True) |
|
os.makedirs(SHIFTED_DIR, exist_ok=True) |
|
|
|
|
|
with open(INPUT_JSON, "r", encoding="utf-8") as f: |
|
items = json.load(f) |
|
if LOAD_LIMIT is not None: |
|
items = items[:LOAD_LIMIT] |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(CKPT) |
|
|
|
|
|
num_gpus = torch.cuda.device_count() |
|
chunks = np.array_split(items, num_gpus) |
|
|
|
|
|
def compute_raw_embeddings(device, data_chunk, gpu_id): |
|
device = torch.device(device) |
|
model = AutoModel.from_pretrained(CKPT).to(device).eval() |
|
results = [] |
|
|
|
for i in tqdm(range(0, len(data_chunk), BATCH_SIZE), desc=f"Device {gpu_id} Raw Batches"): |
|
batch = data_chunk[i:i + BATCH_SIZE] |
|
ids = [it['id'] for it in batch] |
|
captions = [it.get('caption', '') for it in batch] |
|
|
|
inputs = tokenizer( |
|
captions, |
|
padding="max_length", |
|
truncation=True, |
|
max_length=64, |
|
return_tensors="pt" |
|
).to(device) |
|
|
|
with torch.no_grad(): |
|
embs = model.get_text_features(**inputs) |
|
embs_np = embs.cpu().numpy() |
|
|
|
for idx, item_id in enumerate(ids): |
|
results.append({'id': item_id, 'embed': embs_np[idx]}) |
|
|
|
|
|
raw_file = os.path.join(RAW_DIR, f"raw_embeds_device_{gpu_id}.pkl") |
|
with open(raw_file, 'wb') as f: |
|
pickle.dump(results, f, protocol=pickle.HIGHEST_PROTOCOL) |
|
print(f"Device {gpu_id} saved {len(results)} raw embeddings to {raw_file}") |
|
|
|
|
|
def apply_mean_shift_and_save(global_mean, gpu_id): |
|
raw_file = os.path.join(RAW_DIR, f"raw_embeds_device_{gpu_id}.pkl") |
|
out_file = os.path.join(SHIFTED_DIR, f"embeds_device_{gpu_id}.pkl") |
|
|
|
with open(raw_file, 'rb') as f: |
|
data = pickle.load(f) |
|
|
|
|
|
for item in data: |
|
item['embed'] = item['embed'] - global_mean |
|
|
|
|
|
with open(out_file, 'wb') as f: |
|
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL) |
|
print(f"Device {gpu_id} saved {len(data)} shifted embeddings to {out_file}") |
|
|
|
|
|
def main(): |
|
|
|
procs = [] |
|
for i in range(num_gpus): |
|
p = mp.Process(target=compute_raw_embeddings, args=(f"cuda:{i}", chunks[i], i)) |
|
p.start() |
|
procs.append(p) |
|
for p in procs: |
|
p.join() |
|
|
|
if mean_shift: |
|
|
|
all_embeds = [] |
|
for i in range(num_gpus): |
|
raw_file = os.path.join(RAW_DIR, f"raw_embeds_device_{i}.pkl") |
|
with open(raw_file, 'rb') as f: |
|
data = pickle.load(f) |
|
all_embeds.extend([item['embed'] for item in data]) |
|
|
|
all_embeds = np.stack(all_embeds, axis=0) |
|
global_mean = np.mean(all_embeds, axis=0) |
|
print("Computed global mean of shape", global_mean.shape) |
|
|
|
|
|
for i in range(num_gpus): |
|
apply_mean_shift_and_save(global_mean, i) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|