# -------------------------------------------------------- # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) # Github source: https://github.com/microsoft/unilm/tree/master/beit # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # By Hangbo Bao # Based on timm, DINO and DeiT code bases # https://github.com/rwightman/pytorch-image-models/tree/master/timm # https://github.com/facebookresearch/deit/ # https://github.com/facebookresearch/dino # --------------------------------------------------------' import argparse import os import torch import random from torchvision import datasets, transforms from timm.data.constants import \ IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from transforms import RandomResizedCropAndInterpolationWithTwoPic, _pil_interp from timm.data import create_transform, ImageDataset from masking_generator import MaskingGenerator from dataset_folder import ImageFolder class DataAugmentationForBEiT(object): def __init__(self, args): imagenet_default_mean_and_std = args.imagenet_default_mean_and_std mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD # oringinal beit data augmentation self.common_transform = transforms.Compose([ transforms.ColorJitter(0.4, 0.4, 0.4), transforms.RandomHorizontalFlip(p=0.5), RandomResizedCropAndInterpolationWithTwoPic( size=args.input_size, second_size=args.second_input_size, scale=(args.min_crop_scale, 1.0), interpolation=args.train_interpolation, second_interpolation=args.second_interpolation, ), ]) self.patch_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize( mean=torch.tensor(mean), std=torch.tensor(std)) ]) self.visual_token_transform = transforms.Compose([ transforms.ToTensor(),]) self.masked_position_generator = MaskingGenerator( args.window_size, num_masking_patches=args.num_mask_patches, max_num_patches=args.max_mask_patches_per_block, min_num_patches=args.min_mask_patches_per_block, ) def __call__(self, image): for_patches, for_visual_tokens = self.common_transform(image) return \ self.patch_transform(for_patches), self.visual_token_transform(for_visual_tokens), \ self.masked_position_generator() def __repr__(self): repr = "(DataAugmentationForBEiT,\n" repr += " common_transform = %s,\n" % str(self.common_transform) repr += " patch_transform = %s,\n" % str(self.patch_transform) repr += " visual_tokens_transform = %s,\n" % str(self.visual_token_transform) repr += " Masked position generator = %s,\n" % str(self.masked_position_generator) repr += ")" return repr def build_beit_pretraining_dataset(args): transform = DataAugmentationForBEiT(args) print("Data Aug = %s" % str(transform)) return ImageFolder(args.data_path, transform=transform) ############################################### Dataset and Transforms for Tokenizer Training ######################################################### def build_vqkd_dataset(is_train, args): if is_train: t = [] if args.color_jitter > 0.: t.append(transforms.ColorJitter(args.color_jitter, args.color_jitter, args.color_jitter)) t.append(transforms.RandomResizedCrop(args.input_size, scale=(args.min_crop_scale, 1.0), interpolation=_pil_interp(args.train_interpolation))) t.append(transforms.RandomHorizontalFlip(0.5)) t.append(transforms.ToTensor()) transform = transforms.Compose(t) else: t = [] if args.input_size < 384: args.crop_pct = 224 / 256 else: args.crop_pct = 1.0 size = int(args.input_size / args.crop_pct) t.append( transforms.Resize(size, interpolation=_pil_interp(args.train_interpolation)), # to maintain same ratio w.r.t. 224 images ) t.append(transforms.CenterCrop(args.input_size)) t.append(transforms.ToTensor()) transform = transforms.Compose(t) print(f"{'Train' if is_train else 'Test'} Data Aug: {str(transform)}") if args.data_set == 'image_folder': if is_train: return ImageFolder(args.data_path, transform=transform) else: if args.eval_data_path == '': return ImageFolder(args.data_path, transform=transform) else: return ImageFolder(args.eval_data_path, transform=transform) else: raise NotImplementedError() ############################################### Dataset and Transforms for Ft ######################################################### def build_dataset(is_train, args): transform = build_transform(is_train, args) print("Transform = ") if isinstance(transform, tuple): for trans in transform: print(" - - - - - - - - - - ") for t in trans.transforms: print(t) else: for t in transform.transforms: print(t) print("---------------------------") if args.data_set == 'CIFAR': dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform) nb_classes = 100 elif args.data_set == 'IMNET': root = os.path.join(args.data_path, 'train' if is_train else 'val') dataset = datasets.ImageFolder(root, transform=transform) nb_classes = 1000 elif args.data_set == "image_folder": root = args.data_path if is_train else args.eval_data_path index_file = args.image_folder_class_index_file dataset = ImageFolder(root, transform=transform, index_file=index_file) nb_classes = args.nb_classes assert len(dataset.class_to_idx) == nb_classes else: raise NotImplementedError() assert nb_classes == args.nb_classes print("Number of the class = %d" % args.nb_classes) return dataset, nb_classes def build_transform(is_train, args): resize_im = args.input_size > 32 imagenet_default_mean_and_std = args.imagenet_default_mean_and_std mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD if is_train: # this should always dispatch to transforms_imagenet_train transform = create_transform( input_size=args.input_size, is_training=True, color_jitter=args.color_jitter, auto_augment=args.aa, interpolation=args.train_interpolation, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, mean=mean, std=std, ) if not resize_im: # replace RandomResizedCropAndInterpolation with # RandomCrop transform.transforms[0] = transforms.RandomCrop( args.input_size, padding=4) return transform t = [] if resize_im: if args.crop_pct is None: if args.input_size < 384: args.crop_pct = 224 / 256 else: args.crop_pct = 1.0 size = int(args.input_size / args.crop_pct) t.append( transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images ) t.append(transforms.CenterCrop(args.input_size)) t.append(transforms.ToTensor()) t.append(transforms.Normalize(mean, std)) return transforms.Compose(t)