File size: 3,427 Bytes
d2ec4c0
05f5674
d2ec4c0
 
 
05f5674
 
d2ec4c0
ad92e07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2ec4c0
05f5674
ad92e07
 
 
 
 
 
d2ec4c0
ad92e07
 
d2ec4c0
ad92e07
d2ec4c0
ad92e07
 
 
 
d2ec4c0
 
 
 
 
005c98d
 
ad92e07
05f5674
 
 
 
 
 
 
 
 
ad92e07
 
 
 
05f5674
 
 
ad92e07
 
 
 
 
 
05f5674
 
 
 
 
 
ad92e07
 
05f5674
ad92e07
 
d2ec4c0
 
ad92e07
d2ec4c0
 
 
 
ad92e07
05f5674
 
 
ad92e07
05f5674
 
d2ec4c0
 
05f5674
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gradio as gr
from transformers import AutoTokenizer, AutoModel, GPT2Model
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
import numpy as np

MODEL_INFO = {
    "bert-base-uncased": {
        "Model Type": "BERT",
        "Layers": 12,
        "Attention Heads": 12,
        "Parameters": "109.48M"
    },
    "roberta-base": {
        "Model Type": "RoBERTa",
        "Layers": 12,
        "Attention Heads": 12,
        "Parameters": "125M"
    },
    "distilbert-base-uncased": {
        "Model Type": "DistilBERT",
        "Layers": 6,
        "Attention Heads": 12,
        "Parameters": "66M"
    },
    "gpt2": {
        "Model Type": "GPT-2",
        "Layers": 12,
        "Attention Heads": 12,
        "Parameters": "124M"
    }
}

def visualize_transformer(model_name, sentence):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if "gpt2" in model_name:
        model = GPT2Model.from_pretrained(model_name, output_attentions=True)
        tokenizer.pad_token = tokenizer.eos_token
        inputs = tokenizer(sentence, return_tensors='pt', padding=True)
    else:
        model = AutoModel.from_pretrained(model_name, output_attentions=True)
        inputs = tokenizer(sentence, return_tensors='pt')

    outputs = model(**inputs)
    attentions = outputs.attentions
    tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])

    fig, ax = plt.subplots(figsize=(10, 8))
    sns.heatmap(attentions[-1][0][0].detach().numpy(),
                xticklabels=tokens,
                yticklabels=tokens,
                cmap="viridis",
                ax=ax)
    ax.set_title(f"Attention Map - Layer {len(attentions)} Head 1")
    plt.xticks(rotation=90)
    plt.yticks(rotation=0)

    token_output = [f"{i + 1}: \"{tok}\"" for i, tok in enumerate(tokens)]
    token_output_str = "[\n" + "\n".join(token_output) + "\n]"

    last_hidden_state = outputs.last_hidden_state.detach().numpy()[0]
    pca = PCA(n_components=2)
    reduced = pca.fit_transform(last_hidden_state)
    fig2, ax2 = plt.subplots()
    ax2.scatter(reduced[:, 0], reduced[:, 1])
    for i, token in enumerate(tokens):
        ax2.annotate(token, (reduced[i, 0], reduced[i, 1]))
    ax2.set_title("Token Embedding (PCA Projection)")

    model_info = MODEL_INFO.get(model_name, {})
    details = f"""
πŸ›  Model Details
Model Type: {model_info.get("Model Type", "Unknown")}
Number of Layers: {model_info.get("Layers", "?")}
Number of Attention Heads: {model_info.get("Attention Heads", "?")}
Total Parameters: {model_info.get("Parameters", "?")}

πŸ“Š Tokenization Visualization
Enter Text:
{sentence}
Tokenized Output:
{token_output_str}

πŸ“ˆ Model Size Comparison
- BERT: 109M
- RoBERTa: 125M
- DistilBERT: 66M
- GPT-2: 124M
"""

    return details, fig, fig2

model_list = list(MODEL_INFO.keys())

iface = gr.Interface(
    fn=visualize_transformer,
    inputs=[
        gr.Dropdown(choices=model_list, label="Choose Transformer Model"),
        gr.Textbox(label="Enter Input Sentence")
    ],
    outputs=[
        gr.Textbox(label="🧠 Model + Token Info", lines=25),
        gr.Plot(label="🧩 Attention Map"),
        gr.Plot(label="🧬 Token Embedding (PCA Projection)")
    ],
    title="Transformer Visualization App",
    description="Visualize Transformer models including token embeddings, attention maps, and model information."
)

iface.launch()