WwYc commited on
Commit
1f1df39
·
verified ·
1 Parent(s): 5e6183c

Upload 69 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. BERT/.ipynb_checkpoints/Untitled-checkpoint.ipynb +6 -0
  2. BERT/.ipynb_checkpoints/Untitled1-checkpoint.ipynb +6 -0
  3. BERT/BERT-EXPL.py +66 -0
  4. BERT/BERT/config.json +20 -0
  5. BERT/BERT/eval_results_sst-2.txt +3 -0
  6. BERT/BERT/flax_model.msgpack +3 -0
  7. BERT/BERT/gitattributes +9 -0
  8. BERT/BERT/pytorch_model.bin +3 -0
  9. BERT/BERT/special_tokens_map.json +1 -0
  10. BERT/BERT/tokenizer_config.json +1 -0
  11. BERT/BERT/training_args.bin +3 -0
  12. BERT/BERT/vocab.txt +0 -0
  13. BERT/BERT_explainability/modules/BERT/BERT.py +671 -0
  14. BERT/BERT_explainability/modules/BERT/BERT_cls_lrp.py +202 -0
  15. BERT/BERT_explainability/modules/BERT/BERT_orig_lrp.py +671 -0
  16. BERT/BERT_explainability/modules/BERT/BertForSequenceClassification.py +215 -0
  17. BERT/BERT_explainability/modules/BERT/ExplanationGenerator.py +156 -0
  18. BERT/BERT_explainability/modules/BERT/__pycache__/BERT.cpython-38.pyc +0 -0
  19. BERT/BERT_explainability/modules/BERT/__pycache__/BertForSequenceClassification.cpython-38.pyc +0 -0
  20. BERT/BERT_explainability/modules/BERT/__pycache__/ExplanationGenerator.cpython-311.pyc +0 -0
  21. BERT/BERT_explainability/modules/BERT/__pycache__/ExplanationGenerator.cpython-38.pyc +0 -0
  22. BERT/BERT_explainability/modules/__init__.py +0 -0
  23. BERT/BERT_explainability/modules/__pycache__/__init__.cpython-311.pyc +0 -0
  24. BERT/BERT_explainability/modules/__pycache__/__init__.cpython-38.pyc +0 -0
  25. BERT/BERT_explainability/modules/__pycache__/layers_ours.cpython-38.pyc +0 -0
  26. BERT/BERT_explainability/modules/layers_lrp.py +268 -0
  27. BERT/BERT_explainability/modules/layers_ours.py +292 -0
  28. BERT/BERT_params/boolq.json +26 -0
  29. BERT/BERT_params/boolq_baas.json +26 -0
  30. BERT/BERT_params/boolq_bert.json +32 -0
  31. BERT/BERT_params/boolq_soft.json +21 -0
  32. BERT/BERT_params/cose_bert.json +30 -0
  33. BERT/BERT_params/cose_multiclass.json +35 -0
  34. BERT/BERT_params/esnli_bert.json +28 -0
  35. BERT/BERT_params/evidence_inference.json +26 -0
  36. BERT/BERT_params/evidence_inference_bert.json +33 -0
  37. BERT/BERT_params/evidence_inference_soft.json +22 -0
  38. BERT/BERT_params/fever.json +26 -0
  39. BERT/BERT_params/fever_baas.json +25 -0
  40. BERT/BERT_params/fever_bert.json +32 -0
  41. BERT/BERT_params/fever_soft.json +21 -0
  42. BERT/BERT_params/movies.json +26 -0
  43. BERT/BERT_params/movies_baas.json +26 -0
  44. BERT/BERT_params/movies_bert.json +32 -0
  45. BERT/BERT_params/movies_soft.json +21 -0
  46. BERT/BERT_params/multirc.json +26 -0
  47. BERT/BERT_params/multirc_baas.json +26 -0
  48. BERT/BERT_params/multirc_bert.json +32 -0
  49. BERT/BERT_params/multirc_soft.json +21 -0
  50. BERT/BERT_rationale_benchmark/__init__.py +0 -0
