import os import sys from pathlib import Path import torch import torch.nn.functional as F from tqdm.auto import tqdm script_path = os.path.abspath(__file__) script_dir = os.path.dirname(script_path) project_root = os.path.abspath(os.path.join(script_dir, "..", "..")) sys.path.append(project_root) from src.data.embs import ImageDataset from src.model.blip_embs import blip_embs device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def get_blip_config(model="base"): config = dict() if model == "base": config[ "pretrained" ] = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth " config["vit"] = "base" config["batch_size_train"] = 32 config["batch_size_test"] = 16 config["vit_grad_ckpt"] = True config["vit_ckpt_layer"] = 4 config["init_lr"] = 1e-5 elif model == "large": config[ "pretrained" ] = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco.pth" config["vit"] = "large" config["batch_size_train"] = 16 config["batch_size_test"] = 32 config["vit_grad_ckpt"] = True config["vit_ckpt_layer"] = 12 config["init_lr"] = 5e-6 config["image_size"] = 384 config["queue_size"] = 57600 config["alpha"] = 0.4 config["k_test"] = 256 config["negative_all_rank"] = True return config @torch.no_grad() def main(args): dataset = ImageDataset( image_dir=args.image_dir, img_ext=args.img_ext, save_dir=args.save_dir, ) loader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=args.num_workers, ) print("Creating model") config = get_blip_config(args.model_type) model = blip_embs( pretrained=config["pretrained"], image_size=config["image_size"], vit=config["vit"], vit_grad_ckpt=config["vit_grad_ckpt"], vit_ckpt_layer=config["vit_ckpt_layer"], queue_size=config["queue_size"], negative_all_rank=config["negative_all_rank"], ) model = model.to(device) model.eval() for imgs, video_ids in tqdm(loader): imgs = imgs.to(device) img_embs = model.visual_encoder(imgs) img_feats = F.normalize(model.vision_proj(img_embs[:, 0, :]), dim=-1).cpu() for img_feat, video_id in zip(img_feats, video_ids): torch.save(img_feat, args.save_dir / f"{video_id}.pth") if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument( "--image_dir", type=Path, required=True, help="Path to image directory" ) parser.add_argument("--save_dir", type=Path) parser.add_argument("--img_ext", type=str, default="png") parser.add_argument("--batch_size", type=int, default=128) parser.add_argument("--num_workers", type=int, default=8) parser.add_argument( "--model_type", type=str, default="large", choices=["base", "large"] ) args = parser.parse_args() subdirectories = [subdir for subdir in args.image_dir.iterdir() if subdir.is_dir()] if len(subdirectories) == 0: args.save_dir = args.image_dir.parent / f"blip-embs-{args.model_type}" args.save_dir.mkdir(exist_ok=True) main(args) else: for subdir in subdirectories: args.image_dir = subdir args.save_dir = ( subdir.parent.parent / f"blip-embs-{args.model_type}" / subdir.name ) args.save_dir.mkdir(exist_ok=True, parents=True) main(args)