marianvd-01's picture
Create app.py
1398c16 verified
raw
history blame
1.41 kB
# app.py
import gradio as gr
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
import torch
import matplotlib.pyplot as plt
import numpy as np
# Load some default model
MODEL_NAME = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME, output_attentions=True)
def visualize_attention(text):
inputs = tokenizer(text, return_tensors="pt")
outputs = model(**inputs)
# Grab attentions from output
attentions = outputs.attentions # List of (num_layers, batch, num_heads, seq_len, seq_len)
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
fig, ax = plt.subplots(figsize=(8, 6))
# Just visualize attention from last layer, first head
attn_matrix = attentions[-1][0][0].detach().numpy()
cax = ax.matshow(attn_matrix, cmap='viridis')
fig.colorbar(cax)
ax.set_xticks(range(len(tokens)))
ax.set_yticks(range(len(tokens)))
ax.set_xticklabels(tokens, rotation=90)
ax.set_yticklabels(tokens)
ax.set_title("Attention Map - Last Layer, Head 1")
return fig
iface = gr.Interface(
fn=visualize_attention,
inputs=gr.Textbox(lines=2, placeholder="Enter your text here..."),
outputs=gr.Plot(),
title="🧠 Transformer Attention Visualizer",
description="Visualizes the self-attention of the BERT model's last layer."
)
iface.launch()