Spaces:
Sleeping
Sleeping
""" | |
Originally inspired by impl at https://github.com/zhunzhong07/Random-Erasing, Apache 2.0 | |
Copyright Zhun Zhong & Liang Zheng | |
Hacked together by / Copyright 2020 Ross Wightman | |
Modified by Hangbo Bao, for generating the masked position for visual image transformer | |
""" | |
# -------------------------------------------------------- | |
# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) | |
# Github source: https://github.com/microsoft/unilm/tree/master/beit | |
# Copyright (c) 2021 Microsoft | |
# Licensed under The MIT License [see LICENSE for details] | |
# By Hangbo Bao | |
# Based on timm, DINO and DeiT code bases | |
# https://github.com/rwightman/pytorch-image-models/tree/master/timm | |
# Originally inspired by impl at https://github.com/zhunzhong07/Random-Erasing, Apache 2.0 | |
# Copyright Zhun Zhong & Liang Zheng | |
# | |
# Hacked together by / Copyright 2020 Ross Wightman | |
# | |
# Modified by Hangbo Bao, for generating the masked position for visual image transformer | |
# --------------------------------------------------------' | |
import random | |
import math | |
import numpy as np | |
class MaskingGenerator: | |
def __init__( | |
self, input_size, num_masking_patches, min_num_patches=4, max_num_patches=None, | |
min_aspect=0.3, max_aspect=None): | |
if not isinstance(input_size, tuple): | |
input_size = (input_size, ) * 2 | |
self.height, self.width = input_size | |
self.num_patches = self.height * self.width | |
self.num_masking_patches = num_masking_patches | |
self.min_num_patches = min_num_patches | |
self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches | |
max_aspect = max_aspect or 1 / min_aspect | |
self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) | |
def __repr__(self): | |
repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % ( | |
self.height, self.width, self.min_num_patches, self.max_num_patches, | |
self.num_masking_patches, self.log_aspect_ratio[0], self.log_aspect_ratio[1]) | |
return repr_str | |
def get_shape(self): | |
return self.height, self.width | |
def _mask(self, mask, max_mask_patches): | |
delta = 0 | |
for attempt in range(10): | |
target_area = random.uniform(self.min_num_patches, max_mask_patches) | |
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) | |
h = int(round(math.sqrt(target_area * aspect_ratio))) | |
w = int(round(math.sqrt(target_area / aspect_ratio))) | |
if w < self.width and h < self.height: | |
top = random.randint(0, self.height - h) | |
left = random.randint(0, self.width - w) | |
num_masked = mask[top: top + h, left: left + w].sum() | |
# Overlap | |
if 0 < h * w - num_masked <= max_mask_patches: | |
for i in range(top, top + h): | |
for j in range(left, left + w): | |
if mask[i, j] == 0: | |
mask[i, j] = 1 | |
delta += 1 | |
if delta > 0: | |
break | |
return delta | |
def __call__(self): | |
mask = np.zeros(shape=self.get_shape(), dtype=np.int32) | |
mask_count = 0 | |
while mask_count < self.num_masking_patches: | |
max_mask_patches = self.num_masking_patches - mask_count | |
max_mask_patches = min(max_mask_patches, self.max_num_patches) | |
delta = self._mask(mask, max_mask_patches) | |
if delta == 0: | |
break | |
else: | |
mask_count += delta | |
# maintain a fix number {self.num_masking_patches} | |
if mask_count > self.num_masking_patches: | |
delta = mask_count - self.num_masking_patches | |
mask_x, mask_y = mask.nonzero() | |
to_vis = np.random.choice(mask_x.shape[0], delta, replace=False) | |
mask[mask_x[to_vis], mask_y[to_vis]] = 0 | |
elif mask_count < self.num_masking_patches: | |
delta = self.num_masking_patches - mask_count | |
mask_x, mask_y = (mask == 0).nonzero() | |
to_mask = np.random.choice(mask_x.shape[0], delta, replace=False) | |
mask[mask_x[to_mask], mask_y[to_mask]] = 1 | |
assert mask.sum() == self.num_masking_patches, f"mask: {mask}, mask count {mask.sum()}" | |
return mask | |
if __name__ == '__main__': | |
import pdb | |
generator = MaskingGenerator(input_size=14, num_masking_patches=118, min_num_patches=16,) | |
for i in range(10000000): | |
mask = generator() | |
if mask.sum() != 118: | |
pdb.set_trace() | |
print(mask) | |
print(mask.sum()) |