Aranwer's picture
Create app.py
d2ec4c0 verified
raw
history blame
1.46 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModel
import torch
import matplotlib.pyplot as plt
import seaborn as sns
def visualize_attention(model_name, sentence):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, output_attentions=True)
inputs = tokenizer(sentence, return_tensors='pt')
outputs = model(**inputs)
attentions = outputs.attentions # tuple of (layer, batch, head, seq_len, seq_len)
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
fig, ax = plt.subplots(figsize=(10, 8))
sns.heatmap(attentions[-1][0][0].detach().numpy(),
xticklabels=tokens,
yticklabels=tokens,
cmap="viridis",
ax=ax)
ax.set_title(f"Attention Map - Layer {len(attentions)} Head 1")
plt.xticks(rotation=90)
plt.yticks(rotation=0)
return fig
model_list = [
"bert-base-uncased",
"roberta-base",
"distilbert-base-uncased"
]
iface = gr.Interface(
fn=visualize_attention,
inputs=[
gr.Dropdown(choices=model_list, label="Choose Transformer Model"),
gr.Textbox(label="Enter Input Sentence")
],
outputs=gr.Plot(label="Attention Map"),
title="Transformer Attention Visualizer",
description="Visualize attention heads of transformer models. Select a model and input text to see attention heatmaps."
)
iface.launch()