WwYc commited on
Commit
8727b39
·
verified ·
1 Parent(s): 39c6f3a

Delete ViT_DeiT/VIT-EXPL.py

Browse files
Files changed (1) hide show
  1. ViT_DeiT/VIT-EXPL.py +0 -96
ViT_DeiT/VIT-EXPL.py DELETED
@@ -1,96 +0,0 @@
1
- import os
2
-
3
- from PIL import Image
4
- import torchvision.transforms as transforms
5
- import matplotlib.pyplot as plt
6
- import pylab
7
- import torch
8
- import numpy as np
9
- import cv2
10
- from samples.CLS2IDX import CLS2IDX
11
- from baselines.ViT.ViT_LRP import vit_base_patch16_224 as vit_LRP
12
- from baselines.ViT.ViT_explanation_generator import LRP
13
-
14
- normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
15
- transform = transforms.Compose([
16
- transforms.Resize(256),
17
- transforms.CenterCrop(224),
18
- transforms.ToTensor(),
19
- normalize,
20
- ])
21
- use_thresholding = False
22
- def show_cam_on_image(img, mask):
23
- heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
24
- heatmap = np.float32(heatmap) / 255
25
- cam = heatmap + np.float32(img)
26
- cam = cam / np.max(cam)
27
- return cam
28
-
29
- # initialize ViT pretrained
30
- model = vit_LRP(pretrained=True).cuda()
31
- model.eval()
32
- attribution_generator = LRP(model)
33
-
34
- def generate_visualization(original_image, class_index=None):
35
- transformer_attribution = attribution_generator.generate_LRP(original_image.unsqueeze(0).cuda(), method="transformer_attribution", index=class_index).detach()
36
- transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14)
37
- transformer_attribution = torch.nn.functional.interpolate(transformer_attribution, scale_factor=16, mode='bilinear')
38
- transformer_attribution = transformer_attribution.reshape(224, 224).data.cpu().numpy()
39
- transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min())
40
-
41
- if use_thresholding:
42
- transformer_attribution = transformer_attribution * 255
43
- transformer_attribution = transformer_attribution.astype(np.uint8)
44
- ret, transformer_attribution = cv2.threshold(transformer_attribution, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
45
- transformer_attribution[transformer_attribution == 255] = 1
46
-
47
- image_transformer_attribution = original_image.permute(1, 2, 0).data.cpu().numpy()
48
- image_transformer_attribution = (image_transformer_attribution - image_transformer_attribution.min()) / (image_transformer_attribution.max() - image_transformer_attribution.min())
49
- vis = show_cam_on_image(image_transformer_attribution, transformer_attribution)
50
- vis = np.uint8(255 * vis)
51
- vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
52
- return vis
53
-
54
-
55
- def print_top_classes(predictions, **kwargs):
56
- # Print Top-5 predictions
57
- prob = torch.softmax(predictions, dim=1)
58
- class_indices = predictions.data.topk(5, dim=1)[1][0].tolist()
59
- max_str_len = 0
60
- class_names = []
61
- for cls_idx in class_indices:
62
- class_names.append(CLS2IDX[cls_idx])
63
- if len(CLS2IDX[cls_idx]) > max_str_len:
64
- max_str_len = len(CLS2IDX[cls_idx])
65
-
66
- print('Top 5 classes:')
67
- for cls_idx in class_indices:
68
- output_string = '\t{} : {}'.format(cls_idx, CLS2IDX[cls_idx])
69
- output_string += ' ' * (max_str_len - len(CLS2IDX[cls_idx])) + '\t\t'
70
- output_string += 'value = {:.3f}\t prob = {:.1f}%'.format(predictions[0, cls_idx], 100 * prob[0, cls_idx])
71
- print(output_string)
72
-
73
-
74
- image = Image.open('samples/dogcat2.png')
75
- dog_cat_image = transform(image)
76
-
77
- fig, axs = plt.subplots(1, 3)
78
- axs[0].imshow(image);
79
- axs[0].axis('off');
80
-
81
- output = model(dog_cat_image.unsqueeze(0).cuda())
82
- print_top_classes(output)
83
-
84
- # cat - the predicted class
85
- cat = generate_visualization(dog_cat_image)
86
-
87
- # dog
88
- # generate visualization for class 243: 'bull mastiff'
89
- dog = generate_visualization(dog_cat_image, class_index=243)
90
-
91
-
92
- axs[1].imshow(cat);
93
- axs[1].axis('off');
94
- axs[2].imshow(dog);
95
- axs[2].axis('off');
96
- pylab.show()