marianvd-01 commited on
Commit
76269ea
·
verified ·
1 Parent(s): 97e9387

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -41
app.py CHANGED
@@ -1,45 +1,69 @@
1
  # app.py
2
-
3
  import gradio as gr
4
- from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
5
- import torch
6
- import matplotlib.pyplot as plt
7
- import numpy as np
8
-
9
- # Load some default model
10
- MODEL_NAME = "bert-base-uncased"
11
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
12
- model = AutoModel.from_pretrained(MODEL_NAME, output_attentions=True)
13
-
14
- def visualize_attention(text):
15
- inputs = tokenizer(text, return_tensors="pt")
16
- outputs = model(**inputs)
17
-
18
- # Grab attentions from output
19
- attentions = outputs.attentions # List of (num_layers, batch, num_heads, seq_len, seq_len)
20
- tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
21
-
22
- fig, ax = plt.subplots(figsize=(8, 6))
23
- # Just visualize attention from last layer, first head
24
- attn_matrix = attentions[-1][0][0].detach().numpy()
25
-
26
- cax = ax.matshow(attn_matrix, cmap='viridis')
27
- fig.colorbar(cax)
28
-
29
- ax.set_xticks(range(len(tokens)))
30
- ax.set_yticks(range(len(tokens)))
31
- ax.set_xticklabels(tokens, rotation=90)
32
- ax.set_yticklabels(tokens)
33
- ax.set_title("Attention Map - Last Layer, Head 1")
34
-
35
- return fig
36
-
37
- iface = gr.Interface(
38
- fn=visualize_attention,
39
- inputs=gr.Textbox(lines=2, placeholder="Enter your text here..."),
40
- outputs=gr.Plot(),
41
- title="🧠 Transformer Attention Visualizer",
42
- description="Visualizes the self-attention of the BERT model's last layer."
43
  )
44
 
45
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # app.py
 
2
  import gradio as gr
3
+ from model_utils import load_model_info, get_model_stats
4
+ from visualize import (
5
+ visualize_attention,
6
+ visualize_token_embeddings,
7
+ plot_tokenization,
8
+ compare_model_sizes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  )
10
 
11
+ MODEL_CHOICES = {
12
+ "BERT (base)": "bert-base-uncased",
13
+ "DistilBERT": "distilbert-base-uncased",
14
+ "RoBERTa": "roberta-base",
15
+ "GPT-2": "gpt2",
16
+ "Electra": "google/electra-base-discriminator",
17
+ "ALBERT": "albert-base-v2",
18
+ "XLNet": "xlnet-base-cased"
19
+ }
20
+
21
+ def run_visualizer(model_name, text, layer, head):
22
+ model_info = load_model_info(model_name)
23
+ attention_plot = visualize_attention(model_info, text, layer, head)
24
+ token_heatmap = visualize_token_embeddings(model_info, text)
25
+ token_plot = plot_tokenization(model_info, text)
26
+ model_stats = get_model_stats(model_info)
27
+
28
+ return attention_plot, token_heatmap, token_plot, model_stats
29
+
30
+ def run_comparison_chart():
31
+ return compare_model_sizes(MODEL_CHOICES.values())
32
+
33
+ with gr.Blocks() as demo:
34
+ gr.Markdown("""
35
+ # 🤖 Transformer Model Visualizer
36
+ Explore attention heads, token embeddings, and tokenizer behavior across popular transformer models.
37
+ """)
38
+
39
+ with gr.Row():
40
+ model_selector = gr.Dropdown(label="Choose Model", choices=list(MODEL_CHOICES.keys()), value="BERT (base)")
41
+ input_text = gr.Textbox(label="Input Text", placeholder="Enter text to analyze")
42
+
43
+ with gr.Row():
44
+ layer_slider = gr.Slider(minimum=0, maximum=11, step=1, value=0, label="Layer")
45
+ head_slider = gr.Slider(minimum=0, maximum=11, step=1, value=0, label="Attention Head")
46
+
47
+ run_btn = gr.Button("Run Analysis")
48
+
49
+ with gr.Row():
50
+ attention_output = gr.Plot(label="Self-Attention Visualization")
51
+ embedding_output = gr.Plot(label="Token Embedding Heatmap")
52
+
53
+ with gr.Row():
54
+ token_output = gr.Plot(label="Tokenization Overview")
55
+ model_output = gr.JSON(label="Model Details")
56
+
57
+ run_btn.click(
58
+ fn=run_visualizer,
59
+ inputs=[model_selector, input_text, layer_slider, head_slider],
60
+ outputs=[attention_output, embedding_output, token_output, model_output]
61
+ )
62
+
63
+ with gr.Accordion("📊 Compare Model Sizes", open=False):
64
+ compare_btn = gr.Button("Generate Comparison Chart")
65
+ comparison_output = gr.Plot()
66
+ compare_btn.click(fn=run_comparison_chart, outputs=comparison_output)
67
+
68
+ demo.launch()
69
+