File size: 4,492 Bytes
c443e62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import streamlit as st
import matplotlib.pyplot as plt
import pandas as pd
import torch
from transformers import AutoConfig

# Page configuration
st.set_page_config(
    page_title="Transformer Visualizer",
    page_icon="🧠",
    layout="wide",
    initial_sidebar_state="expanded"
)

# Custom CSS styling
st.markdown("""
<style>
    .reportview-container {
        background: linear-gradient(45deg, #1a1a1a, #4a4a4a);
    }
    .sidebar .sidebar-content {
        background: #2c2c2c !important;
    }
    h1, h2, h3, h4, h5, h6 {
        color: #00ff00 !important;
    }
    .stMetric {
        background-color: #333333;
        border-radius: 10px;
        padding: 15px;
    }
</style>
""", unsafe_allow_html=True)

# Model database
MODELS = {
    "BERT": {"model_name": "bert-base-uncased", "type": "Encoder", "layers": 12, "heads": 12, "params": 109.48},
    "GPT-2": {"model_name": "gpt2", "type": "Decoder", "layers": 12, "heads": 12, "params": 117},
    "T5-Small": {"model_name": "t5-small", "type": "Seq2Seq", "layers": 6, "heads": 8, "params": 60},
    "RoBERTa": {"model_name": "roberta-base", "type": "Encoder", "layers": 12, "heads": 12, "params": 125},
    "DistilBERT": {"model_name": "distilbert-base-uncased", "type": "Encoder", "layers": 6, "heads": 12, "params": 66},
    "ALBERT": {"model_name": "albert-base-v2", "type": "Encoder", "layers": 12, "heads": 12, "params": 11.8},
    "ELECTRA": {"model_name": "google/electra-small-discriminator", "type": "Encoder", "layers": 12, "heads": 12, "params": 13.5},
    "XLNet": {"model_name": "xlnet-base-cased", "type": "AutoRegressive", "layers": 12, "heads": 12, "params": 110},
    "BART": {"model_name": "facebook/bart-base", "type": "Seq2Seq", "layers": 6, "heads": 16, "params": 139},
    "DeBERTa": {"model_name": "microsoft/deberta-base", "type": "Encoder", "layers": 12, "heads": 12, "params": 139}
}

def get_model_config(model_name):
    config = AutoConfig.from_pretrained(MODELS[model_name]["model_name"])
    return config

def plot_model_comparison(selected_model):
    model_names = list(MODELS.keys())
    params = [m["params"] for m in MODELS.values()]
    
    fig, ax = plt.subplots(figsize=(10, 6))
    bars = ax.bar(model_names, params)
    
    # Highlight selected model
    index = list(MODELS.keys()).index(selected_model)
    bars[index].set_color('#00ff00')
    
    ax.set_ylabel('Parameters (Millions)', color='white')
    ax.set_title('Model Size Comparison', color='white')
    ax.tick_params(axis='x', rotation=45, colors='white')
    ax.tick_params(axis='y', colors='white')
    ax.set_facecolor('#2c2c2c')
    fig.patch.set_facecolor('#2c2c2c')
    
    st.pyplot(fig)

def visualize_attention_patterns():
    # Simplified attention patterns visualization
    fig, ax = plt.subplots(figsize=(8, 6))
    data = torch.randn(5, 5)
    ax.imshow(data, cmap='viridis')
    ax.set_title('Attention Patterns Example', color='white')
    ax.set_facecolor('#2c2c2c')
    fig.patch.set_facecolor('#2c2c2c')
    st.pyplot(fig)

def main():
    st.title("🧠 Transformer Model Visualizer")
    
    # Model selection
    selected_model = st.sidebar.selectbox("Select Model", list(MODELS.keys()))
    
    # Model details
    model_info = MODELS[selected_model]
    config = get_model_config(selected_model)
    
    # Display metrics
    col1, col2, col3, col4 = st.columns(4)
    with col1:
        st.metric("Model Type", model_info["type"])
    with col2:
        st.metric("Layers", model_info["layers"])
    with col3:
        st.metric("Attention Heads", model_info["heads"])
    with col4:
        st.metric("Parameters", f"{model_info['params']}M")
    
    # Visualization tabs
    tab1, tab2, tab3 = st.tabs(["Model Structure", "Comparison", "Model Specific"])
    
    with tab1:
        st.subheader("Architecture Diagram")
        st.image("https://upload.wikimedia.org/wikipedia/commons/thumb/8/8a/Transformer_model.svg/1200px-Transformer_model.svg.png", 
                use_container_width=True)  # Changed parameter here
    
    with tab2:
        st.subheader("Model Size Comparison")
        plot_model_comparison(selected_model)
    
    with tab3:
        st.subheader("Model-specific Visualizations")
        visualize_attention_patterns()
        if selected_model == "BERT":
            st.write("BERT-specific visualization example")
        elif selected_model == "GPT-2":
            st.write("GPT-2 attention mask visualization")

if __name__ == "__main__":
    main()