Spaces:
Running
Running
import plotly.graph_objects as go | |
import numpy as np | |
def list_supported_models(task): | |
if task == "Text Classification": | |
return ["distilbert-base-uncased", "bert-base-uncased", "roberta-base"] | |
elif task == "Text Generation": | |
return ["gpt2", "distilgpt2"] | |
elif task == "Question Answering": | |
return ["deepset/roberta-base-squad2", "distilbert-base-cased-distilled-squad"] | |
return [] | |
def visualize_attention(attentions, tokenizer, inputs): | |
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) | |
last_layer_attention = attentions[-1][0] # shape: [num_heads, seq_len, seq_len] | |
avg_attention = last_layer_attention.mean(dim=0).detach().numpy() | |
fig = go.Figure(data=go.Heatmap( | |
z=avg_attention, | |
x=tokens, | |
y=tokens, | |
colorscale='Viridis' | |
)) | |
fig.update_layout(title="Average Attention - Last Layer", xaxis_nticks=len(tokens)) | |
return fig | |