explain-BERT-twoclass / generic.py
WwYc's picture
Update generic.py
c2a42a4 verified
import os
import sys
sys.path.append('BERT')
from transformers import BertTokenizer
from BERT_explainability.modules.BERT.ExplanationGenerator import Generator
from BERT_explainability.modules.BERT.BertForSequenceClassification import BertForSequenceClassification
from transformers import AutoTokenizer
from captum.attr import visualization
import torch
from sequenceoutput.modeling_output import SequenceClassifierOutput
model = BertForSequenceClassification.from_pretrained("./BERT/BERT_weight")
model.eval()
tokenizer = AutoTokenizer.from_pretrained("./BERT/BERT_weight")
# initialize the explanations generator
explanations = Generator(model)
classifications = ["NEGATIVE", "POSITIVE"]
true_class = 1
def generate_visual(text_batch, target_class):
encoding = tokenizer(text_batch, return_tensors='pt')
input_ids = encoding['input_ids']
attention_mask = encoding['attention_mask']
expl = \
explanations.generate_LRP(input_ids=input_ids, attention_mask=attention_mask, start_layer=11,
index=target_class)[0]
expl = (expl - expl.min()) / (expl.max() - expl.min())
output = torch.nn.functional.softmax(model(input_ids=input_ids, attention_mask=attention_mask)[0], dim=-1)
classification = output.argmax(dim=-1).item()
class_name = classifications[target_class]
if class_name == "NEGATIVE":
expl *= (-1)
token_importance = {}
tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten())
for i in range(len(tokens)):
token_importance[tokens[i]] = round(expl[i].item(), 3)
vis_data_records = [visualization.VisualizationDataRecord(
expl,
output[0][classification],
classification,
true_class,
true_class,
1,
tokens,
1)]
html_page = visualization.visualize_text(vis_data_records)
return token_importance, html_page.data