BERT/.ipynb_checkpoints/Untitled-checkpoint.ipynb ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [],
3
+ "metadata": {},
4
+ "nbformat": 4,
5
+ "nbformat_minor": 5
6
+ }
BERT/.ipynb_checkpoints/Untitled1-checkpoint.ipynb ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [],
3
+ "metadata": {},
4
+ "nbformat": 4,
5
+ "nbformat_minor": 5
6
+ }
BERT/BERT-EXPL.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from transformers import BertTokenizer
3
+ from BERT_explainability.modules.BERT.ExplanationGenerator import Generator
4
+ from BERT_explainability.modules.BERT.BertForSequenceClassification import BertForSequenceClassification
5
+ from transformers import BertTokenizer
6
+ from BERT_explainability.modules.BERT.ExplanationGenerator import Generator
7
+ from transformers import AutoTokenizer
8
+ from captum.attr import visualization
9
+ import spacy
10
+ import torch
11
+ from IPython.display import Image, HTML, display
12
+
13
+ from sequenceoutput.modeling_output import SequenceClassifierOutput
14
+
15
+ model = BertForSequenceClassification.from_pretrained("./BERT").to("cuda")
16
+ model.eval()
17
+ tokenizer = AutoTokenizer.from_pretrained("./BERT")
18
+ # initialize the explanations generator
19
+ explanations = Generator(model)
20
+
21
+ classifications = ["NEGATIVE", "POSITIVE"]
22
+
23
+ # encode a sentence
24
+ text_batch = ["I hate that I love you."]
25
+ encoding = tokenizer(text_batch, return_tensors='pt')
26
+ input_ids = encoding['input_ids'].to("cuda")
27
+ attention_mask = encoding['attention_mask'].to("cuda")
28
+
29
+ # true class is positive - 1
30
+ true_class = 1
31
+
32
+ # generate an explanation for the input
33
+ target_class = 0
34
+ expl = \
35
+ explanations.generate_LRP(input_ids=input_ids, attention_mask=attention_mask, start_layer=11, index=target_class)[0]
36
+ # normalize scores
37
+ expl = (expl - expl.min()) / (expl.max() - expl.min())
38
+
39
+ # get the model classification
40
+ output = torch.nn.functional.softmax(model(input_ids=input_ids, attention_mask=attention_mask)[0], dim=-1)
41
+ classification = output.argmax(dim=-1).item()
42
+ # get class name
43
+ class_name = classifications[target_class]
44
+ # if the classification is negative, higher explanation scores are more negative
45
+ # flip for visualization
46
+ if class_name == "NEGATIVE":
47
+ expl *= (-1)
48
+ token_importance = {}
49
+ tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten())
50
+ for i in range(len(tokens)):
51
+ token_importance[tokens[i]] = expl[i].item()
52
+ vis_data_records = [visualization.VisualizationDataRecord(
53
+ expl,
54
+ output[0][classification],
55
+ classification,
56
+ true_class,
57
+ true_class,
58
+ 1,
59
+ tokens,
60
+ 1)]
61
+
62
+ html1 = visualization.visualize_text(vis_data_records)
63
+ # print(token_importance, html1)
64
+ # with open('bert-xai.html', 'w+') as f:
65
+ # f.write(str(html1))
66
+ # return token_importance, html1
BERT/BERT/config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertForSequenceClassification"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "finetuning_task": "sst-2",
7
+ "hidden_act": "gelu",
8
+ "hidden_dropout_prob": 0.1,
9
+ "hidden_size": 768,
10
+ "initializer_range": 0.02,
11
+ "intermediate_size": 3072,
12
+ "layer_norm_eps": 1e-12,
13
+ "max_position_embeddings": 512,
14
+ "model_type": "bert",
15
+ "num_attention_heads": 12,
16
+ "num_hidden_layers": 12,
17
+ "pad_token_id": 0,
18
+ "type_vocab_size": 2,
19
+ "vocab_size": 30522
20
+ }
BERT/BERT/eval_results_sst-2.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ eval_loss = 0.2785584788237299
2
+ eval_acc = 0.9243119266055045
3
+ epoch = 3.0
BERT/BERT/flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:57ebdee44ea63b8f3a2a53011dabbd37a7bec8da5d38834beb9751075bb8b821
3
+ size 437942328
BERT/BERT/gitattributes ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
2
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.h5 filter=lfs diff=lfs merge=lfs -text
5
+ *.tflite filter=lfs diff=lfs merge=lfs -text
6
+ *.tar.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.ot filter=lfs diff=lfs merge=lfs -text
8
+ *.onnx filter=lfs diff=lfs merge=lfs -text
9
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
BERT/BERT/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a5f7d1b5618ba58907379af830ee895c8800e3b381286b13d07d90aaf204dc40
3
+ size 437985387
BERT/BERT/special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
BERT/BERT/tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"do_lower_case": true, "model_max_length": 512}
BERT/BERT/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da4b38103a827982f36030842c04dcc7f34bb64cb2f56fa45cc69860836ca5d1
3
+ size 1053
BERT/BERT/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
BERT/BERT_explainability/modules/BERT/BERT.py ADDED
@@ -0,0 +1,671 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+
3
+ import torch
4
+ from torch import nn
5
+ import torch.nn.functional as F
6
+ import math
7
+ from transformers import BertConfig
8
+ from transformers.modeling_outputs import BaseModelOutputWithPooling, BaseModelOutput
9
+ from BERT_explainability.modules.layers_ours import *
10
+ from transformers import (
11
+ BertPreTrainedModel,
12
+ PreTrainedModel,
13
+ )
14
+
15
+ ACT2FN = {
16
+ "relu": ReLU,
17
+ "tanh": Tanh,
18
+ "gelu": GELU,
19
+ }
20
+
21
+
22
+ def get_activation(activation_string):
23
+ if activation_string in ACT2FN:
24
+ return ACT2FN[activation_string]
25
+ else:
26
+ raise KeyError("function {} not found in ACT2FN mapping {}".format(activation_string, list(ACT2FN.keys())))
27
+
28
+ def compute_rollout_attention(all_layer_matrices, start_layer=0):
29
+ # adding residual consideration
30
+ num_tokens = all_layer_matrices[0].shape[1]
31
+ batch_size = all_layer_matrices[0].shape[0]
32
+ eye = torch.eye(num_tokens).expand(batch_size, num_tokens, num_tokens).to(all_layer_matrices[0].device)
33
+ all_layer_matrices = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))]
34
+ all_layer_matrices = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
35
+ for i in range(len(all_layer_matrices))]
36
+ joint_attention = all_layer_matrices[start_layer]
37
+ for i in range(start_layer+1, len(all_layer_matrices)):
38
+ joint_attention = all_layer_matrices[i].bmm(joint_attention)
39
+ return joint_attention
40
+
41
+ class BertEmbeddings(nn.Module):
42
+ """Construct the embeddings from word, position and token_type embeddings."""
43
+
44
+ def __init__(self, config):
45
+ super().__init__()
46
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
47
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
48
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
49
+
50
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
51
+ # any TensorFlow checkpoint file
52
+ self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
53
+ self.dropout = Dropout(config.hidden_dropout_prob)
54
+
55
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
56
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
57
+
58
+ self.add1 = Add()
59
+ self.add2 = Add()
60
+
61
+ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
62
+ if input_ids is not None:
63
+ input_shape = input_ids.size()
64
+ else:
65
+ input_shape = inputs_embeds.size()[:-1]
66
+
67
+ seq_length = input_shape[1]
68
+
69
+ if position_ids is None:
70
+ position_ids = self.position_ids[:, :seq_length]
71
+
72
+ if token_type_ids is None:
73
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
74
+
75
+ if inputs_embeds is None:
76
+ inputs_embeds = self.word_embeddings(input_ids)
77
+ position_embeddings = self.position_embeddings(position_ids)
78
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
79
+
80
+ # embeddings = inputs_embeds + position_embeddings + token_type_embeddings
81
+ embeddings = self.add1([token_type_embeddings, position_embeddings])
82
+ embeddings = self.add2([embeddings, inputs_embeds])
83
+ embeddings = self.LayerNorm(embeddings)
84
+ embeddings = self.dropout(embeddings)
85
+ return embeddings
86
+
87
+ def relprop(self, cam, **kwargs):
88
+ cam = self.dropout.relprop(cam, **kwargs)
89
+ cam = self.LayerNorm.relprop(cam, **kwargs)
90
+
91
+ # [inputs_embeds, position_embeddings, token_type_embeddings]
92
+ (cam) = self.add2.relprop(cam, **kwargs)
93
+
94
+ return cam
95
+
96
+ class BertEncoder(nn.Module):
97
+ def __init__(self, config):
98
+ super().__init__()
99
+ self.config = config
100
+ self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
101
+
102
+ def forward(
103
+ self,
104
+ hidden_states,
105
+ attention_mask=None,
106
+ head_mask=None,
107
+ encoder_hidden_states=None,
108
+ encoder_attention_mask=None,
109
+ output_attentions=False,
110
+ output_hidden_states=False,
111
+ return_dict=False,
112
+ ):
113
+ all_hidden_states = () if output_hidden_states else None
114
+ all_attentions = () if output_attentions else None
115
+ for i, layer_module in enumerate(self.layer):
116
+ if output_hidden_states:
117
+ all_hidden_states = all_hidden_states + (hidden_states,)
118
+
119
+ layer_head_mask = head_mask[i] if head_mask is not None else None
120
+
121
+ if getattr(self.config, "gradient_checkpointing", False):
122
+
123
+ def create_custom_forward(module):
124
+ def custom_forward(*inputs):
125
+ return module(*inputs, output_attentions)
126
+
127
+ return custom_forward
128
+
129
+ layer_outputs = torch.utils.checkpoint.checkpoint(
130
+ create_custom_forward(layer_module),
131
+ hidden_states,
132
+ attention_mask,
133
+ layer_head_mask,
134
+ )
135
+ else:
136
+ layer_outputs = layer_module(
137
+ hidden_states,
138
+ attention_mask,
139
+ layer_head_mask,
140
+ output_attentions,
141
+ )
142
+ hidden_states = layer_outputs[0]
143
+ if output_attentions:
144
+ all_attentions = all_attentions + (layer_outputs[1],)
145
+
146
+ if output_hidden_states:
147
+ all_hidden_states = all_hidden_states + (hidden_states,)
148
+
149
+ if not return_dict:
150
+ return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
151
+ return BaseModelOutput(
152
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
153
+ )
154
+
155
+ def relprop(self, cam, **kwargs):
156
+ # assuming output_hidden_states is False
157
+ for layer_module in reversed(self.layer):
158
+ cam = layer_module.relprop(cam, **kwargs)
159
+ return cam
160
+
161
+ # not adding relprop since this is only pooling at the end of the network, does not impact tokens importance
162
+ class BertPooler(nn.Module):
163
+ def __init__(self, config):
164
+ super().__init__()
165
+ self.dense = Linear(config.hidden_size, config.hidden_size)
166
+ self.activation = Tanh()
167
+ self.pool = IndexSelect()
168
+
169
+ def forward(self, hidden_states):
170
+ # We "pool" the model by simply taking the hidden state corresponding
171
+ # to the first token.
172
+ self._seq_size = hidden_states.shape[1]
173
+
174
+ # first_token_tensor = hidden_states[:, 0]
175
+ first_token_tensor = self.pool(hidden_states, 1, torch.tensor(0, device=hidden_states.device))
176
+ first_token_tensor = first_token_tensor.squeeze(1)
177
+ pooled_output = self.dense(first_token_tensor)
178
+ pooled_output = self.activation(pooled_output)
179
+ return pooled_output
180
+
181
+ def relprop(self, cam, **kwargs):
182
+ cam = self.activation.relprop(cam, **kwargs)
183
+ #print(cam.sum())
184
+ cam = self.dense.relprop(cam, **kwargs)
185
+ #print(cam.sum())
186
+ cam = cam.unsqueeze(1)
187
+ cam = self.pool.relprop(cam, **kwargs)
188
+ #print(cam.sum())
189
+
190
+ return cam
191
+
192
+ class BertAttention(nn.Module):
193
+ def __init__(self, config):
194
+ super().__init__()
195
+ self.self = BertSelfAttention(config)
196
+ self.output = BertSelfOutput(config)
197
+ self.pruned_heads = set()
198
+ self.clone = Clone()
199
+
200
+ def prune_heads(self, heads):
201
+ if len(heads) == 0:
202
+ return
203
+ heads, index = find_pruneable_heads_and_indices(
204
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
205
+ )
206
+
207
+ # Prune linear layers
208
+ self.self.query = prune_linear_layer(self.self.query, index)
209
+ self.self.key = prune_linear_layer(self.self.key, index)
210
+ self.self.value = prune_linear_layer(self.self.value, index)
211
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
212
+
213
+ # Update hyper params and store pruned heads
214
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
215
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
216
+ self.pruned_heads = self.pruned_heads.union(heads)
217
+
218
+ def forward(
219
+ self,
220
+ hidden_states,
221
+ attention_mask=None,
222
+ head_mask=None,
223
+ encoder_hidden_states=None,
224
+ encoder_attention_mask=None,
225
+ output_attentions=False,
226
+ ):
227
+ h1, h2 = self.clone(hidden_states, 2)
228
+ self_outputs = self.self(
229
+ h1,
230
+ attention_mask,
231
+ head_mask,
232
+ encoder_hidden_states,
233
+ encoder_attention_mask,
234
+ output_attentions,
235
+ )
236
+ attention_output = self.output(self_outputs[0], h2)
237
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
238
+ return outputs
239
+
240
+ def relprop(self, cam, **kwargs):
241
+ # assuming that we don't ouput the attentions (outputs = (attention_output,)), self_outputs=(context_layer,)
242
+ (cam1, cam2) = self.output.relprop(cam, **kwargs)
243
+ #print(cam1.sum(), cam2.sum(), (cam1 + cam2).sum())
244
+ cam1 = self.self.relprop(cam1, **kwargs)
245
+ #print(cam1.sum(), cam2.sum(), (cam1 + cam2).sum())
246
+
247
+ return self.clone.relprop((cam1, cam2), **kwargs)
248
+
249
+ class BertSelfAttention(nn.Module):
250
+ def __init__(self, config):
251
+ super().__init__()
252
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
253
+ raise ValueError(
254
+ "The hidden size (%d) is not a multiple of the number of attention "
255
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
256
+ )
257
+
258
+ self.num_attention_heads = config.num_attention_heads
259
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
260
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
261
+
262
+ self.query = Linear(config.hidden_size, self.all_head_size)
263
+ self.key = Linear(config.hidden_size, self.all_head_size)
264
+ self.value = Linear(config.hidden_size, self.all_head_size)
265
+
266
+ self.dropout = Dropout(config.attention_probs_dropout_prob)
267
+
268
+ self.matmul1 = MatMul()
269
+ self.matmul2 = MatMul()
270
+ self.softmax = Softmax(dim=-1)
271
+ self.add = Add()
272
+ self.mul = Mul()
273
+ self.head_mask = None
274
+ self.attention_mask = None
275
+ self.clone = Clone()
276
+
277
+ self.attn_cam = None
278
+ self.attn = None
279
+ self.attn_gradients = None
280
+
281
+ def get_attn(self):
282
+ return self.attn
283
+
284
+ def save_attn(self, attn):
285
+ self.attn = attn
286
+
287
+ def save_attn_cam(self, cam):
288
+ self.attn_cam = cam
289
+
290
+ def get_attn_cam(self):
291
+ return self.attn_cam
292
+
293
+ def save_attn_gradients(self, attn_gradients):
294
+ self.attn_gradients = attn_gradients
295
+
296
+ def get_attn_gradients(self):
297
+ return self.attn_gradients
298
+
299
+ def transpose_for_scores(self, x):
300
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
301
+ x = x.view(*new_x_shape)
302
+ return x.permute(0, 2, 1, 3)
303
+
304
+ def transpose_for_scores_relprop(self, x):
305
+ return x.permute(0, 2, 1, 3).flatten(2)
306
+
307
+ def forward(
308
+ self,
309
+ hidden_states,
310
+ attention_mask=None,
311
+ head_mask=None,
312
+ encoder_hidden_states=None,
313
+ encoder_attention_mask=None,
314
+ output_attentions=False,
315
+ ):
316
+ self.head_mask = head_mask
317
+ self.attention_mask = attention_mask
318
+
319
+ h1, h2, h3 = self.clone(hidden_states, 3)
320
+ mixed_query_layer = self.query(h1)
321
+
322
+ # If this is instantiated as a cross-attention module, the keys
323
+ # and values come from an encoder; the attention mask needs to be
324
+ # such that the encoder's padding tokens are not attended to.
325
+ if encoder_hidden_states is not None:
326
+ mixed_key_layer = self.key(encoder_hidden_states)
327
+ mixed_value_layer = self.value(encoder_hidden_states)
328
+ attention_mask = encoder_attention_mask
329
+ else:
330
+ mixed_key_layer = self.key(h2)
331
+ mixed_value_layer = self.value(h3)
332
+
333
+ query_layer = self.transpose_for_scores(mixed_query_layer)
334
+ key_layer = self.transpose_for_scores(mixed_key_layer)
335
+ value_layer = self.transpose_for_scores(mixed_value_layer)
336
+
337
+ # Take the dot product between "query" and "key" to get the raw attention scores.
338
+ attention_scores = self.matmul1([query_layer, key_layer.transpose(-1, -2)])
339
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
340
+ if attention_mask is not None:
341
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
342
+ attention_scores = self.add([attention_scores, attention_mask])
343
+
344
+ # Normalize the attention scores to probabilities.
345
+ attention_probs = self.softmax(attention_scores)
346
+
347
+ self.save_attn(attention_probs)
348
+ attention_probs.register_hook(self.save_attn_gradients)
349
+
350
+ # This is actually dropping out entire tokens to attend to, which might
351
+ # seem a bit unusual, but is taken from the original Transformer paper.
352
+ attention_probs = self.dropout(attention_probs)
353
+
354
+ # Mask heads if we want to
355
+ if head_mask is not None:
356
+ attention_probs = attention_probs * head_mask
357
+
358
+ context_layer = self.matmul2([attention_probs, value_layer])
359
+
360
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
361
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
362
+ context_layer = context_layer.view(*new_context_layer_shape)
363
+
364
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
365
+ return outputs
366
+
367
+ def relprop(self, cam, **kwargs):
368
+ # Assume output_attentions == False
369
+ cam = self.transpose_for_scores(cam)
370
+
371
+ # [attention_probs, value_layer]
372
+ (cam1, cam2) = self.matmul2.relprop(cam, **kwargs)
373
+ cam1 /= 2
374
+ cam2 /= 2
375
+ if self.head_mask is not None:
376
+ # [attention_probs, head_mask]
377
+ (cam1, _)= self.mul.relprop(cam1, **kwargs)
378
+
379
+
380
+ self.save_attn_cam(cam1)
381
+
382
+ cam1 = self.dropout.relprop(cam1, **kwargs)
383
+
384
+ cam1 = self.softmax.relprop(cam1, **kwargs)
385
+
386
+ if self.attention_mask is not None:
387
+ # [attention_scores, attention_mask]
388
+ (cam1, _) = self.add.relprop(cam1, **kwargs)
389
+
390
+ # [query_layer, key_layer.transpose(-1, -2)]
391
+ (cam1_1, cam1_2) = self.matmul1.relprop(cam1, **kwargs)
392
+ cam1_1 /= 2
393
+ cam1_2 /= 2
394
+
395
+ # query
396
+ cam1_1 = self.transpose_for_scores_relprop(cam1_1)
397
+ cam1_1 = self.query.relprop(cam1_1, **kwargs)
398
+
399
+ # key
400
+ cam1_2 = self.transpose_for_scores_relprop(cam1_2.transpose(-1, -2))
401
+ cam1_2 = self.key.relprop(cam1_2, **kwargs)
402
+
403
+ # value
404
+ cam2 = self.transpose_for_scores_relprop(cam2)
405
+ cam2 = self.value.relprop(cam2, **kwargs)
406
+
407
+ cam = self.clone.relprop((cam1_1, cam1_2, cam2), **kwargs)
408
+
409
+ return cam
410
+
411
+
412
+ class BertSelfOutput(nn.Module):
413
+ def __init__(self, config):
414
+ super().__init__()
415
+ self.dense = Linear(config.hidden_size, config.hidden_size)
416
+ self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
417
+ self.dropout = Dropout(config.hidden_dropout_prob)
418
+ self.add = Add()
419
+
420
+ def forward(self, hidden_states, input_tensor):
421
+ hidden_states = self.dense(hidden_states)
422
+ hidden_states = self.dropout(hidden_states)
423
+ add = self.add([hidden_states, input_tensor])
424
+ hidden_states = self.LayerNorm(add)
425
+ return hidden_states
426
+
427
+ def relprop(self, cam, **kwargs):
428
+ cam = self.LayerNorm.relprop(cam, **kwargs)
429
+ # [hidden_states, input_tensor]
430
+ (cam1, cam2) = self.add.relprop(cam, **kwargs)
431
+ cam1 = self.dropout.relprop(cam1, **kwargs)
432
+ cam1 = self.dense.relprop(cam1, **kwargs)
433
+
434
+ return (cam1, cam2)
435
+
436
+
437
+ class BertIntermediate(nn.Module):
438
+ def __init__(self, config):
439
+ super().__init__()
440
+ self.dense = Linear(config.hidden_size, config.intermediate_size)
441
+ if isinstance(config.hidden_act, str):
442
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]()
443
+ else:
444
+ self.intermediate_act_fn = config.hidden_act
445
+
446
+ def forward(self, hidden_states):
447
+ hidden_states = self.dense(hidden_states)
448
+ hidden_states = self.intermediate_act_fn(hidden_states)
449
+ return hidden_states
450
+
451
+ def relprop(self, cam, **kwargs):
452
+ cam = self.intermediate_act_fn.relprop(cam, **kwargs) # FIXME only ReLU
453
+ #print(cam.sum())
454
+ cam = self.dense.relprop(cam, **kwargs)
455
+ #print(cam.sum())
456
+ return cam
457
+
458
+
459
+ class BertOutput(nn.Module):
460
+ def __init__(self, config):
461
+ super().__init__()
462
+ self.dense = Linear(config.intermediate_size, config.hidden_size)
463
+ self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
464
+ self.dropout = Dropout(config.hidden_dropout_prob)
465
+ self.add = Add()
466
+
467
+ def forward(self, hidden_states, input_tensor):
468
+ hidden_states = self.dense(hidden_states)
469
+ hidden_states = self.dropout(hidden_states)
470
+ add = self.add([hidden_states, input_tensor])
471
+ hidden_states = self.LayerNorm(add)
472
+ return hidden_states
473
+
474
+ def relprop(self, cam, **kwargs):
475
+ # print("in", cam.sum())
476
+ cam = self.LayerNorm.relprop(cam, **kwargs)
477
+ #print(cam.sum())
478
+ # [hidden_states, input_tensor]
479
+ (cam1, cam2)= self.add.relprop(cam, **kwargs)
480
+ # print("add", cam1.sum(), cam2.sum(), cam1.sum() + cam2.sum())
481
+ cam1 = self.dropout.relprop(cam1, **kwargs)
482
+ #print(cam1.sum())
483
+ cam1 = self.dense.relprop(cam1, **kwargs)
484
+ # print("dense", cam1.sum())
485
+
486
+ # print("out", cam1.sum() + cam2.sum(), cam1.sum(), cam2.sum())
487
+ return (cam1, cam2)
488
+
489
+
490
+ class BertLayer(nn.Module):
491
+ def __init__(self, config):
492
+ super().__init__()
493
+ self.attention = BertAttention(config)
494
+ self.intermediate = BertIntermediate(config)
495
+ self.output = BertOutput(config)
496
+ self.clone = Clone()
497
+
498
+ def forward(
499
+ self,
500
+ hidden_states,
501
+ attention_mask=None,
502
+ head_mask=None,
503
+ output_attentions=False,
504
+ ):
505
+ self_attention_outputs = self.attention(
506
+ hidden_states,
507
+ attention_mask,
508
+ head_mask,
509
+ output_attentions=output_attentions,
510
+ )
511
+ attention_output = self_attention_outputs[0]
512
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
513
+
514
+ ao1, ao2 = self.clone(attention_output, 2)
515
+ intermediate_output = self.intermediate(ao1)
516
+ layer_output = self.output(intermediate_output, ao2)
517
+
518
+ outputs = (layer_output,) + outputs
519
+ return outputs
520
+
521
+ def relprop(self, cam, **kwargs):
522
+ (cam1, cam2) = self.output.relprop(cam, **kwargs)
523
+ # print("output", cam1.sum(), cam2.sum(), cam1.sum() + cam2.sum())
524
+ cam1 = self.intermediate.relprop(cam1, **kwargs)
525
+ # print("intermediate", cam1.sum())
526
+ cam = self.clone.relprop((cam1, cam2), **kwargs)
527
+ # print("clone", cam.sum())
528
+ cam = self.attention.relprop(cam, **kwargs)
529
+ # print("attention", cam.sum())
530
+ return cam
531
+
532
+
533
+ class BertModel(BertPreTrainedModel):
534
+ def __init__(self, config):
535
+ super().__init__(config)
536
+ self.config = config
537
+
538
+ self.embeddings = BertEmbeddings(config)
539
+ self.encoder = BertEncoder(config)
540
+ self.pooler = BertPooler(config)
541
+
542
+ self.init_weights()
543
+
544
+ def get_input_embeddings(self):
545
+ return self.embeddings.word_embeddings
546
+
547
+ def set_input_embeddings(self, value):
548
+ self.embeddings.word_embeddings = value
549
+
550
+ def forward(
551
+ self,
552
+ input_ids=None,
553
+ attention_mask=None,
554
+ token_type_ids=None,
555
+ position_ids=None,
556
+ head_mask=None,
557
+ inputs_embeds=None,
558
+ encoder_hidden_states=None,
559
+ encoder_attention_mask=None,
560
+ output_attentions=None,
561
+ output_hidden_states=None,
562
+ return_dict=None,
563
+ ):
564
+ r"""
565
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
566
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
567
+ if the model is configured as a decoder.
568
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
569
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask
570
+ is used in the cross-attention if the model is configured as a decoder.
571
+ Mask values selected in ``[0, 1]``:
572
+ ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
573
+ """
574
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
575
+ output_hidden_states = (
576
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
577
+ )
578
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
579
+
580
+ if input_ids is not None and inputs_embeds is not None:
581
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
582
+ elif input_ids is not None:
583
+ input_shape = input_ids.size()
584
+ elif inputs_embeds is not None:
585
+ input_shape = inputs_embeds.size()[:-1]
586
+ else:
587
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
588
+
589
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
590
+
591
+ if attention_mask is None:
592
+ attention_mask = torch.ones(input_shape, device=device)
593
+ if token_type_ids is None:
594
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
595
+
596
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
597
+ # ourselves in which case we just need to make it broadcastable to all heads.
598
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
599
+
600
+ # If a 2D or 3D attention mask is provided for the cross-attention
601
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
602
+ if self.config.is_decoder and encoder_hidden_states is not None:
603
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
604
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
605
+ if encoder_attention_mask is None:
606
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
607
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
608
+ else:
609
+ encoder_extended_attention_mask = None
610
+
611
+ # Prepare head mask if needed
612
+ # 1.0 in head_mask indicate we keep the head
613
+ # attention_probs has shape bsz x n_heads x N x N
614
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
615
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
616
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
617
+
618
+ embedding_output = self.embeddings(
619
+ input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
620
+ )
621
+
622
+ encoder_outputs = self.encoder(
623
+ embedding_output,
624
+ attention_mask=extended_attention_mask,
625
+ head_mask=head_mask,
626
+ encoder_hidden_states=encoder_hidden_states,
627
+ encoder_attention_mask=encoder_extended_attention_mask,
628
+ output_attentions=output_attentions,
629
+ output_hidden_states=output_hidden_states,
630
+ return_dict=return_dict,
631
+ )
632
+ sequence_output = encoder_outputs[0]
633
+ pooled_output = self.pooler(sequence_output)
634
+
635
+ if not return_dict:
636
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
637
+
638
+ return BaseModelOutputWithPooling(
639
+ last_hidden_state=sequence_output,
640
+ pooler_output=pooled_output,
641
+ hidden_states=encoder_outputs.hidden_states,
642
+ attentions=encoder_outputs.attentions,
643
+ )
644
+
645
+ def relprop(self, cam, **kwargs):
646
+ cam = self.pooler.relprop(cam, **kwargs)
647
+ # print("111111111111",cam.sum())
648
+ cam = self.encoder.relprop(cam, **kwargs)
649
+ # print("222222222222222", cam.sum())
650
+ # print("conservation: ", cam.sum())
651
+ return cam
652
+
653
+
654
+ if __name__ == '__main__':
655
+ class Config:
656
+ def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob):
657
+ self.hidden_size = hidden_size
658
+ self.num_attention_heads = num_attention_heads
659
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
660
+
661
+ model = BertSelfAttention(Config(1024, 4, 0.1))
662
+ x = torch.rand(2, 20, 1024)
663
+ x.requires_grad_()
664
+
665
+ model.eval()
666
+
667
+ y = model.forward(x)
668
+
669
+ relprop = model.relprop(torch.rand(2, 20, 1024), (torch.rand(2, 20, 1024),))
670
+
671
+ print(relprop[1][0].shape)
BERT/BERT_explainability/modules/BERT/BERT_cls_lrp.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertPreTrainedModel
2
+ from transformers.utils import logging
3
+ from BERT_explainability.modules.layers_lrp import *
4
+ from BERT_explainability.modules.BERT.BERT_orig_lrp import BertModel
5
+ from torch.nn import CrossEntropyLoss, MSELoss
6
+ import torch.nn as nn
7
+ from typing import List, Any
8
+ import torch
9
+ from BERT_rationale_benchmark.models.model_utils import PaddedSequence
10
+
11
+
12
+ class BertForSequenceClassification(BertPreTrainedModel):
13
+ def __init__(self, config):
14
+ super().__init__(config)
15
+ self.num_labels = config.num_labels
16
+
17
+ self.bert = BertModel(config)
18
+ self.dropout = Dropout(config.hidden_dropout_prob)
19
+ self.classifier = Linear(config.hidden_size, config.num_labels)
20
+
21
+ self.init_weights()
22
+
23
+ def forward(
24
+ self,
25
+ input_ids=None,
26
+ attention_mask=None,
27
+ token_type_ids=None,
28
+ position_ids=None,
29
+ head_mask=None,
30
+ inputs_embeds=None,
31
+ labels=None,
32
+ output_attentions=None,
33
+ output_hidden_states=None,
34
+ return_dict=None,
35
+ ):
36
+ r"""
37
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
38
+ Labels for computing the sequence classification/regression loss.
39
+ Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
40
+ If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
41
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
42
+ """
43
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
44
+
45
+ outputs = self.bert(
46
+ input_ids,
47
+ attention_mask=attention_mask,
48
+ token_type_ids=token_type_ids,
49
+ position_ids=position_ids,
50
+ head_mask=head_mask,
51
+ inputs_embeds=inputs_embeds,
52
+ output_attentions=output_attentions,
53
+ output_hidden_states=output_hidden_states,
54
+ return_dict=return_dict,
55
+ )
56
+
57
+ pooled_output = outputs[1]
58
+
59
+ pooled_output = self.dropout(pooled_output)
60
+ logits = self.classifier(pooled_output)
61
+
62
+ loss = None
63
+ if labels is not None:
64
+ if self.num_labels == 1:
65
+ # We are doing regression
66
+ loss_fct = MSELoss()
67
+ loss = loss_fct(logits.view(-1), labels.view(-1))
68
+ else:
69
+ loss_fct = CrossEntropyLoss()
70
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
71
+
72
+ if not return_dict:
73
+ output = (logits,) + outputs[2:]
74
+ return ((loss,) + output) if loss is not None else output
75
+
76
+ return SequenceClassifierOutput(
77
+ loss=loss,
78
+ logits=logits,
79
+ hidden_states=outputs.hidden_states,
80
+ attentions=outputs.attentions,
81
+ )
82
+
83
+ def relprop(self, cam=None, **kwargs):
84
+ cam = self.classifier.relprop(cam, **kwargs)
85
+ cam = self.dropout.relprop(cam, **kwargs)
86
+ cam = self.bert.relprop(cam, **kwargs)
87
+ return cam
88
+
89
+
90
+ # this is the actual classifier we will be using
91
+ class BertClassifier(nn.Module):
92
+ """Thin wrapper around BertForSequenceClassification"""
93
+
94
+ def __init__(self,
95
+ bert_dir: str,
96
+ pad_token_id: int,
97
+ cls_token_id: int,
98
+ sep_token_id: int,
99
+ num_labels: int,
100
+ max_length: int = 512,
101
+ use_half_precision=True):
102
+ super(BertClassifier, self).__init__()
103
+ bert = BertForSequenceClassification.from_pretrained(bert_dir, num_labels=num_labels)
104
+ if use_half_precision:
105
+ import apex
106
+ bert = bert.half()
107
+ self.bert = bert
108
+ self.pad_token_id = pad_token_id
109
+ self.cls_token_id = cls_token_id
110
+ self.sep_token_id = sep_token_id
111
+ self.max_length = max_length
112
+
113
+ def forward(self,
114
+ query: List[torch.tensor],
115
+ docids: List[Any],
116
+ document_batch: List[torch.tensor]):
117
+ assert len(query) == len(document_batch)
118
+ print(query)
119
+ # note about device management:
120
+ # since distributed training is enabled, the inputs to this module can be on *any* device (preferably cpu, since we wrap and unwrap the module)
121
+ # we want to keep these params on the input device (assuming CPU) for as long as possible for cheap memory access
122
+ target_device = next(self.parameters()).device
123
+ cls_token = torch.tensor([self.cls_token_id]).to(device=document_batch[0].device)
124
+ sep_token = torch.tensor([self.sep_token_id]).to(device=document_batch[0].device)
125
+ input_tensors = []
126
+ position_ids = []
127
+ for q, d in zip(query, document_batch):
128
+ if len(q) + len(d) + 2 > self.max_length:
129
+ d = d[:(self.max_length - len(q) - 2)]
130
+ input_tensors.append(torch.cat([cls_token, q, sep_token, d]))
131
+ position_ids.append(torch.tensor(list(range(0, len(q) + 1)) + list(range(0, len(d) + 1))))
132
+ bert_input = PaddedSequence.autopad(input_tensors, batch_first=True, padding_value=self.pad_token_id,
133
+ device=target_device)
134
+ positions = PaddedSequence.autopad(position_ids, batch_first=True, padding_value=0, device=target_device)
135
+ (classes,) = self.bert(bert_input.data,
136
+ attention_mask=bert_input.mask(on=0.0, off=float('-inf'), device=target_device),
137
+ position_ids=positions.data)
138
+ assert torch.all(classes == classes) # for nans
139
+
140
+ print(input_tensors[0])
141
+ print(self.relprop()[0])
142
+
143
+ return classes
144
+
145
+ def relprop(self, cam=None, **kwargs):
146
+ return self.bert.relprop(cam, **kwargs)
147
+
148
+
149
+ if __name__ == '__main__':
150
+ from transformers import BertTokenizer
151
+ import os
152
+
153
+ class Config:
154
+ def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, num_labels,
155
+ hidden_dropout_prob):
156
+ self.hidden_size = hidden_size
157
+ self.num_attention_heads = num_attention_heads
158
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
159
+ self.num_labels = num_labels
160
+ self.hidden_dropout_prob = hidden_dropout_prob
161
+
162
+
163
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
164
+ x = tokenizer.encode_plus("In this movie the acting is great. The movie is perfect! [sep]",
165
+ add_special_tokens=True,
166
+ max_length=512,
167
+ return_token_type_ids=False,
168
+ return_attention_mask=True,
169
+ pad_to_max_length=True,
170
+ return_tensors='pt',
171
+ truncation=True)
172
+
173
+ print(x['input_ids'])
174
+
175
+ model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
176
+ model_save_file = os.path.join('./BERT_explainability/output_bert/movies/classifier/', 'classifier.pt')
177
+ model.load_state_dict(torch.load(model_save_file))
178
+
179
+ # x = torch.randint(100, (2, 20))
180
+ # x = torch.tensor([[101, 2054, 2003, 1996, 15792, 1997, 2023, 3319, 1029, 102,
181
+ # 101, 4079, 102, 101, 6732, 102, 101, 2643, 102, 101,
182
+ # 2038, 102, 101, 1037, 102, 101, 2933, 102, 101, 2005,
183
+ # 102, 101, 2032, 102, 101, 1010, 102, 101, 1037, 102,
184
+ # 101, 3800, 102, 101, 2005, 102, 101, 2010, 102, 101,
185
+ # 2166, 102, 101, 1010, 102, 101, 1998, 102, 101, 2010,
186
+ # 102, 101, 4650, 102, 101, 1010, 102, 101, 2002, 102,
187
+ # 101, 2074, 102, 101, 2515, 102, 101, 1050, 102, 101,
188
+ # 1005, 102, 101, 1056, 102, 101, 2113, 102, 101, 2054,
189
+ # 102, 101, 1012, 102]])
190
+ # x.requires_grad_()
191
+
192
+ model.eval()
193
+
194
+ y = model(x['input_ids'], x['attention_mask'])
195
+ print(y)
196
+
197
+ cam, _ = model.relprop()
198
+
199
+ #print(cam.shape)
200
+
201
+ cam = cam.sum(-1)
202
+ #print(cam)
BERT/BERT_explainability/modules/BERT/BERT_orig_lrp.py ADDED
@@ -0,0 +1,671 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+
3
+ import torch
4
+ from torch import nn
5
+ import torch.nn.functional as F
6
+ import math
7
+ from transformers import BertConfig
8
+ from transformers.modeling_outputs import BaseModelOutputWithPooling, BaseModelOutput
9
+ from BERT_explainability.modules.layers_lrp import *
10
+ from transformers import (
11
+ BertPreTrainedModel,
12
+ PreTrainedModel,
13
+ )
14
+
15
+ ACT2FN = {
16
+ "relu": ReLU,
17
+ "tanh": Tanh,
18
+ "gelu": GELU,
19
+ }
20
+
21
+
22
+ def get_activation(activation_string):
23
+ if activation_string in ACT2FN:
24
+ return ACT2FN[activation_string]
25
+ else:
26
+ raise KeyError("function {} not found in ACT2FN mapping {}".format(activation_string, list(ACT2FN.keys())))
27
+
28
+ def compute_rollout_attention(all_layer_matrices, start_layer=0):
29
+ # adding residual consideration
30
+ num_tokens = all_layer_matrices[0].shape[1]
31
+ batch_size = all_layer_matrices[0].shape[0]
32
+ eye = torch.eye(num_tokens).expand(batch_size, num_tokens, num_tokens).to(all_layer_matrices[0].device)
33
+ all_layer_matrices = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))]
34
+ all_layer_matrices = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
35
+ for i in range(len(all_layer_matrices))]
36
+ joint_attention = all_layer_matrices[start_layer]
37
+ for i in range(start_layer+1, len(all_layer_matrices)):
38
+ joint_attention = all_layer_matrices[i].bmm(joint_attention)
39
+ return joint_attention
40
+
41
+ class BertEmbeddings(nn.Module):
42
+ """Construct the embeddings from word, position and token_type embeddings."""
43
+
44
+ def __init__(self, config):
45
+ super().__init__()
46
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
47
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
48
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
49
+
50
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
51
+ # any TensorFlow checkpoint file
52
+ self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
53
+ self.dropout = Dropout(config.hidden_dropout_prob)
54
+
55
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
56
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
57
+
58
+ self.add1 = Add()
59
+ self.add2 = Add()
60
+
61
+ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
62
+ if input_ids is not None:
63
+ input_shape = input_ids.size()
64
+ else:
65
+ input_shape = inputs_embeds.size()[:-1]
66
+
67
+ seq_length = input_shape[1]
68
+
69
+ if position_ids is None:
70
+ position_ids = self.position_ids[:, :seq_length]
71
+
72
+ if token_type_ids is None:
73
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
74
+
75
+ if inputs_embeds is None:
76
+ inputs_embeds = self.word_embeddings(input_ids)
77
+ position_embeddings = self.position_embeddings(position_ids)
78
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
79
+
80
+ # embeddings = inputs_embeds + position_embeddings + token_type_embeddings
81
+ embeddings = self.add1([token_type_embeddings, position_embeddings])
82
+ embeddings = self.add2([embeddings, inputs_embeds])
83
+ embeddings = self.LayerNorm(embeddings)
84
+ embeddings = self.dropout(embeddings)
85
+ return embeddings
86
+
87
+ def relprop(self, cam, **kwargs):
88
+ cam = self.dropout.relprop(cam, **kwargs)
89
+ cam = self.LayerNorm.relprop(cam, **kwargs)
90
+
91
+ # [inputs_embeds, position_embeddings, token_type_embeddings]
92
+ (cam) = self.add2.relprop(cam, **kwargs)
93
+
94
+ return cam
95
+
96
+ class BertEncoder(nn.Module):
97
+ def __init__(self, config):
98
+ super().__init__()
99
+ self.config = config
100
+ self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
101
+
102
+ def forward(
103
+ self,
104
+ hidden_states,
105
+ attention_mask=None,
106
+ head_mask=None,
107
+ encoder_hidden_states=None,
108
+ encoder_attention_mask=None,
109
+ output_attentions=False,
110
+ output_hidden_states=False,
111
+ return_dict=False,
112
+ ):
113
+ all_hidden_states = () if output_hidden_states else None
114
+ all_attentions = () if output_attentions else None
115
+ for i, layer_module in enumerate(self.layer):
116
+ if output_hidden_states:
117
+ all_hidden_states = all_hidden_states + (hidden_states,)
118
+
119
+ layer_head_mask = head_mask[i] if head_mask is not None else None
120
+
121
+ if getattr(self.config, "gradient_checkpointing", False):
122
+
123
+ def create_custom_forward(module):
124
+ def custom_forward(*inputs):
125
+ return module(*inputs, output_attentions)
126
+
127
+ return custom_forward
128
+
129
+ layer_outputs = torch.utils.checkpoint.checkpoint(
130
+ create_custom_forward(layer_module),
131
+ hidden_states,
132
+ attention_mask,
133
+ layer_head_mask,
134
+ )
135
+ else:
136
+ layer_outputs = layer_module(
137
+ hidden_states,
138
+ attention_mask,
139
+ layer_head_mask,
140
+ output_attentions,
141
+ )
142
+ hidden_states = layer_outputs[0]
143
+ if output_attentions:
144
+ all_attentions = all_attentions + (layer_outputs[1],)
145
+
146
+ if output_hidden_states:
147
+ all_hidden_states = all_hidden_states + (hidden_states,)
148
+
149
+ if not return_dict:
150
+ return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
151
+ return BaseModelOutput(
152
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
153
+ )
154
+
155
+ def relprop(self, cam, **kwargs):
156
+ # assuming output_hidden_states is False
157
+ for layer_module in reversed(self.layer):
158
+ cam = layer_module.relprop(cam, **kwargs)
159
+ return cam
160
+
161
+ # not adding relprop since this is only pooling at the end of the network, does not impact tokens importance
162
+ class BertPooler(nn.Module):
163
+ def __init__(self, config):
164
+ super().__init__()
165
+ self.dense = Linear(config.hidden_size, config.hidden_size)
166
+ self.activation = Tanh()
167
+ self.pool = IndexSelect()
168
+
169
+ def forward(self, hidden_states):
170
+ # We "pool" the model by simply taking the hidden state corresponding
171
+ # to the first token.
172
+ self._seq_size = hidden_states.shape[1]
173
+
174
+ # first_token_tensor = hidden_states[:, 0]
175
+ first_token_tensor = self.pool(hidden_states, 1, torch.tensor(0, device=hidden_states.device))
176
+ first_token_tensor = first_token_tensor.squeeze(1)
177
+ pooled_output = self.dense(first_token_tensor)
178
+ pooled_output = self.activation(pooled_output)
179
+ return pooled_output
180
+
181
+ def relprop(self, cam, **kwargs):
182
+ cam = self.activation.relprop(cam, **kwargs)
183
+ #print(cam.sum())
184
+ cam = self.dense.relprop(cam, **kwargs)
185
+ #print(cam.sum())
186
+ cam = cam.unsqueeze(1)
187
+ cam = self.pool.relprop(cam, **kwargs)
188
+ #print(cam.sum())
189
+
190
+ return cam
191
+
192
+ class BertAttention(nn.Module):
193
+ def __init__(self, config):
194
+ super().__init__()
195
+ self.self = BertSelfAttention(config)
196
+ self.output = BertSelfOutput(config)
197
+ self.pruned_heads = set()
198
+ self.clone = Clone()
199
+
200
+ def prune_heads(self, heads):
201
+ if len(heads) == 0:
202
+ return
203
+ heads, index = find_pruneable_heads_and_indices(
204
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
205
+ )
206
+
207
+ # Prune linear layers
208
+ self.self.query = prune_linear_layer(self.self.query, index)
209
+ self.self.key = prune_linear_layer(self.self.key, index)
210
+ self.self.value = prune_linear_layer(self.self.value, index)
211
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
212
+
213
+ # Update hyper params and store pruned heads
214
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
215
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
216
+ self.pruned_heads = self.pruned_heads.union(heads)
217
+
218
+ def forward(
219
+ self,
220
+ hidden_states,
221
+ attention_mask=None,
222
+ head_mask=None,
223
+ encoder_hidden_states=None,
224
+ encoder_attention_mask=None,
225
+ output_attentions=False,
226
+ ):
227
+ h1, h2 = self.clone(hidden_states, 2)
228
+ self_outputs = self.self(
229
+ h1,
230
+ attention_mask,
231
+ head_mask,
232
+ encoder_hidden_states,
233
+ encoder_attention_mask,
234
+ output_attentions,
235
+ )
236
+ attention_output = self.output(self_outputs[0], h2)
237
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
238
+ return outputs
239
+
240
+ def relprop(self, cam, **kwargs):
241
+ # assuming that we don't ouput the attentions (outputs = (attention_output,)), self_outputs=(context_layer,)
242
+ (cam1, cam2) = self.output.relprop(cam, **kwargs)
243
+ #print(cam1.sum(), cam2.sum(), (cam1 + cam2).sum())
244
+ cam1 = self.self.relprop(cam1, **kwargs)
245
+ #print(cam1.sum(), cam2.sum(), (cam1 + cam2).sum())
246
+
247
+ return self.clone.relprop((cam1, cam2), **kwargs)
248
+
249
+ class BertSelfAttention(nn.Module):
250
+ def __init__(self, config):
251
+ super().__init__()
252
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
253
+ raise ValueError(
254
+ "The hidden size (%d) is not a multiple of the number of attention "
255
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
256
+ )
257
+
258
+ self.num_attention_heads = config.num_attention_heads
259
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
260
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
261
+
262
+ self.query = Linear(config.hidden_size, self.all_head_size)
263
+ self.key = Linear(config.hidden_size, self.all_head_size)
264
+ self.value = Linear(config.hidden_size, self.all_head_size)
265
+
266
+ self.dropout = Dropout(config.attention_probs_dropout_prob)
267
+
268
+ self.matmul1 = MatMul()
269
+ self.matmul2 = MatMul()
270
+ self.softmax = Softmax(dim=-1)
271
+ self.add = Add()
272
+ self.mul = Mul()
273
+ self.head_mask = None
274
+ self.attention_mask = None
275
+ self.clone = Clone()
276
+
277
+ self.attn_cam = None
278
+ self.attn = None
279
+ self.attn_gradients = None
280
+
281
+ def get_attn(self):
282
+ return self.attn
283
+
284
+ def save_attn(self, attn):
285
+ self.attn = attn
286
+
287
+ def save_attn_cam(self, cam):
288
+ self.attn_cam = cam
289
+
290
+ def get_attn_cam(self):
291
+ return self.attn_cam
292
+
293
+ def save_attn_gradients(self, attn_gradients):
294
+ self.attn_gradients = attn_gradients
295
+
296
+ def get_attn_gradients(self):
297
+ return self.attn_gradients
298
+
299
+ def transpose_for_scores(self, x):
300
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
301
+ x = x.view(*new_x_shape)
302
+ return x.permute(0, 2, 1, 3)
303
+
304
+ def transpose_for_scores_relprop(self, x):
305
+ return x.permute(0, 2, 1, 3).flatten(2)
306
+
307
+ def forward(
308
+ self,
309
+ hidden_states,
310
+ attention_mask=None,
311
+ head_mask=None,
312
+ encoder_hidden_states=None,
313
+ encoder_attention_mask=None,
314
+ output_attentions=False,
315
+ ):
316
+ self.head_mask = head_mask
317
+ self.attention_mask = attention_mask
318
+
319
+ h1, h2, h3 = self.clone(hidden_states, 3)
320
+ mixed_query_layer = self.query(h1)
321
+
322
+ # If this is instantiated as a cross-attention module, the keys
323
+ # and values come from an encoder; the attention mask needs to be
324
+ # such that the encoder's padding tokens are not attended to.
325
+ if encoder_hidden_states is not None:
326
+ mixed_key_layer = self.key(encoder_hidden_states)
327
+ mixed_value_layer = self.value(encoder_hidden_states)
328
+ attention_mask = encoder_attention_mask
329
+ else:
330
+ mixed_key_layer = self.key(h2)
331
+ mixed_value_layer = self.value(h3)
332
+
333
+ query_layer = self.transpose_for_scores(mixed_query_layer)
334
+ key_layer = self.transpose_for_scores(mixed_key_layer)
335
+ value_layer = self.transpose_for_scores(mixed_value_layer)
336
+
337
+ # Take the dot product between "query" and "key" to get the raw attention scores.
338
+ attention_scores = self.matmul1([query_layer, key_layer.transpose(-1, -2)])
339
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
340
+ if attention_mask is not None:
341
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
342
+ attention_scores = self.add([attention_scores, attention_mask])
343
+
344
+ # Normalize the attention scores to probabilities.
345
+ attention_probs = self.softmax(attention_scores)
346
+
347
+ self.save_attn(attention_probs)
348
+ attention_probs.register_hook(self.save_attn_gradients)
349
+
350
+ # This is actually dropping out entire tokens to attend to, which might
351
+ # seem a bit unusual, but is taken from the original Transformer paper.
352
+ attention_probs = self.dropout(attention_probs)
353
+
354
+ # Mask heads if we want to
355
+ if head_mask is not None:
356
+ attention_probs = attention_probs * head_mask
357
+
358
+ context_layer = self.matmul2([attention_probs, value_layer])
359
+
360
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
361
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
362
+ context_layer = context_layer.view(*new_context_layer_shape)
363
+
364
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
365
+ return outputs
366
+
367
+ def relprop(self, cam, **kwargs):
368
+ # Assume output_attentions == False
369
+ cam = self.transpose_for_scores(cam)
370
+
371
+ # [attention_probs, value_layer]
372
+ (cam1, cam2) = self.matmul2.relprop(cam, **kwargs)
373
+ cam1 /= 2
374
+ cam2 /= 2
375
+ if self.head_mask is not None:
376
+ # [attention_probs, head_mask]
377
+ (cam1, _)= self.mul.relprop(cam1, **kwargs)
378
+
379
+
380
+ self.save_attn_cam(cam1)
381
+
382
+ cam1 = self.dropout.relprop(cam1, **kwargs)
383
+
384
+ cam1 = self.softmax.relprop(cam1, **kwargs)
385
+
386
+ if self.attention_mask is not None:
387
+ # [attention_scores, attention_mask]
388
+ (cam1, _) = self.add.relprop(cam1, **kwargs)
389
+
390
+ # [query_layer, key_layer.transpose(-1, -2)]
391
+ (cam1_1, cam1_2) = self.matmul1.relprop(cam1, **kwargs)
392
+ cam1_1 /= 2
393
+ cam1_2 /= 2
394
+
395
+ # query
396
+ cam1_1 = self.transpose_for_scores_relprop(cam1_1)
397
+ cam1_1 = self.query.relprop(cam1_1, **kwargs)
398
+
399
+ # key
400
+ cam1_2 = self.transpose_for_scores_relprop(cam1_2.transpose(-1, -2))
401
+ cam1_2 = self.key.relprop(cam1_2, **kwargs)
402
+
403
+ # value
404
+ cam2 = self.transpose_for_scores_relprop(cam2)
405
+ cam2 = self.value.relprop(cam2, **kwargs)
406
+
407
+ cam = self.clone.relprop((cam1_1, cam1_2, cam2), **kwargs)
408
+
409
+ return cam
410
+
411
+
412
+ class BertSelfOutput(nn.Module):
413
+ def __init__(self, config):
414
+ super().__init__()
415
+ self.dense = Linear(config.hidden_size, config.hidden_size)
416
+ self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
417
+ self.dropout = Dropout(config.hidden_dropout_prob)
418
+ self.add = Add()
419
+
420
+ def forward(self, hidden_states, input_tensor):
421
+ hidden_states = self.dense(hidden_states)
422
+ hidden_states = self.dropout(hidden_states)
423
+ add = self.add([hidden_states, input_tensor])
424
+ hidden_states = self.LayerNorm(add)
425
+ return hidden_states
426
+
427
+ def relprop(self, cam, **kwargs):
428
+ cam = self.LayerNorm.relprop(cam, **kwargs)
429
+ # [hidden_states, input_tensor]
430
+ (cam1, cam2) = self.add.relprop(cam, **kwargs)
431
+ cam1 = self.dropout.relprop(cam1, **kwargs)
432
+ cam1 = self.dense.relprop(cam1, **kwargs)
433
+
434
+ return (cam1, cam2)
435
+
436
+
437
+ class BertIntermediate(nn.Module):
438
+ def __init__(self, config):
439
+ super().__init__()
440
+ self.dense = Linear(config.hidden_size, config.intermediate_size)
441
+ if isinstance(config.hidden_act, str):
442
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]()
443
+ else:
444
+ self.intermediate_act_fn = config.hidden_act
445
+
446
+ def forward(self, hidden_states):
447
+ hidden_states = self.dense(hidden_states)
448
+ hidden_states = self.intermediate_act_fn(hidden_states)
449
+ return hidden_states
450
+
451
+ def relprop(self, cam, **kwargs):
452
+ cam = self.intermediate_act_fn.relprop(cam, **kwargs) # FIXME only ReLU
453
+ #print(cam.sum())
454
+ cam = self.dense.relprop(cam, **kwargs)
455
+ #print(cam.sum())
456
+ return cam
457
+
458
+
459
+ class BertOutput(nn.Module):
460
+ def __init__(self, config):
461
+ super().__init__()
462
+ self.dense = Linear(config.intermediate_size, config.hidden_size)
463
+ self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
464
+ self.dropout = Dropout(config.hidden_dropout_prob)
465
+ self.add = Add()
466
+
467
+ def forward(self, hidden_states, input_tensor):
468
+ hidden_states = self.dense(hidden_states)
469
+ hidden_states = self.dropout(hidden_states)
470
+ add = self.add([hidden_states, input_tensor])
471
+ hidden_states = self.LayerNorm(add)
472
+ return hidden_states
473
+
474
+ def relprop(self, cam, **kwargs):
475
+ # print("in", cam.sum())
476
+ cam = self.LayerNorm.relprop(cam, **kwargs)
477
+ #print(cam.sum())
478
+ # [hidden_states, input_tensor]
479
+ (cam1, cam2)= self.add.relprop(cam, **kwargs)
480
+ # print("add", cam1.sum(), cam2.sum(), cam1.sum() + cam2.sum())
481
+ cam1 = self.dropout.relprop(cam1, **kwargs)
482
+ #print(cam1.sum())
483
+ cam1 = self.dense.relprop(cam1, **kwargs)
484
+ # print("dense", cam1.sum())
485
+
486
+ # print("out", cam1.sum() + cam2.sum(), cam1.sum(), cam2.sum())
487
+ return (cam1, cam2)
488
+
489
+
490
+ class BertLayer(nn.Module):
491
+ def __init__(self, config):
492
+ super().__init__()
493
+ self.attention = BertAttention(config)
494
+ self.intermediate = BertIntermediate(config)
495
+ self.output = BertOutput(config)
496
+ self.clone = Clone()
497
+
498
+ def forward(
499
+ self,
500
+ hidden_states,
501
+ attention_mask=None,
502
+ head_mask=None,
503
+ output_attentions=False,
504
+ ):
505
+ self_attention_outputs = self.attention(
506
+ hidden_states,
507
+ attention_mask,
508
+ head_mask,
509
+ output_attentions=output_attentions,
510
+ )
511
+ attention_output = self_attention_outputs[0]
512
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
513
+
514
+ ao1, ao2 = self.clone(attention_output, 2)
515
+ intermediate_output = self.intermediate(ao1)
516
+ layer_output = self.output(intermediate_output, ao2)
517
+
518
+ outputs = (layer_output,) + outputs
519
+ return outputs
520
+
521
+ def relprop(self, cam, **kwargs):
522
+ (cam1, cam2) = self.output.relprop(cam, **kwargs)
523
+ # print("output", cam1.sum(), cam2.sum(), cam1.sum() + cam2.sum())
524
+ cam1 = self.intermediate.relprop(cam1, **kwargs)
525
+ # print("intermediate", cam1.sum())
526
+ cam = self.clone.relprop((cam1, cam2), **kwargs)
527
+ # print("clone", cam.sum())
528
+ cam = self.attention.relprop(cam, **kwargs)
529
+ # print("attention", cam.sum())
530
+ return cam
531
+
532
+
533
+ class BertModel(BertPreTrainedModel):
534
+ def __init__(self, config):
535
+ super().__init__(config)
536
+ self.config = config
537
+
538
+ self.embeddings = BertEmbeddings(config)
539
+ self.encoder = BertEncoder(config)
540
+ self.pooler = BertPooler(config)
541
+
542
+ self.init_weights()
543
+
544
+ def get_input_embeddings(self):
545
+ return self.embeddings.word_embeddings
546
+
547
+ def set_input_embeddings(self, value):
548
+ self.embeddings.word_embeddings = value
549
+
550
+ def forward(
551
+ self,
552
+ input_ids=None,
553
+ attention_mask=None,
554
+ token_type_ids=None,
555
+ position_ids=None,
556
+ head_mask=None,
557
+ inputs_embeds=None,
558
+ encoder_hidden_states=None,
559
+ encoder_attention_mask=None,
560
+ output_attentions=None,
561
+ output_hidden_states=None,
562
+ return_dict=None,
563
+ ):
564
+ r"""
565
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
566
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
567
+ if the model is configured as a decoder.
568
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
569
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask
570
+ is used in the cross-attention if the model is configured as a decoder.
571
+ Mask values selected in ``[0, 1]``:
572
+ ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
573
+ """
574
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
575
+ output_hidden_states = (
576
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
577
+ )
578
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
579
+
580
+ if input_ids is not None and inputs_embeds is not None:
581
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
582
+ elif input_ids is not None:
583
+ input_shape = input_ids.size()
584
+ elif inputs_embeds is not None:
585
+ input_shape = inputs_embeds.size()[:-1]
586
+ else:
587
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
588
+
589
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
590
+
591
+ if attention_mask is None:
592
+ attention_mask = torch.ones(input_shape, device=device)
593
+ if token_type_ids is None:
594
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
595
+
596
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
597
+ # ourselves in which case we just need to make it broadcastable to all heads.
598
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
599
+
600
+ # If a 2D or 3D attention mask is provided for the cross-attention
601
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
602
+ if self.config.is_decoder and encoder_hidden_states is not None:
603
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
604
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
605
+ if encoder_attention_mask is None:
606
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
607
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
608
+ else:
609
+ encoder_extended_attention_mask = None
610
+
611
+ # Prepare head mask if needed
612
+ # 1.0 in head_mask indicate we keep the head
613
+ # attention_probs has shape bsz x n_heads x N x N
614
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
615
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
616
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
617
+
618
+ embedding_output = self.embeddings(
619
+ input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
620
+ )
621
+
622
+ encoder_outputs = self.encoder(
623
+ embedding_output,
624
+ attention_mask=extended_attention_mask,
625
+ head_mask=head_mask,
626
+ encoder_hidden_states=encoder_hidden_states,
627
+ encoder_attention_mask=encoder_extended_attention_mask,
628
+ output_attentions=output_attentions,
629
+ output_hidden_states=output_hidden_states,
630
+ return_dict=return_dict,
631
+ )
632
+ sequence_output = encoder_outputs[0]
633
+ pooled_output = self.pooler(sequence_output)
634
+
635
+ if not return_dict:
636
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
637
+
638
+ return BaseModelOutputWithPooling(
639
+ last_hidden_state=sequence_output,
640
+ pooler_output=pooled_output,
641
+ hidden_states=encoder_outputs.hidden_states,
642
+ attentions=encoder_outputs.attentions,
643
+ )
644
+
645
+ def relprop(self, cam, **kwargs):
646
+ cam = self.pooler.relprop(cam, **kwargs)
647
+ # print("111111111111",cam.sum())
648
+ cam = self.encoder.relprop(cam, **kwargs)
649
+ # print("222222222222222", cam.sum())
650
+ # print("conservation: ", cam.sum())
651
+ return cam
652
+
653
+
654
+ if __name__ == '__main__':
655
+ class Config:
656
+ def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob):
657
+ self.hidden_size = hidden_size
658
+ self.num_attention_heads = num_attention_heads
659
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
660
+
661
+ model = BertSelfAttention(Config(1024, 4, 0.1))
662
+ x = torch.rand(2, 20, 1024)
663
+ x.requires_grad_()
664
+
665
+ model.eval()
666
+
667
+ y = model.forward(x)
668
+
669
+ relprop = model.relprop(torch.rand(2, 20, 1024), (torch.rand(2, 20, 1024),))
670
+
671
+ print(relprop[1][0].shape)
BERT/BERT_explainability/modules/BERT/BertForSequenceClassification.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from turtle import forward
2
+
3
+ from transformers import BertPreTrainedModel
4
+
5
+ from BERT_explainability.modules.layers_ours import *
6
+ from BERT_explainability.modules.BERT.BERT import BertModel
7
+ from torch.nn import CrossEntropyLoss, MSELoss
8
+ import torch.nn as nn
9
+ from typing import List, Any
10
+ import torch
11
+
12
+ from BERT_rationale_benchmark.models.model_utils import PaddedSequence
13
+
14
+ import sys
15
+ sys.path.append("../../")
16
+ from sequenceoutput.modeling_output import SequenceClassifierOutput
17
+
18
+ class BertForSequenceClassification(BertPreTrainedModel):
19
+ def __init__(self, config):
20
+ super().__init__(config)
21
+ self.num_labels = config.num_labels
22
+
23
+ self.bert = BertModel(config)
24
+ self.dropout = Dropout(config.hidden_dropout_prob)
25
+ self.classifier = Linear(config.hidden_size, config.num_labels)
26
+
27
+ self.init_weights()
28
+
29
+
30
+
31
+
32
+ def forward(
33
+ self,
34
+ input_ids=None,
35
+ attention_mask=None,
36
+ token_type_ids=None,
37
+ position_ids=None,
38
+ head_mask=None,
39
+ inputs_embeds=None,
40
+ labels=None,
41
+ output_attentions=None,
42
+ output_hidden_states=None,
43
+ return_dict=None,
44
+ ):
45
+ r"""
46
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
47
+ Labels for computing the sequence classification/regression loss.
48
+ Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
49
+ If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
50
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
51
+ """
52
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
53
+
54
+ outputs = self.bert(
55
+ input_ids,
56
+ attention_mask=attention_mask,
57
+ token_type_ids=token_type_ids,
58
+ position_ids=position_ids,
59
+ head_mask=head_mask,
60
+ inputs_embeds=inputs_embeds,
61
+ output_attentions=output_attentions,
62
+ output_hidden_states=output_hidden_states,
63
+ return_dict=return_dict,
64
+ )
65
+
66
+ pooled_output = outputs[1]
67
+
68
+ pooled_output = self.dropout(pooled_output)
69
+ logits = self.classifier(pooled_output)
70
+
71
+ loss = None
72
+ if labels is not None:
73
+ if self.num_labels == 1:
74
+ # We are doing regression
75
+ loss_fct = MSELoss()
76
+ loss = loss_fct(logits.view(-1), labels.view(-1))
77
+ else:
78
+ loss_fct = CrossEntropyLoss()
79
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
80
+
81
+ if not return_dict:
82
+ output = (logits,) + outputs[2:]
83
+ return ((loss,) + output) if loss is not None else output
84
+
85
+ return SequenceClassifierOutput(
86
+ loss=loss,
87
+ logits=logits,
88
+ hidden_states=outputs.hidden_states,
89
+ attentions=outputs.attentions,
90
+ )
91
+
92
+
93
+
94
+
95
+ def relprop(self, cam=None, **kwargs):
96
+ cam = self.classifier.relprop(cam, **kwargs)
97
+ cam = self.dropout.relprop(cam, **kwargs)
98
+ cam = self.bert.relprop(cam, **kwargs)
99
+ # print("conservation: ", cam.sum())
100
+ return cam
101
+
102
+
103
+ # this is the actual classifier we will be using
104
+ class BertClassifier(nn.Module):
105
+ """Thin wrapper around BertForSequenceClassification"""
106
+
107
+ def __init__(self,
108
+ bert_dir: str,
109
+ pad_token_id: int,
110
+ cls_token_id: int,
111
+ sep_token_id: int,
112
+ num_labels: int,
113
+ max_length: int = 512,
114
+ use_half_precision=True):
115
+ super(BertClassifier, self).__init__()
116
+ bert = BertForSequenceClassification.from_pretrained(bert_dir, num_labels=num_labels)
117
+ if use_half_precision:
118
+ import apex
119
+ bert = bert.half()
120
+ self.bert = bert
121
+ self.pad_token_id = pad_token_id
122
+ self.cls_token_id = cls_token_id
123
+ self.sep_token_id = sep_token_id
124
+ self.max_length = max_length
125
+
126
+ def forward(self,
127
+ query: List[torch.tensor],
128
+ docids: List[Any],
129
+ document_batch: List[torch.tensor]):
130
+ assert len(query) == len(document_batch)
131
+ print(query)
132
+ # note about device management:
133
+ # since distributed training is enabled, the inputs to this module can be on *any* device (preferably cpu, since we wrap and unwrap the module)
134
+ # we want to keep these params on the input device (assuming CPU) for as long as possible for cheap memory access
135
+ target_device = next(self.parameters()).device
136
+ cls_token = torch.tensor([self.cls_token_id]).to(device=document_batch[0].device)
137
+ sep_token = torch.tensor([self.sep_token_id]).to(device=document_batch[0].device)
138
+ input_tensors = []
139
+ position_ids = []
140
+ for q, d in zip(query, document_batch):
141
+ if len(q) + len(d) + 2 > self.max_length:
142
+ d = d[:(self.max_length - len(q) - 2)]
143
+ input_tensors.append(torch.cat([cls_token, q, sep_token, d]))
144
+ position_ids.append(torch.tensor(list(range(0, len(q) + 1)) + list(range(0, len(d) + 1))))
145
+ bert_input = PaddedSequence.autopad(input_tensors, batch_first=True, padding_value=self.pad_token_id,
146
+ device=target_device)
147
+ positions = PaddedSequence.autopad(position_ids, batch_first=True, padding_value=0, device=target_device)
148
+ (classes,) = self.bert(bert_input.data,
149
+ attention_mask=bert_input.mask(on=0.0, off=float('-inf'), device=target_device),
150
+ position_ids=positions.data)
151
+ assert torch.all(classes == classes) # for nans
152
+
153
+ print(input_tensors[0])
154
+ print(self.relprop()[0])
155
+
156
+ return classes
157
+
158
+ def relprop(self, cam=None, **kwargs):
159
+ return self.bert.relprop(cam, **kwargs)
160
+
161
+
162
+ if __name__ == '__main__':
163
+ from transformers import BertTokenizer
164
+ import os
165
+
166
+ class Config:
167
+ def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, num_labels,
168
+ hidden_dropout_prob):
169
+ self.hidden_size = hidden_size
170
+ self.num_attention_heads = num_attention_heads
171
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
172
+ self.num_labels = num_labels
173
+ self.hidden_dropout_prob = hidden_dropout_prob
174
+
175
+
176
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
177
+ x = tokenizer.encode_plus("In this movie the acting is great. The movie is perfect! [sep]",
178
+ add_special_tokens=True,
179
+ max_length=512,
180
+ return_token_type_ids=False,
181
+ return_attention_mask=True,
182
+ pad_to_max_length=True,
183
+ return_tensors='pt',
184
+ truncation=True)
185
+
186
+ print(x['input_ids'])
187
+
188
+ model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
189
+ model_save_file = os.path.join('./BERT_explainability/output_bert/movies/classifier/', 'classifier.pt')
190
+ model.load_state_dict(torch.load(model_save_file))
191
+
192
+ # x = torch.randint(100, (2, 20))
193
+ # x = torch.tensor([[101, 2054, 2003, 1996, 15792, 1997, 2023, 3319, 1029, 102,
194
+ # 101, 4079, 102, 101, 6732, 102, 101, 2643, 102, 101,
195
+ # 2038, 102, 101, 1037, 102, 101, 2933, 102, 101, 2005,
196
+ # 102, 101, 2032, 102, 101, 1010, 102, 101, 1037, 102,
197
+ # 101, 3800, 102, 101, 2005, 102, 101, 2010, 102, 101,
198
+ # 2166, 102, 101, 1010, 102, 101, 1998, 102, 101, 2010,
199
+ # 102, 101, 4650, 102, 101, 1010, 102, 101, 2002, 102,
200
+ # 101, 2074, 102, 101, 2515, 102, 101, 1050, 102, 101,
201
+ # 1005, 102, 101, 1056, 102, 101, 2113, 102, 101, 2054,
202
+ # 102, 101, 1012, 102]])
203
+ # x.requires_grad_()
204
+
205
+ model.eval()
206
+
207
+ y = model(x['input_ids'], x['attention_mask'])
208
+ print(y)
209
+
210
+ cam, _ = model.relprop()
211
+
212
+ #print(cam.shape)
213
+
214
+ cam = cam.sum(-1)
215
+ #print(cam)
BERT/BERT_explainability/modules/BERT/ExplanationGenerator.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import numpy as np
3
+ import torch
4
+ import glob
5
+
6
+ # compute rollout between attention layers
7
+ def compute_rollout_attention(all_layer_matrices, start_layer=0):
8
+ # adding residual consideration- code adapted from https://github.com/samiraabnar/attention_flow
9
+ num_tokens = all_layer_matrices[0].shape[1]
10
+ batch_size = all_layer_matrices[0].shape[0]
11
+ eye = torch.eye(num_tokens).expand(batch_size, num_tokens, num_tokens).to(all_layer_matrices[0].device)
12
+ all_layer_matrices = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))]
13
+ matrices_aug = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
14
+ for i in range(len(all_layer_matrices))]
15
+ joint_attention = matrices_aug[start_layer]
16
+ for i in range(start_layer+1, len(matrices_aug)):
17
+ joint_attention = matrices_aug[i].bmm(joint_attention)
18
+ return joint_attention
19
+
20
+ class Generator:
21
+ def __init__(self, model):
22
+ self.model = model
23
+ self.model.eval()
24
+
25
+ def forward(self, input_ids, attention_mask):
26
+ return self.model(input_ids, attention_mask)
27
+
28
+ def generate_LRP(self, input_ids, attention_mask,
29
+ index=None, start_layer=11):
30
+ output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
31
+ kwargs = {"alpha": 1}
32
+
33
+ if index == None:
34
+ index = np.argmax(output.cpu().data.numpy(), axis=-1)
35
+
36
+ one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
37
+ one_hot[0, index] = 1
38
+ one_hot_vector = one_hot
39
+ one_hot = torch.from_numpy(one_hot).requires_grad_(True)
40
+ one_hot = torch.sum(one_hot.cuda() * output)
41
+
42
+ self.model.zero_grad()
43
+ one_hot.backward(retain_graph=True)
44
+
45
+ self.model.relprop(torch.tensor(one_hot_vector).to(input_ids.device), **kwargs)
46
+
47
+ cams = []
48
+ blocks = self.model.bert.encoder.layer
49
+ for blk in blocks:
50
+ grad = blk.attention.self.get_attn_gradients()
51
+ cam = blk.attention.self.get_attn_cam()
52
+ cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
53
+ grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
54
+ cam = grad * cam
55
+ cam = cam.clamp(min=0).mean(dim=0)
56
+ cams.append(cam.unsqueeze(0))
57
+ rollout = compute_rollout_attention(cams, start_layer=start_layer)
58
+ rollout[:, 0, 0] = rollout[:, 0].min()
59
+ return rollout[:, 0]
60
+
61
+
62
+ def generate_LRP_last_layer(self, input_ids, attention_mask,
63
+ index=None):
64
+ output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
65
+ kwargs = {"alpha": 1}
66
+ if index == None:
67
+ index = np.argmax(output.cpu().data.numpy(), axis=-1)
68
+
69
+ one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
70
+ one_hot[0, index] = 1
71
+ one_hot_vector = one_hot
72
+ one_hot = torch.from_numpy(one_hot).requires_grad_(True)
73
+ one_hot = torch.sum(one_hot.cuda() * output)
74
+
75
+ self.model.zero_grad()
76
+ one_hot.backward(retain_graph=True)
77
+
78
+ self.model.relprop(torch.tensor(one_hot_vector).to(input_ids.device), **kwargs)
79
+
80
+ cam = self.model.bert.encoder.layer[-1].attention.self.get_attn_cam()[0]
81
+ cam = cam.clamp(min=0).mean(dim=0).unsqueeze(0)
82
+ cam[:, 0, 0] = 0
83
+ return cam[:, 0]
84
+
85
+ def generate_full_lrp(self, input_ids, attention_mask,
86
+ index=None):
87
+ output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
88
+ kwargs = {"alpha": 1}
89
+
90
+ if index == None:
91
+ index = np.argmax(output.cpu().data.numpy(), axis=-1)
92
+
93
+ one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
94
+ one_hot[0, index] = 1
95
+ one_hot_vector = one_hot
96
+ one_hot = torch.from_numpy(one_hot).requires_grad_(True)
97
+ one_hot = torch.sum(one_hot.cuda() * output)
98
+
99
+ self.model.zero_grad()
100
+ one_hot.backward(retain_graph=True)
101
+
102
+ cam = self.model.relprop(torch.tensor(one_hot_vector).to(input_ids.device), **kwargs)
103
+ cam = cam.sum(dim=2)
104
+ cam[:, 0] = 0
105
+ return cam
106
+
107
+ def generate_attn_last_layer(self, input_ids, attention_mask,
108
+ index=None):
109
+ output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
110
+ cam = self.model.bert.encoder.layer[-1].attention.self.get_attn()[0]
111
+ cam = cam.mean(dim=0).unsqueeze(0)
112
+ cam[:, 0, 0] = 0
113
+ return cam[:, 0]
114
+
115
+ def generate_rollout(self, input_ids, attention_mask, start_layer=0, index=None):
116
+ self.model.zero_grad()
117
+ output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
118
+ blocks = self.model.bert.encoder.layer
119
+ all_layer_attentions = []
120
+ for blk in blocks:
121
+ attn_heads = blk.attention.self.get_attn()
122
+ avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach()
123
+ all_layer_attentions.append(avg_heads)
124
+ rollout = compute_rollout_attention(all_layer_attentions, start_layer=start_layer)
125
+ rollout[:, 0, 0] = 0
126
+ return rollout[:, 0]
127
+
128
+ def generate_attn_gradcam(self, input_ids, attention_mask, index=None):
129
+ output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
130
+ kwargs = {"alpha": 1}
131
+
132
+ if index == None:
133
+ index = np.argmax(output.cpu().data.numpy(), axis=-1)
134
+
135
+ one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
136
+ one_hot[0, index] = 1
137
+ one_hot_vector = one_hot
138
+ one_hot = torch.from_numpy(one_hot).requires_grad_(True)
139
+ one_hot = torch.sum(one_hot.cuda() * output)
140
+
141
+ self.model.zero_grad()
142
+ one_hot.backward(retain_graph=True)
143
+
144
+ self.model.relprop(torch.tensor(one_hot_vector).to(input_ids.device), **kwargs)
145
+
146
+ cam = self.model.bert.encoder.layer[-1].attention.self.get_attn()
147
+ grad = self.model.bert.encoder.layer[-1].attention.self.get_attn_gradients()
148
+
149
+ cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
150
+ grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
151
+ grad = grad.mean(dim=[1, 2], keepdim=True)
152
+ cam = (cam * grad).mean(0).clamp(min=0).unsqueeze(0)
153
+ cam = (cam - cam.min()) / (cam.max() - cam.min())
154
+ cam[:, 0, 0] = 0
155
+ return cam[:, 0]
156
+
BERT/BERT_explainability/modules/BERT/__pycache__/BERT.cpython-38.pyc ADDED
Binary file (18.2 kB). View file
 
