File size: 1,462 Bytes
d2ec4c0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
import gradio as gr
from transformers import AutoTokenizer, AutoModel
import torch
import matplotlib.pyplot as plt
import seaborn as sns
def visualize_attention(model_name, sentence):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, output_attentions=True)
inputs = tokenizer(sentence, return_tensors='pt')
outputs = model(**inputs)
attentions = outputs.attentions # tuple of (layer, batch, head, seq_len, seq_len)
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
fig, ax = plt.subplots(figsize=(10, 8))
sns.heatmap(attentions[-1][0][0].detach().numpy(),
xticklabels=tokens,
yticklabels=tokens,
cmap="viridis",
ax=ax)
ax.set_title(f"Attention Map - Layer {len(attentions)} Head 1")
plt.xticks(rotation=90)
plt.yticks(rotation=0)
return fig
model_list = [
"bert-base-uncased",
"roberta-base",
"distilbert-base-uncased"
]
iface = gr.Interface(
fn=visualize_attention,
inputs=[
gr.Dropdown(choices=model_list, label="Choose Transformer Model"),
gr.Textbox(label="Enter Input Sentence")
],
outputs=gr.Plot(label="Attention Map"),
title="Transformer Attention Visualizer",
description="Visualize attention heads of transformer models. Select a model and input text to see attention heatmaps."
)
iface.launch() |