|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModel |
|
import torch |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
|
|
def visualize_attention(model_name, sentence): |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModel.from_pretrained(model_name, output_attentions=True) |
|
|
|
inputs = tokenizer(sentence, return_tensors='pt') |
|
outputs = model(**inputs) |
|
attentions = outputs.attentions |
|
|
|
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]) |
|
|
|
fig, ax = plt.subplots(figsize=(10, 8)) |
|
sns.heatmap(attentions[-1][0][0].detach().numpy(), |
|
xticklabels=tokens, |
|
yticklabels=tokens, |
|
cmap="viridis", |
|
ax=ax) |
|
ax.set_title(f"Attention Map - Layer {len(attentions)} Head 1") |
|
plt.xticks(rotation=90) |
|
plt.yticks(rotation=0) |
|
|
|
return fig |
|
|
|
model_list = [ |
|
"bert-base-uncased", |
|
"roberta-base", |
|
"distilbert-base-uncased" |
|
] |
|
|
|
iface = gr.Interface( |
|
fn=visualize_attention, |
|
inputs=[ |
|
gr.Dropdown(choices=model_list, label="Choose Transformer Model"), |
|
gr.Textbox(label="Enter Input Sentence") |
|
], |
|
outputs=gr.Plot(label="Attention Map"), |
|
title="Transformer Attention Visualizer", |
|
description="Visualize attention heads of transformer models. Select a model and input text to see attention heatmaps." |
|
) |
|
|
|
iface.launch() |