WwYc commited on
Commit
090a94e
·
verified ·
1 Parent(s): 1f1df39

Create generic.py

Browse files
Files changed (1) hide show
  1. generic.py +58 -0
generic.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from transformers import BertTokenizer
3
+ from BERT_explainability.modules.BERT.ExplanationGenerator import Generator
4
+ from BERT_explainability.modules.BERT.BertForSequenceClassification import BertForSequenceClassification
5
+ from transformers import BertTokenizer
6
+ from BERT_explainability.modules.BERT.ExplanationGenerator import Generator
7
+ from transformers import AutoTokenizer
8
+ from captum.attr import visualization
9
+ import spacy
10
+ import torch
11
+ from IPython.display import Image, HTML, display
12
+ import sys
13
+
14
+ sys.path.append('BERT')
15
+
16
+ from sequenceoutput.modeling_output import SequenceClassifierOutput
17
+
18
+ model = BertForSequenceClassification.from_pretrained("./BERT_weight")
19
+ model.eval()
20
+ tokenizer = AutoTokenizer.from_pretrained("./BERT_weight")
21
+ # initialize the explanations generator
22
+ explanations = Generator(model)
23
+
24
+ classifications = ["NEGATIVE", "POSITIVE"]
25
+ true_class = 1
26
+
27
+
28
+ def generate_visual(text_batch, target_class):
29
+ encoding = tokenizer(text_batch, return_tensors='pt')
30
+ input_ids = encoding['input_ids']
31
+ attention_mask = encoding['attention_mask']
32
+ expl = \
33
+ explanations.generate_LRP(input_ids=input_ids, attention_mask=attention_mask, start_layer=11,
34
+ index=target_class)[0]
35
+ expl = (expl - expl.min()) / (expl.max() - expl.min())
36
+ output = torch.nn.functional.softmax(model(input_ids=input_ids, attention_mask=attention_mask)[0], dim=-1)
37
+ classification = output.argmax(dim=-1).item()
38
+ class_name = classifications[target_class]
39
+ if class_name == "NEGATIVE":
40
+ expl *= (-1)
41
+ token_importance = {}
42
+ tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten())
43
+ for i in range(len(tokens)):
44
+ token_importance[tokens[i]] = expl[i].item()
45
+ vis_data_records = [visualization.VisualizationDataRecord(
46
+ expl,
47
+ output[0][classification],
48
+ classification,
49
+ true_class,
50
+ true_class,
51
+ 1,
52
+ tokens,
53
+ 1)]
54
+
55
+ html1 = visualization.visualize_text(vis_data_records)
56
+ return token_importance, html1
57
+
58
+