Spaces:
Sleeping
Sleeping
# -------------------------------------------------------- | |
# BEiT v2: Masked Image Modeling with Vector-Quantized Visual Tokenizers (https://arxiv.org/abs/2208.06366) | |
# Github source: https://github.com/microsoft/unilm/tree/master/beitv2 | |
# Copyright (c) 2022 Microsoft | |
# Licensed under The MIT License [see LICENSE for details] | |
# By Zhiliang Peng | |
# Based on BEiT, timm, DeiT and DINO code bases | |
# https://github.com/microsoft/unilm/tree/master/beit | |
# https://github.com/rwightman/pytorch-image-models/tree/master/timm | |
# https://github.com/facebookresearch/deit/ | |
# https://github.com/facebookresearch/dino | |
# --------------------------------------------------------' | |
import os | |
import sys | |
import argparse | |
import torch | |
from torch import nn | |
from torchvision import transforms as pth_transforms | |
from timm.models import create_model | |
from PIL import Image | |
import utils | |
import modeling_vqkd | |
def get_code(args): | |
# ============ preparing data ... ============ | |
transform = pth_transforms.Compose([ | |
pth_transforms.Resize(256, interpolation=3), | |
pth_transforms.CenterCrop(224), | |
pth_transforms.ToTensor(), | |
# pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), # Normalize in pre-process of vqkd | |
]) | |
print(f"Image transforms: {transform}") | |
images = transform(Image.open(args.img_path)).unsqueeze(0) | |
# ============ building network ... ============ | |
model = create_model( | |
args.model, | |
pretrained=True, | |
pretrained_weight=args.pretrained_weights, | |
as_tokenzer=True, | |
).eval() | |
input_ids = model.get_codebook_indices(images) | |
print(input_ids) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser('Get code for VQ-KD') | |
parser.add_argument('--model', default='vqkd_encoder_base_decoder_1x768x12_clip', type=str, help="model") | |
parser.add_argument('--pretrained_weights', | |
default='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/vqkd_encoder_base_decoder_1x768x12_clip-d93179da.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D', | |
type=str, help="Path to pretrained weights to evaluate.") | |
parser.add_argument('--img_path', default='demo/ILSVRC2012_val_00031649.JPEG', type=str, help="image path.") | |
args = parser.parse_args() | |
get_code(args) | |
# tensor([[3812, 7466, 1913, 1913, 1903, 1913, 1903, 1913, 3812, 7820, 6337, 2189, | |
# 7466, 7466, 2492, 3743, 5268, 3481, 5268, 4987, 445, 8009, 3501, 5268, | |
# 7820, 7831, 4816, 2189, 7549, 7549, 5548, 4987, 445, 4198, 445, 5216, | |
# 4987, 5268, 3278, 5203, 6337, 1799, 847, 6454, 4527, 5302, 8009, 3743, | |
# 5216, 4678, 3743, 4858, 5203, 4816, 7831, 2189, 7549, 5386, 6628, 5004, | |
# 2779, 7131, 7131, 7131, 4928, 3743, 119, 445, 1903, 7466, 4527, 5386, | |
# 5398, 5704, 2104, 5398, 2779, 7258, 7989, 624, 7131, 1186, 5216, 7466, | |
# 8015, 5004, 452, 7243, 3145, 6690, 7017, 2104, 5398, 4198, 7989, 7131, | |
# 3717, 7466, 580, 5004, 5004, 6202, 6202, 6202, 1826, 7521, 1473, 5722, | |
# 2486, 5663, 4928, 3941, 580, 5548, 7983, 7983, 7983, 2104, 5004, 2063, | |
# 2637, 1822, 3100, 3100, 1405, 1637, 8187, 5433, 2779, 5398, 5004, 5004, | |
# 1107, 3469, 3469, 5302, 2590, 6381, 3100, 4194, 3717, 356, 7131, 7688, | |
# 5104, 3081, 3812, 3950, 1186, 7131, 7131, 3717, 4399, 1186, 2221, 6501, | |
# 7131, 5433, 3014, 3950, 3278, 2812, 7131, 1186, 7036, 6947, 7036, 4648, | |
# 2812, 7131, 3014, 5295, 7266, 5180, 4123, 3792, 4648, 8009, 4648, 4816, | |
# 1511, 7036, 375, 2221, 5813, 5698, 168, 7131, 3792, 5698, 5698, 2667, | |
# 5698, 4648, 4171, 6501]], device='cuda:0') |