WwYc commited on
Commit
5196e4f
·
verified ·
1 Parent(s): 3d27aee

Create visualization

Browse files
Files changed (1) hide show
  1. visualization +73 -0
visualization ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from PIL import Image
4
+ import torchvision.transforms as transforms
5
+ import matplotlib.pyplot as plt
6
+ import pylab
7
+ import torch
8
+ import numpy as np
9
+ import cv2
10
+ import sys
11
+ sys.path.append('ViT_DeiT')
12
+ from samples.CLS2IDX import CLS2IDX
13
+ from baselines.ViT.ViT_LRP import vit_base_patch16_224 as vit_LRP
14
+ from baselines.ViT.ViT_explanation_generator import LRP
15
+
16
+ normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
17
+ transform = transforms.Compose([
18
+ transforms.Resize(256),
19
+ transforms.CenterCrop(224),
20
+ transforms.ToTensor(),
21
+ normalize,
22
+ ])
23
+ use_thresholding = False
24
+ def show_cam_on_image(img, mask):
25
+ heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
26
+ heatmap = np.float32(heatmap) / 255
27
+ cam = heatmap + np.float32(img)
28
+ cam = cam / np.max(cam)
29
+ return cam
30
+
31
+ # initialize ViT pretrained
32
+ model = vit_LRP(pretrained=True).cuda()
33
+ model.eval()
34
+ attribution_generator = LRP(model)
35
+
36
+ def generate_visualization(original_image, class_index=None):
37
+ transformer_attribution = attribution_generator.generate_LRP(original_image.unsqueeze(0).cuda(), method="transformer_attribution", index=class_index).detach()
38
+ transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14)
39
+ transformer_attribution = torch.nn.functional.interpolate(transformer_attribution, scale_factor=16, mode='bilinear')
40
+ transformer_attribution = transformer_attribution.reshape(224, 224).data.cpu().numpy()
41
+ transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min())
42
+
43
+ if use_thresholding:
44
+ transformer_attribution = transformer_attribution * 255
45
+ transformer_attribution = transformer_attribution.astype(np.uint8)
46
+ ret, transformer_attribution = cv2.threshold(transformer_attribution, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
47
+ transformer_attribution[transformer_attribution == 255] = 1
48
+
49
+ image_transformer_attribution = original_image.permute(1, 2, 0).data.cpu().numpy()
50
+ image_transformer_attribution = (image_transformer_attribution - image_transformer_attribution.min()) / (image_transformer_attribution.max() - image_transformer_attribution.min())
51
+ vis = show_cam_on_image(image_transformer_attribution, transformer_attribution)
52
+ vis = np.uint8(255 * vis)
53
+ vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
54
+ return vis
55
+
56
+
57
+ def print_top_classes(predictions, **kwargs):
58
+ # Print Top-5 predictions
59
+ prob = torch.softmax(predictions, dim=1)
60
+ class_indices = predictions.data.topk(5, dim=1)[1][0].tolist()
61
+ max_str_len = 0
62
+ class_names = []
63
+ for cls_idx in class_indices:
64
+ class_names.append(CLS2IDX[cls_idx])
65
+ if len(CLS2IDX[cls_idx]) > max_str_len:
66
+ max_str_len = len(CLS2IDX[cls_idx])
67
+
68
+ print('Top 5 classes:')
69
+ for cls_idx in class_indices:
70
+ output_string = '\t{} : {}'.format(cls_idx, CLS2IDX[cls_idx])
71
+ output_string += ' ' * (max_str_len - len(CLS2IDX[cls_idx])) + '\t\t'
72
+ output_string += 'value = {:.3f}\t prob = {:.1f}%'.format(predictions[0, cls_idx], 100 * prob[0, cls_idx])
73
+ print(output_string)