# app.py import gradio as gr from model_utils import load_model_info, get_model_stats from visualize import ( visualize_attention, visualize_token_embeddings, plot_tokenization, compare_model_sizes ) MODEL_CHOICES = { "BERT (base)": "bert-base-uncased", "DistilBERT": "distilbert-base-uncased", "RoBERTa": "roberta-base", "GPT-2": "gpt2", "Electra": "google/electra-base-discriminator", "ALBERT": "albert-base-v2", "XLNet": "xlnet-base-cased" } def run_visualizer(model_name, text, layer, head): model_info = load_model_info(model_name) attention_plot = visualize_attention(model_info, text, layer, head) token_heatmap = visualize_token_embeddings(model_info, text) token_plot = plot_tokenization(model_info, text) model_stats = get_model_stats(model_info) return attention_plot, token_heatmap, token_plot, model_stats def run_comparison_chart(): return compare_model_sizes(MODEL_CHOICES.values()) with gr.Blocks() as demo: gr.Markdown(""" # 🤖 Transformer Model Visualizer Explore attention heads, token embeddings, and tokenizer behavior across popular transformer models. """) with gr.Row(): model_selector = gr.Dropdown(label="Choose Model", choices=list(MODEL_CHOICES.keys()), value="BERT (base)") input_text = gr.Textbox(label="Input Text", placeholder="Enter text to analyze") with gr.Row(): layer_slider = gr.Slider(minimum=0, maximum=11, step=1, value=0, label="Layer") head_slider = gr.Slider(minimum=0, maximum=11, step=1, value=0, label="Attention Head") run_btn = gr.Button("Run Analysis") with gr.Row(): attention_output = gr.Plot(label="Self-Attention Visualization") embedding_output = gr.Plot(label="Token Embedding Heatmap") with gr.Row(): token_output = gr.Plot(label="Tokenization Overview") model_output = gr.JSON(label="Model Details") run_btn.click( fn=run_visualizer, inputs=[model_selector, input_text, layer_slider, head_slider], outputs=[attention_output, embedding_output, token_output, model_output] ) with gr.Accordion("📊 Compare Model Sizes", open=False): compare_btn = gr.Button("Generate Comparison Chart") comparison_output = gr.Plot() compare_btn.click(fn=run_comparison_chart, outputs=comparison_output) demo.launch()