WwYc commited on
Commit
c35cecd
·
verified ·
1 Parent(s): 49ef971

Update ViT_DeiT/baselines/ViT/ViT_explanation_generator.py

Browse files
ViT_DeiT/baselines/ViT/ViT_explanation_generator.py CHANGED
@@ -1,28 +1,46 @@
1
  import argparse
2
- import torch
3
  import numpy as np
 
4
  from numpy import *
5
 
 
6
  # compute rollout between attention layers
7
  def compute_rollout_attention(all_layer_matrices, start_layer=0):
8
  # adding residual consideration- code adapted from https://github.com/samiraabnar/attention_flow
9
  num_tokens = all_layer_matrices[0].shape[1]
10
  batch_size = all_layer_matrices[0].shape[0]
11
- eye = torch.eye(num_tokens).expand(batch_size, num_tokens, num_tokens).to(all_layer_matrices[0].device)
12
- all_layer_matrices = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))]
13
- matrices_aug = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
14
- for i in range(len(all_layer_matrices))]
 
 
 
 
 
 
 
 
15
  joint_attention = matrices_aug[start_layer]
16
- for i in range(start_layer+1, len(matrices_aug)):
17
  joint_attention = matrices_aug[i].bmm(joint_attention)
18
  return joint_attention
19
 
 
20
  class LRP:
21
  def __init__(self, model):
22
  self.model = model
23
  self.model.eval()
24
 
25
- def generate_LRP(self, input, index=None, method="transformer_attribution", is_ablation=False, start_layer=0):
 
 
 
 
 
 
 
26
  output = self.model(input)
27
  kwargs = {"alpha": 1}
28
  if index == None:
@@ -32,14 +50,18 @@ class LRP:
32
  one_hot[0, index] = 1
33
  one_hot_vector = one_hot
34
  one_hot = torch.from_numpy(one_hot).requires_grad_(True)
35
- one_hot = torch.sum(one_hot.cuda() * output)
36
 
37
  self.model.zero_grad()
38
  one_hot.backward(retain_graph=True)
39
 
40
- return self.model.relprop(torch.tensor(one_hot_vector).to(input.device), method=method, is_ablation=is_ablation,
41
- start_layer=start_layer, **kwargs)
42
-
 
 
 
 
43
 
44
 
45
  class Baselines:
@@ -48,14 +70,14 @@ class Baselines:
48
  self.model.eval()
49
 
50
  def generate_cam_attn(self, input, index=None):
51
- output = self.model(input.cuda(), register_hook=True)
52
  if index == None:
53
  index = np.argmax(output.cpu().data.numpy())
54
 
55
  one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
56
  one_hot[0][index] = 1
57
  one_hot = torch.from_numpy(one_hot).requires_grad_(True)
58
- one_hot = torch.sum(one_hot.cuda() * output)
59
 
60
  self.model.zero_grad()
61
  one_hot.backward(retain_graph=True)
@@ -79,5 +101,7 @@ class Baselines:
79
  attn_heads = blk.attn.get_attention_map()
80
  avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach()
81
  all_layer_attentions.append(avg_heads)
82
- rollout = compute_rollout_attention(all_layer_attentions, start_layer=start_layer)
83
- return rollout[:,0, 1:]
 
 
 
1
  import argparse
2
+
3
  import numpy as np
4
+ import torch
5
  from numpy import *
6
 
7
+
8
  # compute rollout between attention layers
9
  def compute_rollout_attention(all_layer_matrices, start_layer=0):
10
  # adding residual consideration- code adapted from https://github.com/samiraabnar/attention_flow
11
  num_tokens = all_layer_matrices[0].shape[1]
12
  batch_size = all_layer_matrices[0].shape[0]
13
+ eye = (
14
+ torch.eye(num_tokens)
15
+ .expand(batch_size, num_tokens, num_tokens)
16
+ .to(all_layer_matrices[0].device)
17
+ )
18
+ all_layer_matrices = [
19
+ all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))
20
+ ]
21
+ matrices_aug = [
22
+ all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
23
+ for i in range(len(all_layer_matrices))
24
+ ]
25
  joint_attention = matrices_aug[start_layer]
26
+ for i in range(start_layer + 1, len(matrices_aug)):
27
  joint_attention = matrices_aug[i].bmm(joint_attention)
28
  return joint_attention
29
 
30
+
31
  class LRP:
32
  def __init__(self, model):
33
  self.model = model
34
  self.model.eval()
35
 
36
+ def generate_LRP(
37
+ self,
38
+ input,
39
+ index=None,
40
+ method="transformer_attribution",
41
+ is_ablation=False,
42
+ start_layer=0,
43
+ ):
44
  output = self.model(input)
45
  kwargs = {"alpha": 1}
46
  if index == None:
 
50
  one_hot[0, index] = 1
51
  one_hot_vector = one_hot
52
  one_hot = torch.from_numpy(one_hot).requires_grad_(True)
53
+ one_hot = torch.sum(one_hot * output)
54
 
55
  self.model.zero_grad()
56
  one_hot.backward(retain_graph=True)
57
 
58
+ return self.model.relprop(
59
+ torch.tensor(one_hot_vector).to(input.device),
60
+ method=method,
61
+ is_ablation=is_ablation,
62
+ start_layer=start_layer,
63
+ **kwargs
64
+ )
65
 
66
 
67
  class Baselines:
 
70
  self.model.eval()
71
 
72
  def generate_cam_attn(self, input, index=None):
73
+ output = self.model(input, register_hook=True)
74
  if index == None:
75
  index = np.argmax(output.cpu().data.numpy())
76
 
77
  one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
78
  one_hot[0][index] = 1
79
  one_hot = torch.from_numpy(one_hot).requires_grad_(True)
80
+ one_hot = torch.sum(one_hot * output)
81
 
82
  self.model.zero_grad()
83
  one_hot.backward(retain_graph=True)
 
101
  attn_heads = blk.attn.get_attention_map()
102
  avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach()
103
  all_layer_attentions.append(avg_heads)
104
+ rollout = compute_rollout_attention(
105
+ all_layer_attentions, start_layer=start_layer
106
+ )
107
+ return rollout[:, 0, 1:]