Spaces:
Sleeping
Sleeping
# -------------------------------------------------------- | |
# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) | |
# Github source: https://github.com/microsoft/unilm/tree/master/beit | |
# Copyright (c) 2021 Microsoft | |
# Licensed under The MIT License [see LICENSE for details] | |
# Based on DINO code bases | |
# https://github.com/facebookresearch/dino/blob/main/visualize_attention.py | |
# --------------------------------------------------------' | |
import os | |
import sys | |
import argparse | |
import cv2 | |
import random | |
import colorsys | |
import requests | |
from io import BytesIO | |
import skimage.io | |
from skimage.measure import find_contours | |
import matplotlib.pyplot as plt | |
from matplotlib.patches import Polygon | |
import torch | |
import torch.nn as nn | |
import torchvision | |
from torchvision import transforms as pth_transforms | |
import numpy as np | |
from PIL import Image | |
import utils | |
from timm.models import create_model | |
import modeling_pretrain | |
def apply_mask(image, mask, color, alpha=0.5): | |
for c in range(3): | |
image[:, :, c] = image[:, :, c] * (1 - alpha * mask) + alpha * mask * color[c] * 255 | |
return image | |
def random_colors(N, bright=True): | |
""" | |
Generate random colors. | |
""" | |
brightness = 1.0 if bright else 0.7 | |
hsv = [(i / N, 1, brightness) for i in range(N)] | |
colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv)) | |
random.shuffle(colors) | |
return colors | |
def display_instances(image, mask, fname="test", figsize=(5, 5), blur=False, contour=True, alpha=0.5): | |
fig = plt.figure(figsize=figsize, frameon=False) | |
ax = plt.Axes(fig, [0., 0., 1., 1.]) | |
ax.set_axis_off() | |
fig.add_axes(ax) | |
ax = plt.gca() | |
N = 1 | |
mask = mask[None, :, :] | |
# Generate random colors | |
colors = random_colors(N) | |
# Show area outside image boundaries. | |
height, width = image.shape[:2] | |
margin = 0 | |
ax.set_ylim(height + margin, -margin) | |
ax.set_xlim(-margin, width + margin) | |
ax.axis('off') | |
masked_image = image.astype(np.uint32).copy() | |
for i in range(N): | |
color = colors[i] | |
_mask = mask[i] | |
if blur: | |
_mask = cv2.blur(_mask,(10,10)) | |
# Mask | |
masked_image = apply_mask(masked_image, _mask, color, alpha) | |
# Mask Polygon | |
# Pad to ensure proper polygons for masks that touch image edges. | |
if contour: | |
padded_mask = np.zeros((_mask.shape[0] + 2, _mask.shape[1] + 2)) | |
padded_mask[1:-1, 1:-1] = _mask | |
contours = find_contours(padded_mask, 0.5) | |
for verts in contours: | |
# Subtract the padding and flip (y, x) to (x, y) | |
verts = np.fliplr(verts) - 1 | |
p = Polygon(verts, facecolor="none", edgecolor=color) | |
ax.add_patch(p) | |
ax.imshow(masked_image.astype(np.uint8), aspect='auto') | |
fig.savefig(fname) | |
print(f"{fname} saved.") | |
return | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser('Visualize Self-Attention maps') | |
parser.add_argument('--model', default='beit_base_patch16_224_8k_vocab', type=str, help='Architecture (support only ViT atm).') | |
parser.add_argument('--rel_pos_bias', action='store_true') | |
parser.add_argument('--disable_rel_pos_bias', action='store_false', dest='rel_pos_bias') | |
parser.set_defaults(rel_pos_bias=True) | |
parser.add_argument('--abs_pos_emb', action='store_true') | |
parser.set_defaults(abs_pos_emb=False) | |
parser.add_argument('--layer_scale_init_value', default=0.1, type=float, | |
help="0.1 for base, 1e-5 for large. set 0 to disable layer scale") | |
parser.add_argument('--input_size', default=480, type=int, help='Input resolution of the model.') | |
parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.') | |
parser.add_argument('--pretrained_weights', default='', type=str, | |
help="Path to pretrained weights to load.") | |
parser.add_argument("--checkpoint_key", default="model", type=str, | |
help='Key to use in the checkpoint (example: "teacher")') | |
parser.add_argument("--image_path", default=None, type=str, help="Path of the image to load.") | |
parser.add_argument('--output_dir', default='../visualization', help='Path where to save visualizations.') | |
parser.add_argument("--threshold", type=float, default=0.6, help="""We visualize masks | |
obtained by thresholding the self-attention maps to keep xx% of the mass.""") | |
parser.add_argument('--selected_row', default=8, type=int) | |
parser.add_argument('--selected_col', default=8, type=int) | |
args = parser.parse_args() | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
model = create_model( | |
args.model, | |
pretrained=False, | |
drop_rate=0, | |
drop_path_rate=0, | |
attn_drop_rate=0, | |
drop_block_rate=None, | |
use_rel_pos_bias=args.rel_pos_bias, | |
use_abs_pos_emb=args.abs_pos_emb, | |
init_values=args.layer_scale_init_value, | |
) | |
for p in model.parameters(): | |
p.requires_grad = False | |
model.eval() | |
model.to(device) | |
if os.path.isfile(args.pretrained_weights): | |
state_dict = torch.load(args.pretrained_weights, map_location="cpu") | |
if args.checkpoint_key is not None and args.checkpoint_key in state_dict: | |
print(f"Take key {args.checkpoint_key} in provided checkpoint dict") | |
state_dict = state_dict[args.checkpoint_key] | |
# remove `module.` prefix | |
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} | |
# remove `backbone.` prefix induced by multicrop wrapper | |
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} | |
msg = model.load_state_dict(state_dict, strict=False) | |
print('Pretrained weights found at {} and loaded with msg: {}'.format(args.pretrained_weights, msg)) | |
else: | |
print("Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate.") | |
print("There is no reference weights available for this model => We use random weights.") | |
# open image | |
if args.image_path is None: | |
# user has not specified any image - we use our own image | |
print("Please use the `--image_path` argument to indicate the path of the image you wish to visualize.") | |
print("Since no image path have been provided, we take the first image in our paper.") | |
response = requests.get("https://dl.fbaipublicfiles.com/dino/img.png") | |
img = Image.open(BytesIO(response.content)) | |
img = img.convert('RGB') | |
elif os.path.isfile(args.image_path): | |
with open(args.image_path, 'rb') as f: | |
img = Image.open(f) | |
img = img.convert('RGB') | |
else: | |
print(f"Provided image path {args.image_path} is non valid.") | |
sys.exit(1) | |
input_size = args.input_size | |
transform = pth_transforms.Compose([ | |
pth_transforms.Resize(input_size), | |
pth_transforms.CenterCrop(input_size), | |
pth_transforms.ToTensor(), | |
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), | |
]) | |
img = transform(img) | |
# make the image divisible by the patch size | |
w, h = img.shape[1] - img.shape[1] % args.patch_size, img.shape[2] - img.shape[2] % args.patch_size | |
img = img[:, :w, :h].unsqueeze(0) | |
w_featmap = img.shape[-2] // args.patch_size | |
h_featmap = img.shape[-1] // args.patch_size | |
attentions = model.get_last_selfattention(img.to(device)) | |
bsz, nh, num_patches, _ = attentions.size() | |
selected_row = args.selected_row | |
selected_col = args.selected_col | |
selected_index = selected_row * w_featmap + selected_col | |
attentions = attentions[0, :, selected_index + 1, 1:] | |
# we keep only a certain percentage of the mass | |
val, idx = torch.sort(attentions) | |
val /= torch.sum(val, dim=1, keepdim=True) | |
cumval = torch.cumsum(val, dim=1) | |
th_attn = cumval > (1 - args.threshold) | |
idx2 = torch.argsort(idx) | |
for head in range(nh): | |
th_attn[head] = th_attn[head][idx2[head]] | |
th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float() | |
# interpolate | |
th_attn = nn.functional.interpolate(th_attn.unsqueeze(0), scale_factor=args.patch_size, mode="nearest")[0].cpu().numpy() | |
attentions = attentions.reshape(nh, w_featmap, h_featmap) | |
attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=args.patch_size, mode="nearest")[0].cpu().numpy() | |
# save attentions heatmaps | |
os.makedirs(args.output_dir, exist_ok=True) | |
torchvision.utils.save_image(torchvision.utils.make_grid(img, normalize=True, scale_each=True), os.path.join(args.output_dir, "img.png")) | |
for j in range(nh): | |
fname = os.path.join(args.output_dir, "attn-head" + str(j) + ".png") | |
plt.imsave(fname=fname, arr=attentions[j], format='png') | |
print(f"{fname} saved.") | |
image = skimage.io.imread(os.path.join(args.output_dir, "img.png")) | |
select_image = skimage.io.imread(os.path.join(args.output_dir, "img.png")) | |
for _x in range(4, args.patch_size - 4): | |
for _y in range(4, args.patch_size - 4): | |
for _ in range(3): | |
x = _x + selected_row * args.patch_size | |
y = _y + selected_col * args.patch_size | |
select_image[x, y, _] = select_image[x, y, _] * 0.5 + [1.0, 0, 0][_] * 255.0 * 0.5 | |
fname = os.path.join(args.output_dir, "select.png") | |
plt.imsave(fname=fname, arr=select_image, format='png') | |
if args.threshold < 1.0: | |
for j in range(nh): | |
display_instances(image, th_attn[j], fname=os.path.join(args.output_dir, "mask_th" + str(args.threshold) + "_head" + str(j) +".png"), blur=False) | |