Transformer / utils.py
rahideer's picture
Create utils.py
b36e408 verified
raw
history blame
943 Bytes
import plotly.graph_objects as go
import numpy as np
def list_supported_models(task):
if task == "Text Classification":
return ["distilbert-base-uncased", "bert-base-uncased", "roberta-base"]
elif task == "Text Generation":
return ["gpt2", "distilgpt2"]
elif task == "Question Answering":
return ["deepset/roberta-base-squad2", "distilbert-base-cased-distilled-squad"]
return []
def visualize_attention(attentions, tokenizer, inputs):
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
last_layer_attention = attentions[-1][0] # shape: [num_heads, seq_len, seq_len]
avg_attention = last_layer_attention.mean(dim=0).detach().numpy()
fig = go.Figure(data=go.Heatmap(
z=avg_attention,
x=tokens,
y=tokens,
colorscale='Viridis'
))
fig.update_layout(title="Average Attention - Last Layer", xaxis_nticks=len(tokens))
return fig