BERT/BERT_explainability/modules/BERT/__pycache__/BertForSequenceClassification.cpython-38.pyc ADDED
Binary file (5.76 kB). View file
 
BERT/BERT_explainability/modules/BERT/__pycache__/ExplanationGenerator.cpython-311.pyc ADDED
Binary file (12.9 kB). View file
 
BERT/BERT_explainability/modules/BERT/__pycache__/ExplanationGenerator.cpython-38.pyc ADDED
Binary file (5.5 kB). View file
 
BERT/BERT_explainability/modules/__init__.py ADDED
File without changes
BERT/BERT_explainability/modules/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (181 Bytes). View file
 
BERT/BERT_explainability/modules/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (185 Bytes). View file
 
BERT/BERT_explainability/modules/__pycache__/layers_ours.cpython-38.pyc ADDED
Binary file (10.6 kB). View file
 
BERT/BERT_explainability/modules/layers_lrp.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ __all__ = ['forward_hook', 'Clone', 'Add', 'Cat', 'ReLU', 'GELU', 'Dropout', 'BatchNorm2d', 'Linear', 'MaxPool2d',
6
+ 'AdaptiveAvgPool2d', 'AvgPool2d', 'Conv2d', 'Sequential', 'safe_divide', 'einsum', 'Softmax', 'IndexSelect',
7
+ 'LayerNorm', 'AddEye', 'Tanh', 'MatMul', 'Mul']
8
+
9
+
10
+ def safe_divide(a, b):
11
+ den = b.clamp(min=1e-9) + b.clamp(max=1e-9)
12
+ den = den + den.eq(0).type(den.type()) * 1e-9
13
+ return a / den * b.ne(0).type(b.type())
14
+
15
+
16
+ def forward_hook(self, input, output):
17
+ if type(input[0]) in (list, tuple):
18
+ self.X = []
19
+ for i in input[0]:
20
+ x = i.detach()
21
+ x.requires_grad = True
22
+ self.X.append(x)
23
+ else:
24
+ self.X = input[0].detach()
25
+ self.X.requires_grad = True
26
+
27
+ self.Y = output
28
+
29
+
30
+ def backward_hook(self, grad_input, grad_output):
31
+ self.grad_input = grad_input
32
+ self.grad_output = grad_output
33
+
34
+
35
+ class RelProp(nn.Module):
36
+ def __init__(self):
37
+ super(RelProp, self).__init__()
38
+ # if not self.training:
39
+ self.register_forward_hook(forward_hook)
40
+
41
+ def gradprop(self, Z, X, S):
42
+ C = torch.autograd.grad(Z, X, S, retain_graph=True)
43
+ return C
44
+
45
+ def relprop(self, R, alpha):
46
+ return R
47
+
48
+
49
+ class RelPropSimple(RelProp):
50
+ def relprop(self, R, alpha):
51
+ Z = self.forward(self.X)
52
+ S = safe_divide(R, Z)
53
+ C = self.gradprop(Z, self.X, S)
54
+
55
+ if torch.is_tensor(self.X) == False:
56
+ outputs = []
57
+ outputs.append(self.X[0] * C[0])
58
+ outputs.append(self.X[1] * C[1])
59
+ else:
60
+ outputs = self.X * (C[0])
61
+ return outputs
62
+
63
+ class AddEye(RelPropSimple):
64
+ # input of shape B, C, seq_len, seq_len
65
+ def forward(self, input):
66
+ return input + torch.eye(input.shape[2]).expand_as(input).to(input.device)
67
+
68
+ class ReLU(nn.ReLU, RelProp):
69
+ pass
70
+
71
+ class Tanh(nn.Tanh, RelProp):
72
+ pass
73
+
74
+ class GELU(nn.GELU, RelProp):
75
+ pass
76
+
77
+ class Softmax(nn.Softmax, RelProp):
78
+ pass
79
+
80
+ class LayerNorm(nn.LayerNorm, RelProp):
81
+ pass
82
+
83
+ class Dropout(nn.Dropout, RelProp):
84
+ pass
85
+
86
+
87
+ class MaxPool2d(nn.MaxPool2d, RelPropSimple):
88
+ pass
89
+
90
+ class LayerNorm(nn.LayerNorm, RelProp):
91
+ pass
92
+
93
+ class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d, RelPropSimple):
94
+ pass
95
+
96
+ class MatMul(RelPropSimple):
97
+ def forward(self, inputs):
98
+ return torch.matmul(*inputs)
99
+
100
+ class Mul(RelPropSimple):
101
+ def forward(self, inputs):
102
+ return torch.mul(*inputs)
103
+
104
+ class AvgPool2d(nn.AvgPool2d, RelPropSimple):
105
+ pass
106
+
107
+
108
+ class Add(RelPropSimple):
109
+ def forward(self, inputs):
110
+ return torch.add(*inputs)
111
+
112
+ class einsum(RelPropSimple):
113
+ def __init__(self, equation):
114
+ super().__init__()
115
+ self.equation = equation
116
+ def forward(self, *operands):
117
+ return torch.einsum(self.equation, *operands)
118
+
119
+ class IndexSelect(RelProp):
120
+ def forward(self, inputs, dim, indices):
121
+ self.__setattr__('dim', dim)
122
+ self.__setattr__('indices', indices)
123
+
124
+ return torch.index_select(inputs, dim, indices)
125
+
126
+ def relprop(self, R, alpha):
127
+ Z = self.forward(self.X, self.dim, self.indices)
128
+ S = safe_divide(R, Z)
129
+ C = self.gradprop(Z, self.X, S)
130
+
131
+ if torch.is_tensor(self.X) == False:
132
+ outputs = []
133
+ outputs.append(self.X[0] * C[0])
134
+ outputs.append(self.X[1] * C[1])
135
+ else:
136
+ outputs = self.X * (C[0])
137
+ return outputs
138
+
139
+
140
+
141
+ class Clone(RelProp):
142
+ def forward(self, input, num):
143
+ self.__setattr__('num', num)
144
+ outputs = []
145
+ for _ in range(num):
146
+ outputs.append(input)
147
+
148
+ return outputs
149
+
150
+ def relprop(self, R, alpha):
151
+ Z = []
152
+ for _ in range(self.num):
153
+ Z.append(self.X)
154
+ S = [safe_divide(r, z) for r, z in zip(R, Z)]
155
+ C = self.gradprop(Z, self.X, S)[0]
156
+
157
+ R = self.X * C
158
+
159
+ return R
160
+
161
+ class Cat(RelProp):
162
+ def forward(self, inputs, dim):
163
+ self.__setattr__('dim', dim)
164
+ return torch.cat(inputs, dim)
165
+
166
+ def relprop(self, R, alpha):
167
+ Z = self.forward(self.X, self.dim)
168
+ S = safe_divide(R, Z)
169
+ C = self.gradprop(Z, self.X, S)
170
+
171
+ outputs = []
172
+ for x, c in zip(self.X, C):
173
+ outputs.append(x * c)
174
+
175
+ return outputs
176
+
177
+ class Sequential(nn.Sequential):
178
+ def relprop(self, R, alpha):
179
+ for m in reversed(self._modules.values()):
180
+ R = m.relprop(R, alpha)
181
+ return R
182
+
183
+ class BatchNorm2d(nn.BatchNorm2d, RelProp):
184
+ def relprop(self, R, alpha):
185
+ X = self.X
186
+ beta = 1 - alpha
187
+ weight = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) / (
188
+ (self.running_var.unsqueeze(0).unsqueeze(2).unsqueeze(3).pow(2) + self.eps).pow(0.5))
189
+ Z = X * weight + 1e-9
190
+ S = R / Z
191
+ Ca = S * weight
192
+ R = self.X * (Ca)
193
+ return R
194
+
195
+
196
+ class Linear(nn.Linear, RelProp):
197
+ def relprop(self, R, alpha):
198
+ beta = alpha - 1
199
+ pw = torch.clamp(self.weight, min=0)
200
+ nw = torch.clamp(self.weight, max=0)
201
+ px = torch.clamp(self.X, min=0)
202
+ nx = torch.clamp(self.X, max=0)
203
+
204
+ def f(w1, w2, x1, x2):
205
+ Z1 = F.linear(x1, w1)
206
+ Z2 = F.linear(x2, w2)
207
+ S1 = safe_divide(R, Z1)
208
+ S2 = safe_divide(R, Z2)
209
+ C1 = x1 * torch.autograd.grad(Z1, x1, S1)[0]
210
+ C2 = x2 * torch.autograd.grad(Z2, x2, S2)[0]
211
+
212
+ return C1 + C2
213
+
214
+ activator_relevances = f(pw, nw, px, nx)
215
+ inhibitor_relevances = f(nw, pw, px, nx)
216
+
217
+ R = alpha * activator_relevances - beta * inhibitor_relevances
218
+
219
+ return R
220
+
221
+ class Conv2d(nn.Conv2d, RelProp):
222
+ def gradprop2(self, DY, weight):
223
+ Z = self.forward(self.X)
224
+
225
+ output_padding = self.X.size()[2] - (
226
+ (Z.size()[2] - 1) * self.stride[0] - 2 * self.padding[0] + self.kernel_size[0])
227
+
228
+ return F.conv_transpose2d(DY, weight, stride=self.stride, padding=self.padding, output_padding=output_padding)
229
+
230
+ def relprop(self, R, alpha):
231
+ if self.X.shape[1] == 3:
232
+ pw = torch.clamp(self.weight, min=0)
233
+ nw = torch.clamp(self.weight, max=0)
234
+ X = self.X
235
+ L = self.X * 0 + \
236
+ torch.min(torch.min(torch.min(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
237
+ keepdim=True)[0]
238
+ H = self.X * 0 + \
239
+ torch.max(torch.max(torch.max(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
240
+ keepdim=True)[0]
241
+ Za = torch.conv2d(X, self.weight, bias=None, stride=self.stride, padding=self.padding) - \
242
+ torch.conv2d(L, pw, bias=None, stride=self.stride, padding=self.padding) - \
243
+ torch.conv2d(H, nw, bias=None, stride=self.stride, padding=self.padding) + 1e-9
244
+
245
+ S = R / Za
246
+ C = X * self.gradprop2(S, self.weight) - L * self.gradprop2(S, pw) - H * self.gradprop2(S, nw)
247
+ R = C
248
+ else:
249
+ beta = alpha - 1
250
+ pw = torch.clamp(self.weight, min=0)
251
+ nw = torch.clamp(self.weight, max=0)
252
+ px = torch.clamp(self.X, min=0)
253
+ nx = torch.clamp(self.X, max=0)
254
+
255
+ def f(w1, w2, x1, x2):
256
+ Z1 = F.conv2d(x1, w1, bias=None, stride=self.stride, padding=self.padding)
257
+ Z2 = F.conv2d(x2, w2, bias=None, stride=self.stride, padding=self.padding)
258
+ S1 = safe_divide(R, Z1)
259
+ S2 = safe_divide(R, Z2)
260
+ C1 = x1 * self.gradprop(Z1, x1, S1)[0]
261
+ C2 = x2 * self.gradprop(Z2, x2, S2)[0]
262
+ return C1 + C2
263
+
264
+ activator_relevances = f(pw, nw, px, nx)
265
+ inhibitor_relevances = f(nw, pw, px, nx)
266
+
267
+ R = alpha * activator_relevances - beta * inhibitor_relevances
268
+ return R
BERT/BERT_explainability/modules/layers_ours.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ __all__ = ['forward_hook', 'Clone', 'Add', 'Cat', 'ReLU', 'GELU', 'Dropout', 'BatchNorm2d', 'Linear', 'MaxPool2d',
6
+ 'AdaptiveAvgPool2d', 'AvgPool2d', 'Conv2d', 'Sequential', 'safe_divide', 'einsum', 'Softmax', 'IndexSelect',
7
+ 'LayerNorm', 'AddEye', 'Tanh', 'MatMul', 'Mul']
8
+
9
+
10
+ def safe_divide(a, b):
11
+ den = b.clamp(min=1e-9) + b.clamp(max=1e-9)
12
+ den = den + den.eq(0).type(den.type()) * 1e-9
13
+ return a / den * b.ne(0).type(b.type())
14
+
15
+
16
+ def forward_hook(self, input, output):
17
+ if type(input[0]) in (list, tuple):
18
+ self.X = []
19
+ for i in input[0]:
20
+ x = i.detach()
21
+ x.requires_grad = True
22
+ self.X.append(x)
23
+ else:
24
+ self.X = input[0].detach()
25
+ self.X.requires_grad = True
26
+
27
+ self.Y = output
28
+
29
+
30
+ def backward_hook(self, grad_input, grad_output):
31
+ self.grad_input = grad_input
32
+ self.grad_output = grad_output
33
+
34
+
35
+ class RelProp(nn.Module):
36
+ def __init__(self):
37
+ super(RelProp, self).__init__()
38
+ # if not self.training:
39
+ self.register_forward_hook(forward_hook)
40
+
41
+ def gradprop(self, Z, X, S):
42
+ C = torch.autograd.grad(Z, X, S, retain_graph=True)
43
+ return C
44
+
45
+ def relprop(self, R, alpha):
46
+ return R
47
+
48
+
49
+ class RelPropSimple(RelProp):
50
+ def relprop(self, R, alpha):
51
+ Z = self.forward(self.X)
52
+ S = safe_divide(R, Z)
53
+ C = self.gradprop(Z, self.X, S)
54
+
55
+ if torch.is_tensor(self.X) == False:
56
+ outputs = []
57
+ outputs.append(self.X[0] * C[0])
58
+ outputs.append(self.X[1] * C[1])
59
+ else:
60
+ outputs = self.X * (C[0])
61
+ return outputs
62
+
63
+ class AddEye(RelPropSimple):
64
+ # input of shape B, C, seq_len, seq_len
65
+ def forward(self, input):
66
+ return input + torch.eye(input.shape[2]).expand_as(input).to(input.device)
67
+
68
+ class ReLU(nn.ReLU, RelProp):
69
+ pass
70
+
71
+ class GELU(nn.GELU, RelProp):
72
+ pass
73
+
74
+ class Softmax(nn.Softmax, RelProp):
75
+ pass
76
+
77
+ class Mul(RelPropSimple):
78
+ def forward(self, inputs):
79
+ return torch.mul(*inputs)
80
+
81
+ class Tanh(nn.Tanh, RelProp):
82
+ pass
83
+ class LayerNorm(nn.LayerNorm, RelProp):
84
+ pass
85
+
86
+ class Dropout(nn.Dropout, RelProp):
87
+ pass
88
+
89
+ class MatMul(RelPropSimple):
90
+ def forward(self, inputs):
91
+ return torch.matmul(*inputs)
92
+
93
+ class MaxPool2d(nn.MaxPool2d, RelPropSimple):
94
+ pass
95
+
96
+ class LayerNorm(nn.LayerNorm, RelProp):
97
+ pass
98
+
99
+ class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d, RelPropSimple):
100
+ pass
101
+
102
+
103
+ class AvgPool2d(nn.AvgPool2d, RelPropSimple):
104
+ pass
105
+
106
+
107
+ class Add(RelPropSimple):
108
+ def forward(self, inputs):
109
+ return torch.add(*inputs)
110
+
111
+ def relprop(self, R, alpha):
112
+ Z = self.forward(self.X)
113
+ S = safe_divide(R, Z)
114
+ C = self.gradprop(Z, self.X, S)
115
+
116
+ a = self.X[0] * C[0]
117
+ b = self.X[1] * C[1]
118
+
119
+ a_sum = a.sum()
120
+ b_sum = b.sum()
121
+
122
+ a_fact = safe_divide(a_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum()
123
+ b_fact = safe_divide(b_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum()
124
+
125
+ a = a * safe_divide(a_fact, a.sum())
126
+ b = b * safe_divide(b_fact, b.sum())
127
+
128
+ outputs = [a, b]
129
+
130
+ return outputs
131
+
132
+ class einsum(RelPropSimple):
133
+ def __init__(self, equation):
134
+ super().__init__()
135
+ self.equation = equation
136
+ def forward(self, *operands):
137
+ return torch.einsum(self.equation, *operands)
138
+
139
+ class IndexSelect(RelProp):
140
+ def forward(self, inputs, dim, indices):
141
+ self.__setattr__('dim', dim)
142
+ self.__setattr__('indices', indices)
143
+
144
+ return torch.index_select(inputs, dim, indices)
145
+
146
+ def relprop(self, R, alpha):
147
+ Z = self.forward(self.X, self.dim, self.indices)
148
+ S = safe_divide(R, Z)
149
+ C = self.gradprop(Z, self.X, S)
150
+
151
+ if torch.is_tensor(self.X) == False:
152
+ outputs = []
153
+ outputs.append(self.X[0] * C[0])
154
+ outputs.append(self.X[1] * C[1])
155
+ else:
156
+ outputs = self.X * (C[0])
157
+ return outputs
158
+
159
+
160
+
161
+ class Clone(RelProp):
162
+ def forward(self, input, num):
163
+ self.__setattr__('num', num)
164
+ outputs = []
165
+ for _ in range(num):
166
+ outputs.append(input)
167
+
168
+ return outputs
169
+
170
+ def relprop(self, R, alpha):
171
+ Z = []
172
+ for _ in range(self.num):
173
+ Z.append(self.X)
174
+ S = [safe_divide(r, z) for r, z in zip(R, Z)]
175
+ C = self.gradprop(Z, self.X, S)[0]
176
+
177
+ R = self.X * C
178
+
179
+ return R
180
+
181
+
182
+ class Cat(RelProp):
183
+ def forward(self, inputs, dim):
184
+ self.__setattr__('dim', dim)
185
+ return torch.cat(inputs, dim)
186
+
187
+ def relprop(self, R, alpha):
188
+ Z = self.forward(self.X, self.dim)
189
+ S = safe_divide(R, Z)
190
+ C = self.gradprop(Z, self.X, S)
191
+
192
+ outputs = []
193
+ for x, c in zip(self.X, C):
194
+ outputs.append(x * c)
195
+
196
+ return outputs
197
+
198
+
199
+ class Sequential(nn.Sequential):
200
+ def relprop(self, R, alpha):
201
+ for m in reversed(self._modules.values()):
202
+ R = m.relprop(R, alpha)
203
+ return R
204
+
205
+
206
+ class BatchNorm2d(nn.BatchNorm2d, RelProp):
207
+ def relprop(self, R, alpha):
208
+ X = self.X
209
+ beta = 1 - alpha
210
+ weight = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) / (
211
+ (self.running_var.unsqueeze(0).unsqueeze(2).unsqueeze(3).pow(2) + self.eps).pow(0.5))
212
+ Z = X * weight + 1e-9
213
+ S = R / Z
214
+ Ca = S * weight
215
+ R = self.X * (Ca)
216
+ return R
217
+
218
+
219
+ class Linear(nn.Linear, RelProp):
220
+ def relprop(self, R, alpha):
221
+ beta = alpha - 1
222
+ pw = torch.clamp(self.weight, min=0)
223
+ nw = torch.clamp(self.weight, max=0)
224
+ px = torch.clamp(self.X, min=0)
225
+ nx = torch.clamp(self.X, max=0)
226
+
227
+ def f(w1, w2, x1, x2):
228
+ Z1 = F.linear(x1, w1)
229
+ Z2 = F.linear(x2, w2)
230
+ S1 = safe_divide(R, Z1 + Z2)
231
+ S2 = safe_divide(R, Z1 + Z2)
232
+ C1 = x1 * self.gradprop(Z1, x1, S1)[0]
233
+ C2 = x2 * self.gradprop(Z2, x2, S2)[0]
234
+
235
+ return C1 + C2
236
+
237
+ activator_relevances = f(pw, nw, px, nx)
238
+ inhibitor_relevances = f(nw, pw, px, nx)
239
+
240
+ R = alpha * activator_relevances - beta * inhibitor_relevances
241
+
242
+ return R
243
+
244
+
245
+ class Conv2d(nn.Conv2d, RelProp):
246
+ def gradprop2(self, DY, weight):
247
+ Z = self.forward(self.X)
248
+
249
+ output_padding = self.X.size()[2] - (
250
+ (Z.size()[2] - 1) * self.stride[0] - 2 * self.padding[0] + self.kernel_size[0])
251
+
252
+ return F.conv_transpose2d(DY, weight, stride=self.stride, padding=self.padding, output_padding=output_padding)
253
+
254
+ def relprop(self, R, alpha):
255
+ if self.X.shape[1] == 3:
256
+ pw = torch.clamp(self.weight, min=0)
257
+ nw = torch.clamp(self.weight, max=0)
258
+ X = self.X
259
+ L = self.X * 0 + \
260
+ torch.min(torch.min(torch.min(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
261
+ keepdim=True)[0]
262
+ H = self.X * 0 + \
263
+ torch.max(torch.max(torch.max(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
264
+ keepdim=True)[0]
265
+ Za = torch.conv2d(X, self.weight, bias=None, stride=self.stride, padding=self.padding) - \
266
+ torch.conv2d(L, pw, bias=None, stride=self.stride, padding=self.padding) - \
267
+ torch.conv2d(H, nw, bias=None, stride=self.stride, padding=self.padding) + 1e-9
268
+
269
+ S = R / Za
270
+ C = X * self.gradprop2(S, self.weight) - L * self.gradprop2(S, pw) - H * self.gradprop2(S, nw)
271
+ R = C
272
+ else:
273
+ beta = alpha - 1
274
+ pw = torch.clamp(self.weight, min=0)
275
+ nw = torch.clamp(self.weight, max=0)
276
+ px = torch.clamp(self.X, min=0)
277
+ nx = torch.clamp(self.X, max=0)
278
+
279
+ def f(w1, w2, x1, x2):
280
+ Z1 = F.conv2d(x1, w1, bias=None, stride=self.stride, padding=self.padding)
281
+ Z2 = F.conv2d(x2, w2, bias=None, stride=self.stride, padding=self.padding)
282
+ S1 = safe_divide(R, Z1)
283
+ S2 = safe_divide(R, Z2)
284
+ C1 = x1 * self.gradprop(Z1, x1, S1)[0]
285
+ C2 = x2 * self.gradprop(Z2, x2, S2)[0]
286
+ return C1 + C2
287
+
288
+ activator_relevances = f(pw, nw, px, nx)
289
+ inhibitor_relevances = f(nw, pw, px, nx)
290
+
291
+ R = alpha * activator_relevances - beta * inhibitor_relevances
292
+ return R
BERT/BERT_params/boolq.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embeddings": {
3
+ "embedding_file": "model_components/glove.6B.200d.txt",
4
+ "dropout": 0.05
5
+ },
6
+ "evidence_identifier": {
7
+ "mlp_size": 128,
8
+ "dropout": 0.2,
9
+ "batch_size": 768,
10
+ "epochs": 50,
11
+ "patience": 10,
12
+ "lr": 1e-3,
13
+ "sampling_method": "random",
14
+ "sampling_ratio": 1.0
15
+ },
16
+ "evidence_classifier": {
17
+ "classes": [ "False", "True" ],
18
+ "mlp_size": 128,
19
+ "dropout": 0.2,
20
+ "batch_size": 768,
21
+ "epochs": 50,
22
+ "patience": 10,
23
+ "lr": 1e-3,
24
+ "sampling_method": "everything"
25
+ }
26
+ }
BERT/BERT_params/boolq_baas.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "start_server": 0,
3
+ "bert_dir": "model_components/uncased_L-12_H-768_A-12/",
4
+ "max_length": 512,
5
+ "pooling_strategy": "CLS_TOKEN",
6
+ "evidence_identifier": {
7
+ "batch_size": 64,
8
+ "epochs": 3,
9
+ "patience": 10,
10
+ "lr": 1e-3,
11
+ "max_grad_norm": 1.0,
12
+ "sampling_method": "random",
13
+ "sampling_ratio": 1.0
14
+ },
15
+ "evidence_classifier": {
16
+ "classes": [ "False", "True" ],
17
+ "batch_size": 64,
18
+ "epochs": 3,
19
+ "patience": 10,
20
+ "lr": 1e-3,
21
+ "max_grad_norm": 1.0,
22
+ "sampling_method": "everything"
23
+ }
24
+ }
25
+
26
+
BERT/BERT_params/boolq_bert.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "max_length": 512,
3
+ "bert_vocab": "bert-base-uncased",
4
+ "bert_dir": "bert-base-uncased",
5
+ "use_evidence_sentence_identifier": 1,
6
+ "use_evidence_token_identifier": 0,
7
+ "evidence_identifier": {
8
+ "batch_size": 10,
9
+ "epochs": 10,
10
+ "patience": 10,
11
+ "warmup_steps": 50,
12
+ "lr": 1e-05,
13
+ "max_grad_norm": 1,
14
+ "sampling_method": "random",
15
+ "sampling_ratio": 1,
16
+ "use_half_precision": 0
17
+ },
18
+ "evidence_classifier": {
19
+ "classes": [
20
+ "False",
21
+ "True"
22
+ ],
23
+ "batch_size": 10,
24
+ "warmup_steps": 50,
25
+ "epochs": 10,
26
+ "patience": 10,
27
+ "lr": 1e-05,
28
+ "max_grad_norm": 1,
29
+ "sampling_method": "everything",
30
+ "use_half_precision": 0
31
+ }
32
+ }
BERT/BERT_params/boolq_soft.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embeddings": {
3
+ "embedding_file": "model_components/glove.6B.200d.txt",
4
+ "dropout": 0.2
5
+ },
6
+ "classifier": {
7
+ "classes": [ "False", "True" ],
8
+ "has_query": 1,
9
+ "hidden_size": 32,
10
+ "mlp_size": 128,
11
+ "dropout": 0.2,
12
+ "batch_size": 16,
13
+ "epochs": 50,
14
+ "attention_epochs": 50,
15
+ "patience": 10,
16
+ "lr": 1e-3,
17
+ "dropout": 0.2,
18
+ "k_fraction": 0.07,
19
+ "threshold": 0.1
20
+ }
21
+ }
BERT/BERT_params/cose_bert.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "max_length": 512,
3
+ "bert_vocab": "bert-base-uncased",
4
+ "bert_dir": "bert-base-uncased",
5
+ "use_evidence_sentence_identifier": 0,
6
+ "use_evidence_token_identifier": 1,
7
+ "evidence_token_identifier": {
8
+ "batch_size": 32,
9
+ "epochs": 10,
10
+ "patience": 10,
11
+ "warmup_steps": 10,
12
+ "lr": 1e-05,
13
+ "max_grad_norm": 0.5,
14
+ "sampling_method": "everything",
15
+ "use_half_precision": 0,
16
+ "cose_data_hack": 1
17
+ },
18
+ "evidence_classifier": {
19
+ "classes": [ "false", "true"],
20
+ "batch_size": 32,
21
+ "warmup_steps": 10,
22
+ "epochs": 10,
23
+ "patience": 10,
24
+ "lr": 1e-05,
25
+ "max_grad_norm": 0.5,
26
+ "sampling_method": "everything",
27
+ "use_half_precision": 0,
28
+ "cose_data_hack": 1
29
+ }
30
+ }
BERT/BERT_params/cose_multiclass.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "max_length": 512,
3
+ "bert_vocab": "bert-base-uncased",
4
+ "bert_dir": "bert-base-uncased",
5
+ "use_evidence_sentence_identifier": 1,
6
+ "use_evidence_token_identifier": 0,
7
+ "evidence_identifier": {
8
+ "batch_size": 32,
9
+ "epochs": 10,
10
+ "patience": 10,
11
+ "warmup_steps": 50,
12
+ "lr": 1e-05,
13
+ "max_grad_norm": 1,
14
+ "sampling_method": "random",
15
+ "sampling_ratio": 1,
16
+ "use_half_precision": 0
17
+ },
18
+ "evidence_classifier": {
19
+ "classes": [
20
+ "A",
21
+ "B",
22
+ "C",
23
+ "D",
24
+ "E"
25
+ ],
26
+ "batch_size": 10,
27
+ "warmup_steps": 50,
28
+ "epochs": 10,
29
+ "patience": 10,
30
+ "lr": 1e-05,
31
+ "max_grad_norm": 1,
32
+ "sampling_method": "everything",
33
+ "use_half_precision": 0
34
+ }
35
+ }
BERT/BERT_params/esnli_bert.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "max_length": 512,
3
+ "bert_vocab": "bert-base-uncased",
4
+ "bert_dir": "bert-base-uncased",
5
+ "use_evidence_sentence_identifier": 0,
6
+ "use_evidence_token_identifier": 1,
7
+ "evidence_token_identifier": {
8
+ "batch_size": 32,
9
+ "epochs": 10,
10
+ "patience": 10,
11
+ "warmup_steps": 10,
12
+ "lr": 1e-05,
13
+ "max_grad_norm": 1,
14
+ "sampling_method": "everything",
15
+ "use_half_precision": 0
16
+ },
17
+ "evidence_classifier": {
18
+ "classes": [ "contradiction", "neutral", "entailment" ],
19
+ "batch_size": 32,
20
+ "warmup_steps": 10,
21
+ "epochs": 10,
22
+ "patience": 10,
23
+ "lr": 1e-05,
24
+ "max_grad_norm": 1,
25
+ "sampling_method": "everything",
26
+ "use_half_precision": 0
27
+ }
28
+ }
BERT/BERT_params/evidence_inference.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embeddings": {
3
+ "embedding_file": "model_components/PubMed-w2v.bin",
4
+ "dropout": 0.05
5
+ },
6
+ "evidence_identifier": {
7
+ "mlp_size": 128,
8
+ "dropout": 0.05,
9
+ "batch_size": 768,
10
+ "epochs": 50,
11
+ "patience": 10,
12
+ "lr": 1e-3,
13
+ "sampling_method": "random",
14
+ "sampling_ratio": 1.0
15
+ },
16
+ "evidence_classifier": {
17
+ "classes": [ "significantly decreased", "no significant difference", "significantly increased" ],
18
+ "mlp_size": 128,
19
+ "dropout": 0.05,
20
+ "batch_size": 768,
21
+ "epochs": 50,
22
+ "patience": 10,
23
+ "lr": 1e-3,
24
+ "sampling_method": "everything"
25
+ }
26
+ }
BERT/BERT_params/evidence_inference_bert.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "max_length": 512,
3
+ "bert_vocab": "allenai/scibert_scivocab_uncased",
4
+ "bert_dir": "allenai/scibert_scivocab_uncased",
5
+ "use_evidence_sentence_identifier": 1,
6
+ "use_evidence_token_identifier": 0,
7
+ "evidence_identifier": {
8
+ "batch_size": 10,
9
+ "epochs": 10,
10
+ "patience": 10,
11
+ "warmup_steps": 10,
12
+ "lr": 1e-05,
13
+ "max_grad_norm": 1,
14
+ "sampling_method": "random",
15
+ "use_half_precision": 0,
16
+ "sampling_ratio": 1
17
+ },
18
+ "evidence_classifier": {
19
+ "classes": [
20
+ "significantly decreased",
21
+ "no significant difference",
22
+ "significantly increased"
23
+ ],
24
+ "batch_size": 10,
25
+ "warmup_steps": 10,
26
+ "epochs": 10,
27
+ "patience": 10,
28
+ "lr": 1e-05,
29
+ "max_grad_norm": 1,
30
+ "sampling_method": "everything",
31
+ "use_half_precision": 0
32
+ }
33
+ }
BERT/BERT_params/evidence_inference_soft.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embeddings": {
3
+ "embedding_file": "model_components/PubMed-w2v.bin",
4
+ "dropout": 0.2
5
+ },
6
+ "classifier": {
7
+ "classes": [ "significantly decreased", "no significant difference", "significantly increased" ],
8
+ "use_token_selection": 1,
9
+ "has_query": 1,
10
+ "hidden_size": 32,
11
+ "mlp_size": 128,
12
+ "dropout": 0.2,
13
+ "batch_size": 16,
14
+ "epochs": 50,
15
+ "attention_epochs": 0,
16
+ "patience": 10,
17
+ "lr": 1e-3,
18
+ "dropout": 0.2,
19
+ "k_fraction": 0.013,
20
+ "threshold": 0.1
21
+ }
22
+ }
BERT/BERT_params/fever.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embeddings": {
3
+ "embedding_file": "model_components/glove.6B.200d.txt",
4
+ "dropout": 0.05
5
+ },
6
+ "evidence_identifier": {
7
+ "mlp_size": 128,
8
+ "dropout": 0.05,
9
+ "batch_size": 768,
10
+ "epochs": 50,
11
+ "patience": 10,
12
+ "lr": 1e-3,
13
+ "sampling_method": "random",
14
+ "sampling_ratio": 1.0
15
+ },
16
+ "evidence_classifier": {
17
+ "classes": [ "SUPPORTS", "REFUTES" ],
18
+ "mlp_size": 128,
19
+ "dropout": 0.05,
20
+ "batch_size": 768,
21
+ "epochs": 50,
22
+ "patience": 10,
23
+ "lr": 1e-5,
24
+ "sampling_method": "everything"
25
+ }
26
+ }
BERT/BERT_params/fever_baas.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "start_server": 0,
3
+ "bert_dir": "model_components/uncased_L-12_H-768_A-12/",
4
+ "max_length": 512,
5
+ "pooling_strategy": "CLS_TOKEN",
6
+ "evidence_identifier": {
7
+ "batch_size": 64,
8
+ "epochs": 3,
9
+ "patience": 10,
10
+ "lr": 1e-3,
11
+ "max_grad_norm": 1.0,
12
+ "sampling_method": "random",
13
+ "sampling_ratio": 1.0
14
+ },
15
+ "evidence_classifier": {
16
+ "classes": [ "SUPPORTS", "REFUTES" ],
17
+ "batch_size": 64,
18
+ "epochs": 3,
19
+ "patience": 10,
20
+ "lr": 1e-3,
21
+ "max_grad_norm": 1.0,
22
+ "sampling_method": "everything"
23
+ }
24
+ }
25
+
BERT/BERT_params/fever_bert.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "max_length": 512,
3
+ "bert_vocab": "bert-base-uncased",
4
+ "bert_dir": "bert-base-uncased",
5
+ "use_evidence_sentence_identifier": 1,
6
+ "use_evidence_token_identifier": 0,
7
+ "evidence_identifier": {
8
+ "batch_size": 16,
9
+ "epochs": 10,
10
+ "patience": 10,
11
+ "warmup_steps": 10,
12
+ "lr": 1e-05,
13
+ "max_grad_norm": 1.0,
14
+ "sampling_method": "random",
15
+ "sampling_ratio": 1.0,
16
+ "use_half_precision": 0
17
+ },
18
+ "evidence_classifier": {
19
+ "classes": [
20
+ "SUPPORTS",
21
+ "REFUTES"
22
+ ],
23
+ "batch_size": 10,
24
+ "warmup_steps": 10,
25
+ "epochs": 10,
26
+ "patience": 10,
27
+ "lr": 1e-05,
28
+ "max_grad_norm": 1.0,
29
+ "sampling_method": "everything",
30
+ "use_half_precision": 0
31
+ }
32
+ }
BERT/BERT_params/fever_soft.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embeddings": {
3
+ "embedding_file": "model_components/glove.6B.200d.txt",
4
+ "dropout": 0.2
5
+ },
6
+ "classifier": {
7
+ "classes": [ "SUPPORTS", "REFUTES" ],
8
+ "has_query": 1,
9
+ "hidden_size": 32,
10
+ "mlp_size": 128,
11
+ "dropout": 0.2,
12
+ "batch_size": 128,
13
+ "epochs": 50,
14
+ "attention_epochs": 50,
15
+ "patience": 10,
16
+ "lr": 1e-3,
17
+ "dropout": 0.2,
18
+ "k_fraction": 0.07,
19
+ "threshold": 0.1
20
+ }
21
+ }
BERT/BERT_params/movies.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embeddings": {
3
+ "embedding_file": "model_components/glove.6B.200d.txt",
4
+ "dropout": 0.05
5
+ },
6
+ "evidence_identifier": {
7
+ "mlp_size": 128,
8
+ "dropout": 0.05,
9
+ "batch_size": 768,
10
+ "epochs": 50,
11
+ "patience": 10,
12
+ "lr": 1e-4,
13
+ "sampling_method": "random",
14
+ "sampling_ratio": 1.0
15
+ },
16
+ "evidence_classifier": {
17
+ "classes": [ "NEG", "POS" ],
18
+ "mlp_size": 128,
19
+ "dropout": 0.05,
20
+ "batch_size": 768,
21
+ "epochs": 50,
22
+ "patience": 10,
23
+ "lr": 1e-3,
24
+ "sampling_method": "everything"
25
+ }
26
+ }
BERT/BERT_params/movies_baas.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "start_server": 0,
3
+ "bert_dir": "model_components/uncased_L-12_H-768_A-12/",
4
+ "max_length": 512,
5
+ "pooling_strategy": "CLS_TOKEN",
6
+ "evidence_identifier": {
7
+ "batch_size": 64,
8
+ "epochs": 3,
9
+ "patience": 10,
10
+ "lr": 1e-3,
11
+ "max_grad_norm": 1.0,
12
+ "sampling_method": "random",
13
+ "sampling_ratio": 1.0
14
+ },
15
+ "evidence_classifier": {
16
+ "classes": [ "NEG", "POS" ],
17
+ "batch_size": 64,
18
+ "epochs": 3,
19
+ "patience": 10,
20
+ "lr": 1e-3,
21
+ "max_grad_norm": 1.0,
22
+ "sampling_method": "everything"
23
+ }
24
+ }
25
+
26
+
BERT/BERT_params/movies_bert.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "max_length": 512,
3
+ "bert_vocab": "bert-base-uncased",
4
+ "bert_dir": "bert-base-uncased",
5
+ "use_evidence_sentence_identifier": 1,
6
+ "use_evidence_token_identifier": 0,
7
+ "evidence_identifier": {
8
+ "batch_size": 16,
9
+ "epochs": 10,
10
+ "patience": 10,
11
+ "warmup_steps": 50,
12
+ "lr": 1e-05,
13
+ "max_grad_norm": 1,
14
+ "sampling_method": "random",
15
+ "sampling_ratio": 1,
16
+ "use_half_precision": 0
17
+ },
18
+ "evidence_classifier": {
19
+ "classes": [
20
+ "NEG",
21
+ "POS"
22
+ ],
23
+ "batch_size": 10,
24
+ "warmup_steps": 50,
25
+ "epochs": 10,
26
+ "patience": 10,
27
+ "lr": 1e-05,
28
+ "max_grad_norm": 1,
29
+ "sampling_method": "everything",
30
+ "use_half_precision": 0
31
+ }
32
+ }
BERT/BERT_params/movies_soft.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embeddings": {
3
+ "embedding_file": "model_components/glove.6B.200d.txt",
4
+ "dropout": 0.2
5
+ },
6
+ "classifier": {
7
+ "classes": [ "NEG", "POS" ],
8
+ "has_query": 0,
9
+ "hidden_size": 32,
10
+ "mlp_size": 128,
11
+ "dropout": 0.2,
12
+ "batch_size": 16,
13
+ "epochs": 50,
14
+ "attention_epochs": 50,
15
+ "patience": 10,
16
+ "lr": 1e-3,
17
+ "dropout": 0.2,
18
+ "k_fraction": 0.07,
19
+ "threshold": 0.1
20
+ }
21
+ }
BERT/BERT_params/multirc.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embeddings": {
3
+ "embedding_file": "model_components/glove.6B.200d.txt",
4
+ "dropout": 0.05
5
+ },
6
+ "evidence_identifier": {
7
+ "mlp_size": 128,
8
+ "dropout": 0.05,
9
+ "batch_size": 768,
10
+ "epochs": 50,
11
+ "patience": 10,
12
+ "lr": 1e-3,
13
+ "sampling_method": "random",
14
+ "sampling_ratio": 1.0
15
+ },
16
+ "evidence_classifier": {
17
+ "classes": [ "False", "True" ],
18
+ "mlp_size": 128,
19
+ "dropout": 0.05,
20
+ "batch_size": 768,
21
+ "epochs": 50,
22
+ "patience": 10,
23
+ "lr": 1e-3,
24
+ "sampling_method": "everything"
25
+ }
26
+ }
BERT/BERT_params/multirc_baas.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "start_server": 0,
3
+ "bert_dir": "model_components/uncased_L-12_H-768_A-12/",
4
+ "max_length": 512,
5
+ "pooling_strategy": "CLS_TOKEN",
6
+ "evidence_identifier": {
7
+ "batch_size": 64,
8
+ "epochs": 3,
9
+ "patience": 10,
10
+ "lr": 1e-3,
11
+ "max_grad_norm": 1.0,
12
+ "sampling_method": "random",
13
+ "sampling_ratio": 1.0
14
+ },
15
+ "evidence_classifier": {
16
+ "classes": [ "False", "True" ],
17
+ "batch_size": 64,
18
+ "epochs": 3,
19
+ "patience": 10,
20
+ "lr": 1e-3,
21
+ "max_grad_norm": 1.0,
22
+ "sampling_method": "everything"
23
+ }
24
+ }
25
+
26
+
BERT/BERT_params/multirc_bert.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "max_length": 512,
3
+ "bert_vocab": "bert-base-uncased",
4
+ "bert_dir": "bert-base-uncased",
5
+ "use_evidence_sentence_identifier": 1,
6
+ "use_evidence_token_identifier": 0,
7
+ "evidence_identifier": {
8
+ "batch_size": 32,
9
+ "epochs": 10,
10
+ "patience": 10,
11
+ "warmup_steps": 50,
12
+ "lr": 1e-05,
13
+ "max_grad_norm": 1,
14
+ "sampling_method": "random",
15
+ "sampling_ratio": 1,
16
+ "use_half_precision": 0
17
+ },
18
+ "evidence_classifier": {
19
+ "classes": [
20
+ "False",
21
+ "True"
22
+ ],
23
+ "batch_size": 32,
24
+ "warmup_steps": 50,
25
+ "epochs": 10,
26
+ "patience": 10,
27
+ "lr": 1e-05,
28
+ "max_grad_norm": 1,
29
+ "sampling_method": "everything",
30
+ "use_half_precision": 0
31
+ }
32
+ }
BERT/BERT_params/multirc_soft.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embeddings": {
3
+ "embedding_file": "model_components/glove.6B.200d.txt",
4
+ "dropout": 0.2
5
+ },
6
+ "classifier": {
7
+ "classes": [ "False", "True" ],
8
+ "has_query": 1,
9
+ "hidden_size": 32,
10
+ "mlp_size": 128,
11
+ "dropout": 0.2,
12
+ "batch_size": 16,
13
+ "epochs": 50,
14
+ "attention_epochs": 50,
15
+ "patience": 10,
16
+ "lr": 1e-3,
17
+ "dropout": 0.2,
18
+ "k_fraction": 0.07,
19
+ "threshold": 0.1
20
+ }
21
+ }
BERT/BERT_rationale_benchmark/__init__.py ADDED
File without changes