Spaces:
Runtime error
Runtime error
import os | |
from transformers import BertTokenizer | |
from BERT_explainability.modules.BERT.ExplanationGenerator import Generator | |
from BERT_explainability.modules.BERT.BertForSequenceClassification import BertForSequenceClassification | |
from transformers import BertTokenizer | |
from BERT_explainability.modules.BERT.ExplanationGenerator import Generator | |
from transformers import AutoTokenizer | |
from captum.attr import visualization | |
import spacy | |
import torch | |
from IPython.display import Image, HTML, display | |
from sequenceoutput.modeling_output import SequenceClassifierOutput | |
model = BertForSequenceClassification.from_pretrained("./BERT").to("cuda") | |
model.eval() | |
tokenizer = AutoTokenizer.from_pretrained("./BERT") | |
# initialize the explanations generator | |
explanations = Generator(model) | |
classifications = ["NEGATIVE", "POSITIVE"] | |
# encode a sentence | |
text_batch = ["I hate that I love you."] | |
encoding = tokenizer(text_batch, return_tensors='pt') | |
input_ids = encoding['input_ids'].to("cuda") | |
attention_mask = encoding['attention_mask'].to("cuda") | |
# true class is positive - 1 | |
true_class = 1 | |
# generate an explanation for the input | |
target_class = 0 | |
expl = \ | |
explanations.generate_LRP(input_ids=input_ids, attention_mask=attention_mask, start_layer=11, index=target_class)[0] | |
# normalize scores | |
expl = (expl - expl.min()) / (expl.max() - expl.min()) | |
# get the model classification | |
output = torch.nn.functional.softmax(model(input_ids=input_ids, attention_mask=attention_mask)[0], dim=-1) | |
classification = output.argmax(dim=-1).item() | |
# get class name | |
class_name = classifications[target_class] | |
# if the classification is negative, higher explanation scores are more negative | |
# flip for visualization | |
if class_name == "NEGATIVE": | |
expl *= (-1) | |
token_importance = {} | |
tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten()) | |
for i in range(len(tokens)): | |
token_importance[tokens[i]] = expl[i].item() | |
vis_data_records = [visualization.VisualizationDataRecord( | |
expl, | |
output[0][classification], | |
classification, | |
true_class, | |
true_class, | |
1, | |
tokens, | |
1)] | |
html1 = visualization.visualize_text(vis_data_records) | |
# print(token_importance, html1) | |
# with open('bert-xai.html', 'w+') as f: | |
# f.write(str(html1)) | |
# return token_importance, html1 | |