File size: 1,910 Bytes
090a94e
7e60a64
 
 
090a94e
 
 
 
 
d338d4f
090a94e
 
 
 
 
8eabb3c
090a94e
8eabb3c
090a94e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2a42a4
090a94e
 
 
 
 
 
 
 
 
8797a1a
 
090a94e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import os
import sys

sys.path.append('BERT')
from transformers import BertTokenizer
from BERT_explainability.modules.BERT.ExplanationGenerator import Generator
from BERT_explainability.modules.BERT.BertForSequenceClassification import BertForSequenceClassification
from transformers import AutoTokenizer
from captum.attr import visualization

import torch


from sequenceoutput.modeling_output import SequenceClassifierOutput

model = BertForSequenceClassification.from_pretrained("./BERT/BERT_weight")
model.eval()
tokenizer = AutoTokenizer.from_pretrained("./BERT/BERT_weight")
# initialize the explanations generator
explanations = Generator(model)

classifications = ["NEGATIVE", "POSITIVE"]
true_class = 1


def generate_visual(text_batch, target_class):
    encoding = tokenizer(text_batch, return_tensors='pt')
    input_ids = encoding['input_ids']
    attention_mask = encoding['attention_mask']
    expl = \
        explanations.generate_LRP(input_ids=input_ids, attention_mask=attention_mask, start_layer=11,
                                  index=target_class)[0]
    expl = (expl - expl.min()) / (expl.max() - expl.min())
    output = torch.nn.functional.softmax(model(input_ids=input_ids, attention_mask=attention_mask)[0], dim=-1)
    classification = output.argmax(dim=-1).item()
    class_name = classifications[target_class]
    if class_name == "NEGATIVE":
        expl *= (-1)
    token_importance = {}
    tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten())
    for i in range(len(tokens)):
        token_importance[tokens[i]] = round(expl[i].item(), 3)
    vis_data_records = [visualization.VisualizationDataRecord(
        expl,
        output[0][classification],
        classification,
        true_class,
        true_class,
        1,
        tokens,
        1)]
    html_page = visualization.visualize_text(vis_data_records)
    return token_importance, html_page.data