Spaces:
Runtime error
Runtime error
File size: 1,910 Bytes
090a94e 7e60a64 090a94e d338d4f 090a94e 8eabb3c 090a94e 8eabb3c 090a94e c2a42a4 090a94e 8797a1a 090a94e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import os
import sys
sys.path.append('BERT')
from transformers import BertTokenizer
from BERT_explainability.modules.BERT.ExplanationGenerator import Generator
from BERT_explainability.modules.BERT.BertForSequenceClassification import BertForSequenceClassification
from transformers import AutoTokenizer
from captum.attr import visualization
import torch
from sequenceoutput.modeling_output import SequenceClassifierOutput
model = BertForSequenceClassification.from_pretrained("./BERT/BERT_weight")
model.eval()
tokenizer = AutoTokenizer.from_pretrained("./BERT/BERT_weight")
# initialize the explanations generator
explanations = Generator(model)
classifications = ["NEGATIVE", "POSITIVE"]
true_class = 1
def generate_visual(text_batch, target_class):
encoding = tokenizer(text_batch, return_tensors='pt')
input_ids = encoding['input_ids']
attention_mask = encoding['attention_mask']
expl = \
explanations.generate_LRP(input_ids=input_ids, attention_mask=attention_mask, start_layer=11,
index=target_class)[0]
expl = (expl - expl.min()) / (expl.max() - expl.min())
output = torch.nn.functional.softmax(model(input_ids=input_ids, attention_mask=attention_mask)[0], dim=-1)
classification = output.argmax(dim=-1).item()
class_name = classifications[target_class]
if class_name == "NEGATIVE":
expl *= (-1)
token_importance = {}
tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten())
for i in range(len(tokens)):
token_importance[tokens[i]] = round(expl[i].item(), 3)
vis_data_records = [visualization.VisualizationDataRecord(
expl,
output[0][classification],
classification,
true_class,
true_class,
1,
tokens,
1)]
html_page = visualization.visualize_text(vis_data_records)
return token_importance, html_page.data
|