# -------------------------------------------------------- # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) # Github source: https://github.com/microsoft/unilm/tree/master/beit # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # Based on DINO code bases # https://github.com/facebookresearch/dino/blob/main/visualize_attention.py # --------------------------------------------------------' import os import sys import argparse import cv2 import random import colorsys import requests from io import BytesIO import skimage.io from skimage.measure import find_contours import matplotlib.pyplot as plt from matplotlib.patches import Polygon import torch import torch.nn as nn import torchvision from torchvision import transforms as pth_transforms import numpy as np from PIL import Image import utils from timm.models import create_model import modeling_pretrain def apply_mask(image, mask, color, alpha=0.5): for c in range(3): image[:, :, c] = image[:, :, c] * (1 - alpha * mask) + alpha * mask * color[c] * 255 return image def random_colors(N, bright=True): """ Generate random colors. """ brightness = 1.0 if bright else 0.7 hsv = [(i / N, 1, brightness) for i in range(N)] colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv)) random.shuffle(colors) return colors def display_instances(image, mask, fname="test", figsize=(5, 5), blur=False, contour=True, alpha=0.5): fig = plt.figure(figsize=figsize, frameon=False) ax = plt.Axes(fig, [0., 0., 1., 1.]) ax.set_axis_off() fig.add_axes(ax) ax = plt.gca() N = 1 mask = mask[None, :, :] # Generate random colors colors = random_colors(N) # Show area outside image boundaries. height, width = image.shape[:2] margin = 0 ax.set_ylim(height + margin, -margin) ax.set_xlim(-margin, width + margin) ax.axis('off') masked_image = image.astype(np.uint32).copy() for i in range(N): color = colors[i] _mask = mask[i] if blur: _mask = cv2.blur(_mask,(10,10)) # Mask masked_image = apply_mask(masked_image, _mask, color, alpha) # Mask Polygon # Pad to ensure proper polygons for masks that touch image edges. if contour: padded_mask = np.zeros((_mask.shape[0] + 2, _mask.shape[1] + 2)) padded_mask[1:-1, 1:-1] = _mask contours = find_contours(padded_mask, 0.5) for verts in contours: # Subtract the padding and flip (y, x) to (x, y) verts = np.fliplr(verts) - 1 p = Polygon(verts, facecolor="none", edgecolor=color) ax.add_patch(p) ax.imshow(masked_image.astype(np.uint8), aspect='auto') fig.savefig(fname) print(f"{fname} saved.") return if __name__ == '__main__': parser = argparse.ArgumentParser('Visualize Self-Attention maps') parser.add_argument('--model', default='beit_base_patch16_224_8k_vocab', type=str, help='Architecture (support only ViT atm).') parser.add_argument('--rel_pos_bias', action='store_true') parser.add_argument('--disable_rel_pos_bias', action='store_false', dest='rel_pos_bias') parser.set_defaults(rel_pos_bias=True) parser.add_argument('--abs_pos_emb', action='store_true') parser.set_defaults(abs_pos_emb=False) parser.add_argument('--layer_scale_init_value', default=0.1, type=float, help="0.1 for base, 1e-5 for large. set 0 to disable layer scale") parser.add_argument('--input_size', default=480, type=int, help='Input resolution of the model.') parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.') parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to load.") parser.add_argument("--checkpoint_key", default="model", type=str, help='Key to use in the checkpoint (example: "teacher")') parser.add_argument("--image_path", default=None, type=str, help="Path of the image to load.") parser.add_argument('--output_dir', default='../visualization', help='Path where to save visualizations.') parser.add_argument("--threshold", type=float, default=0.6, help="""We visualize masks obtained by thresholding the self-attention maps to keep xx% of the mass.""") parser.add_argument('--selected_row', default=8, type=int) parser.add_argument('--selected_col', default=8, type=int) args = parser.parse_args() device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = create_model( args.model, pretrained=False, drop_rate=0, drop_path_rate=0, attn_drop_rate=0, drop_block_rate=None, use_rel_pos_bias=args.rel_pos_bias, use_abs_pos_emb=args.abs_pos_emb, init_values=args.layer_scale_init_value, ) for p in model.parameters(): p.requires_grad = False model.eval() model.to(device) if os.path.isfile(args.pretrained_weights): state_dict = torch.load(args.pretrained_weights, map_location="cpu") if args.checkpoint_key is not None and args.checkpoint_key in state_dict: print(f"Take key {args.checkpoint_key} in provided checkpoint dict") state_dict = state_dict[args.checkpoint_key] # remove `module.` prefix state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} # remove `backbone.` prefix induced by multicrop wrapper state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} msg = model.load_state_dict(state_dict, strict=False) print('Pretrained weights found at {} and loaded with msg: {}'.format(args.pretrained_weights, msg)) else: print("Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate.") print("There is no reference weights available for this model => We use random weights.") # open image if args.image_path is None: # user has not specified any image - we use our own image print("Please use the `--image_path` argument to indicate the path of the image you wish to visualize.") print("Since no image path have been provided, we take the first image in our paper.") response = requests.get("https://dl.fbaipublicfiles.com/dino/img.png") img = Image.open(BytesIO(response.content)) img = img.convert('RGB') elif os.path.isfile(args.image_path): with open(args.image_path, 'rb') as f: img = Image.open(f) img = img.convert('RGB') else: print(f"Provided image path {args.image_path} is non valid.") sys.exit(1) input_size = args.input_size transform = pth_transforms.Compose([ pth_transforms.Resize(input_size), pth_transforms.CenterCrop(input_size), pth_transforms.ToTensor(), pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) img = transform(img) # make the image divisible by the patch size w, h = img.shape[1] - img.shape[1] % args.patch_size, img.shape[2] - img.shape[2] % args.patch_size img = img[:, :w, :h].unsqueeze(0) w_featmap = img.shape[-2] // args.patch_size h_featmap = img.shape[-1] // args.patch_size attentions = model.get_last_selfattention(img.to(device)) bsz, nh, num_patches, _ = attentions.size() selected_row = args.selected_row selected_col = args.selected_col selected_index = selected_row * w_featmap + selected_col attentions = attentions[0, :, selected_index + 1, 1:] # we keep only a certain percentage of the mass val, idx = torch.sort(attentions) val /= torch.sum(val, dim=1, keepdim=True) cumval = torch.cumsum(val, dim=1) th_attn = cumval > (1 - args.threshold) idx2 = torch.argsort(idx) for head in range(nh): th_attn[head] = th_attn[head][idx2[head]] th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float() # interpolate th_attn = nn.functional.interpolate(th_attn.unsqueeze(0), scale_factor=args.patch_size, mode="nearest")[0].cpu().numpy() attentions = attentions.reshape(nh, w_featmap, h_featmap) attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=args.patch_size, mode="nearest")[0].cpu().numpy() # save attentions heatmaps os.makedirs(args.output_dir, exist_ok=True) torchvision.utils.save_image(torchvision.utils.make_grid(img, normalize=True, scale_each=True), os.path.join(args.output_dir, "img.png")) for j in range(nh): fname = os.path.join(args.output_dir, "attn-head" + str(j) + ".png") plt.imsave(fname=fname, arr=attentions[j], format='png') print(f"{fname} saved.") image = skimage.io.imread(os.path.join(args.output_dir, "img.png")) select_image = skimage.io.imread(os.path.join(args.output_dir, "img.png")) for _x in range(4, args.patch_size - 4): for _y in range(4, args.patch_size - 4): for _ in range(3): x = _x + selected_row * args.patch_size y = _y + selected_col * args.patch_size select_image[x, y, _] = select_image[x, y, _] * 0.5 + [1.0, 0, 0][_] * 255.0 * 0.5 fname = os.path.join(args.output_dir, "select.png") plt.imsave(fname=fname, arr=select_image, format='png') if args.threshold < 1.0: for j in range(nh): display_instances(image, th_attn[j], fname=os.path.join(args.output_dir, "mask_th" + str(args.threshold) + "_head" + str(j) +".png"), blur=False)