File size: 2,273 Bytes
1f1df39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
from transformers import BertTokenizer
from BERT_explainability.modules.BERT.ExplanationGenerator import Generator
from BERT_explainability.modules.BERT.BertForSequenceClassification import BertForSequenceClassification
from transformers import BertTokenizer
from BERT_explainability.modules.BERT.ExplanationGenerator import Generator
from transformers import AutoTokenizer
from captum.attr import visualization
import spacy
import torch
from IPython.display import Image, HTML, display

from sequenceoutput.modeling_output import SequenceClassifierOutput

model = BertForSequenceClassification.from_pretrained("./BERT").to("cuda")
model.eval()
tokenizer = AutoTokenizer.from_pretrained("./BERT")
# initialize the explanations generator
explanations = Generator(model)

classifications = ["NEGATIVE", "POSITIVE"]

# encode a sentence
text_batch = ["I hate that I love you."]
encoding = tokenizer(text_batch, return_tensors='pt')
input_ids = encoding['input_ids'].to("cuda")
attention_mask = encoding['attention_mask'].to("cuda")

# true class is positive - 1
true_class = 1

# generate an explanation for the input
target_class = 0
expl = \
    explanations.generate_LRP(input_ids=input_ids, attention_mask=attention_mask, start_layer=11, index=target_class)[0]
# normalize scores
expl = (expl - expl.min()) / (expl.max() - expl.min())

# get the model classification
output = torch.nn.functional.softmax(model(input_ids=input_ids, attention_mask=attention_mask)[0], dim=-1)
classification = output.argmax(dim=-1).item()
# get class name
class_name = classifications[target_class]
# if the classification is negative, higher explanation scores are more negative
# flip for visualization
if class_name == "NEGATIVE":
    expl *= (-1)
token_importance = {}
tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten())
for i in range(len(tokens)):
    token_importance[tokens[i]] = expl[i].item()
vis_data_records = [visualization.VisualizationDataRecord(
    expl,
    output[0][classification],
    classification,
    true_class,
    true_class,
    1,
    tokens,
    1)]

html1 = visualization.visualize_text(vis_data_records)
# print(token_importance, html1)
# with open('bert-xai.html', 'w+') as f:
#     f.write(str(html1))
# return token_importance, html1