import os from transformers import BertTokenizer from BERT_explainability.modules.BERT.ExplanationGenerator import Generator from BERT_explainability.modules.BERT.BertForSequenceClassification import BertForSequenceClassification from transformers import BertTokenizer from BERT_explainability.modules.BERT.ExplanationGenerator import Generator from transformers import AutoTokenizer from captum.attr import visualization import spacy import torch from IPython.display import Image, HTML, display from sequenceoutput.modeling_output import SequenceClassifierOutput model = BertForSequenceClassification.from_pretrained("./BERT").to("cuda") model.eval() tokenizer = AutoTokenizer.from_pretrained("./BERT") # initialize the explanations generator explanations = Generator(model) classifications = ["NEGATIVE", "POSITIVE"] # encode a sentence text_batch = ["I hate that I love you."] encoding = tokenizer(text_batch, return_tensors='pt') input_ids = encoding['input_ids'].to("cuda") attention_mask = encoding['attention_mask'].to("cuda") # true class is positive - 1 true_class = 1 # generate an explanation for the input target_class = 0 expl = \ explanations.generate_LRP(input_ids=input_ids, attention_mask=attention_mask, start_layer=11, index=target_class)[0] # normalize scores expl = (expl - expl.min()) / (expl.max() - expl.min()) # get the model classification output = torch.nn.functional.softmax(model(input_ids=input_ids, attention_mask=attention_mask)[0], dim=-1) classification = output.argmax(dim=-1).item() # get class name class_name = classifications[target_class] # if the classification is negative, higher explanation scores are more negative # flip for visualization if class_name == "NEGATIVE": expl *= (-1) token_importance = {} tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten()) for i in range(len(tokens)): token_importance[tokens[i]] = expl[i].item() vis_data_records = [visualization.VisualizationDataRecord( expl, output[0][classification], classification, true_class, true_class, 1, tokens, 1)] html1 = visualization.visualize_text(vis_data_records) # print(token_importance, html1) # with open('bert-xai.html', 'w+') as f: # f.write(str(html1)) # return token_importance, html1