Spaces:
Runtime error
Runtime error
import os | |
import sys | |
sys.path.append('BERT') | |
from transformers import BertTokenizer | |
from BERT_explainability.modules.BERT.ExplanationGenerator import Generator | |
from BERT_explainability.modules.BERT.BertForSequenceClassification import BertForSequenceClassification | |
from transformers import AutoTokenizer | |
from captum.attr import visualization | |
import torch | |
from sequenceoutput.modeling_output import SequenceClassifierOutput | |
model = BertForSequenceClassification.from_pretrained("./BERT/BERT_weight") | |
model.eval() | |
tokenizer = AutoTokenizer.from_pretrained("./BERT/BERT_weight") | |
# initialize the explanations generator | |
explanations = Generator(model) | |
classifications = ["NEGATIVE", "POSITIVE"] | |
true_class = 1 | |
def generate_visual(text_batch, target_class): | |
encoding = tokenizer(text_batch, return_tensors='pt') | |
input_ids = encoding['input_ids'] | |
attention_mask = encoding['attention_mask'] | |
expl = \ | |
explanations.generate_LRP(input_ids=input_ids, attention_mask=attention_mask, start_layer=11, | |
index=target_class)[0] | |
expl = (expl - expl.min()) / (expl.max() - expl.min()) | |
output = torch.nn.functional.softmax(model(input_ids=input_ids, attention_mask=attention_mask)[0], dim=-1) | |
classification = output.argmax(dim=-1).item() | |
class_name = classifications[target_class] | |
if class_name == "NEGATIVE": | |
expl *= (-1) | |
token_importance = {} | |
tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten()) | |
for i in range(len(tokens)): | |
token_importance[tokens[i]] = round(expl[i].item(), 3) | |
vis_data_records = [visualization.VisualizationDataRecord( | |
expl, | |
output[0][classification], | |
classification, | |
true_class, | |
true_class, | |
1, | |
tokens, | |
1)] | |
html_page = visualization.visualize_text(vis_data_records) | |
return token_importance, html_page.data | |