WwYc's picture
Upload 69 files
1f1df39 verified
import os
from transformers import BertTokenizer
from BERT_explainability.modules.BERT.ExplanationGenerator import Generator
from BERT_explainability.modules.BERT.BertForSequenceClassification import BertForSequenceClassification
from transformers import BertTokenizer
from BERT_explainability.modules.BERT.ExplanationGenerator import Generator
from transformers import AutoTokenizer
from captum.attr import visualization
import spacy
import torch
from IPython.display import Image, HTML, display
from sequenceoutput.modeling_output import SequenceClassifierOutput
model = BertForSequenceClassification.from_pretrained("./BERT").to("cuda")
model.eval()
tokenizer = AutoTokenizer.from_pretrained("./BERT")
# initialize the explanations generator
explanations = Generator(model)
classifications = ["NEGATIVE", "POSITIVE"]
# encode a sentence
text_batch = ["I hate that I love you."]
encoding = tokenizer(text_batch, return_tensors='pt')
input_ids = encoding['input_ids'].to("cuda")
attention_mask = encoding['attention_mask'].to("cuda")
# true class is positive - 1
true_class = 1
# generate an explanation for the input
target_class = 0
expl = \
explanations.generate_LRP(input_ids=input_ids, attention_mask=attention_mask, start_layer=11, index=target_class)[0]
# normalize scores
expl = (expl - expl.min()) / (expl.max() - expl.min())
# get the model classification
output = torch.nn.functional.softmax(model(input_ids=input_ids, attention_mask=attention_mask)[0], dim=-1)
classification = output.argmax(dim=-1).item()
# get class name
class_name = classifications[target_class]
# if the classification is negative, higher explanation scores are more negative
# flip for visualization
if class_name == "NEGATIVE":
expl *= (-1)
token_importance = {}
tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten())
for i in range(len(tokens)):
token_importance[tokens[i]] = expl[i].item()
vis_data_records = [visualization.VisualizationDataRecord(
expl,
output[0][classification],
classification,
true_class,
true_class,
1,
tokens,
1)]
html1 = visualization.visualize_text(vis_data_records)
# print(token_importance, html1)
# with open('bert-xai.html', 'w+') as f:
# f.write(str(html1))
# return token_importance, html1