Spaces:
Sleeping
Sleeping
# -------------------------------------------------------- | |
# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) | |
# Github source: https://github.com/microsoft/unilm/tree/master/beit | |
# Copyright (c) 2021 Microsoft | |
# Licensed under The MIT License [see LICENSE for details] | |
# By Hangbo Bao | |
# Based on timm, DINO and DeiT code bases | |
# https://github.com/rwightman/pytorch-image-models/tree/master/timm | |
# https://github.com/facebookresearch/deit/ | |
# https://github.com/facebookresearch/dino | |
# --------------------------------------------------------' | |
import argparse | |
import os | |
import torch | |
import random | |
from torchvision import datasets, transforms | |
from timm.data.constants import \ | |
IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD | |
from transforms import RandomResizedCropAndInterpolationWithTwoPic, _pil_interp | |
from timm.data import create_transform, ImageDataset | |
from masking_generator import MaskingGenerator | |
from dataset_folder import ImageFolder | |
class DataAugmentationForBEiT(object): | |
def __init__(self, args): | |
imagenet_default_mean_and_std = args.imagenet_default_mean_and_std | |
mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN | |
std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD | |
# oringinal beit data augmentation | |
self.common_transform = transforms.Compose([ | |
transforms.ColorJitter(0.4, 0.4, 0.4), | |
transforms.RandomHorizontalFlip(p=0.5), | |
RandomResizedCropAndInterpolationWithTwoPic( | |
size=args.input_size, second_size=args.second_input_size, scale=(args.min_crop_scale, 1.0), | |
interpolation=args.train_interpolation, second_interpolation=args.second_interpolation, | |
), | |
]) | |
self.patch_transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize( | |
mean=torch.tensor(mean), | |
std=torch.tensor(std)) | |
]) | |
self.visual_token_transform = transforms.Compose([ | |
transforms.ToTensor(),]) | |
self.masked_position_generator = MaskingGenerator( | |
args.window_size, num_masking_patches=args.num_mask_patches, | |
max_num_patches=args.max_mask_patches_per_block, | |
min_num_patches=args.min_mask_patches_per_block, | |
) | |
def __call__(self, image): | |
for_patches, for_visual_tokens = self.common_transform(image) | |
return \ | |
self.patch_transform(for_patches), self.visual_token_transform(for_visual_tokens), \ | |
self.masked_position_generator() | |
def __repr__(self): | |
repr = "(DataAugmentationForBEiT,\n" | |
repr += " common_transform = %s,\n" % str(self.common_transform) | |
repr += " patch_transform = %s,\n" % str(self.patch_transform) | |
repr += " visual_tokens_transform = %s,\n" % str(self.visual_token_transform) | |
repr += " Masked position generator = %s,\n" % str(self.masked_position_generator) | |
repr += ")" | |
return repr | |
def build_beit_pretraining_dataset(args): | |
transform = DataAugmentationForBEiT(args) | |
print("Data Aug = %s" % str(transform)) | |
return ImageFolder(args.data_path, transform=transform) | |
############################################### Dataset and Transforms for Tokenizer Training ######################################################### | |
def build_vqkd_dataset(is_train, args): | |
if is_train: | |
t = [] | |
if args.color_jitter > 0.: | |
t.append(transforms.ColorJitter(args.color_jitter, args.color_jitter, args.color_jitter)) | |
t.append(transforms.RandomResizedCrop(args.input_size, scale=(args.min_crop_scale, 1.0), interpolation=_pil_interp(args.train_interpolation))) | |
t.append(transforms.RandomHorizontalFlip(0.5)) | |
t.append(transforms.ToTensor()) | |
transform = transforms.Compose(t) | |
else: | |
t = [] | |
if args.input_size < 384: | |
args.crop_pct = 224 / 256 | |
else: | |
args.crop_pct = 1.0 | |
size = int(args.input_size / args.crop_pct) | |
t.append( | |
transforms.Resize(size, interpolation=_pil_interp(args.train_interpolation)), # to maintain same ratio w.r.t. 224 images | |
) | |
t.append(transforms.CenterCrop(args.input_size)) | |
t.append(transforms.ToTensor()) | |
transform = transforms.Compose(t) | |
print(f"{'Train' if is_train else 'Test'} Data Aug: {str(transform)}") | |
if args.data_set == 'image_folder': | |
if is_train: | |
return ImageFolder(args.data_path, transform=transform) | |
else: | |
if args.eval_data_path == '': | |
return ImageFolder(args.data_path, transform=transform) | |
else: | |
return ImageFolder(args.eval_data_path, transform=transform) | |
else: | |
raise NotImplementedError() | |
############################################### Dataset and Transforms for Ft ######################################################### | |
def build_dataset(is_train, args): | |
transform = build_transform(is_train, args) | |
print("Transform = ") | |
if isinstance(transform, tuple): | |
for trans in transform: | |
print(" - - - - - - - - - - ") | |
for t in trans.transforms: | |
print(t) | |
else: | |
for t in transform.transforms: | |
print(t) | |
print("---------------------------") | |
if args.data_set == 'CIFAR': | |
dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform) | |
nb_classes = 100 | |
elif args.data_set == 'IMNET': | |
root = os.path.join(args.data_path, 'train' if is_train else 'val') | |
dataset = datasets.ImageFolder(root, transform=transform) | |
nb_classes = 1000 | |
elif args.data_set == "image_folder": | |
root = args.data_path if is_train else args.eval_data_path | |
index_file = args.image_folder_class_index_file | |
dataset = ImageFolder(root, transform=transform, index_file=index_file) | |
nb_classes = args.nb_classes | |
assert len(dataset.class_to_idx) == nb_classes | |
else: | |
raise NotImplementedError() | |
assert nb_classes == args.nb_classes | |
print("Number of the class = %d" % args.nb_classes) | |
return dataset, nb_classes | |
def build_transform(is_train, args): | |
resize_im = args.input_size > 32 | |
imagenet_default_mean_and_std = args.imagenet_default_mean_and_std | |
mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN | |
std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD | |
if is_train: | |
# this should always dispatch to transforms_imagenet_train | |
transform = create_transform( | |
input_size=args.input_size, | |
is_training=True, | |
color_jitter=args.color_jitter, | |
auto_augment=args.aa, | |
interpolation=args.train_interpolation, | |
re_prob=args.reprob, | |
re_mode=args.remode, | |
re_count=args.recount, | |
mean=mean, | |
std=std, | |
) | |
if not resize_im: | |
# replace RandomResizedCropAndInterpolation with | |
# RandomCrop | |
transform.transforms[0] = transforms.RandomCrop( | |
args.input_size, padding=4) | |
return transform | |
t = [] | |
if resize_im: | |
if args.crop_pct is None: | |
if args.input_size < 384: | |
args.crop_pct = 224 / 256 | |
else: | |
args.crop_pct = 1.0 | |
size = int(args.input_size / args.crop_pct) | |
t.append( | |
transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images | |
) | |
t.append(transforms.CenterCrop(args.input_size)) | |
t.append(transforms.ToTensor()) | |
t.append(transforms.Normalize(mean, std)) | |
return transforms.Compose(t) | |