Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import torch
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import seaborn as sns
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
import plotly.express as px
|
8 |
+
from sklearn.decomposition import PCA
|
9 |
+
from transformers import AutoModel, AutoTokenizer, pipeline, AutoModelForCausalLM
|
10 |
+
|
11 |
+
# App Title
|
12 |
+
st.title("π Transformer Model Explorer")
|
13 |
+
st.markdown("""
|
14 |
+
Explore different transformer models, their architectures, tokenization, and attention mechanisms.
|
15 |
+
""")
|
16 |
+
|
17 |
+
# Model Selection
|
18 |
+
model_name = st.selectbox(
|
19 |
+
"Choose a Transformer Model:",
|
20 |
+
["bigscience/bloom", "openai/whisper-base", "facebook/wav2vec2-base-960h"]
|
21 |
+
)
|
22 |
+
|
23 |
+
# Load Tokenizer & Model
|
24 |
+
st.write(f"Loading model: `{model_name}`...")
|
25 |
+
if "bloom" in model_name:
|
26 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
27 |
+
model = AutoModel.from_pretrained(model_name)
|
28 |
+
elif "whisper" in model_name:
|
29 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
30 |
+
model = AutoModel.from_pretrained(model_name)
|
31 |
+
elif "wav2vec2" in model_name:
|
32 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
33 |
+
model = AutoModel.from_pretrained(model_name)
|
34 |
+
|
35 |
+
# Display Model Details
|
36 |
+
st.subheader("π Model Details")
|
37 |
+
st.write(f"Model Type: `{model.config.model_type}`")
|
38 |
+
st.write(f"Number of Layers: `{model.config.num_hidden_layers}`")
|
39 |
+
st.write(f"Number of Attention Heads: `{model.config.num_attention_heads if hasattr(model.config, 'num_attention_heads') else 'N/A'}`")
|
40 |
+
st.write(f"Total Parameters: `{sum(p.numel() for p in model.parameters())/1e6:.2f}M`")
|
41 |
+
|
42 |
+
# Model Size Comparison
|
43 |
+
st.subheader("π Model Size Comparison")
|
44 |
+
model_sizes = {
|
45 |
+
"bigscience/bloom": 176,
|
46 |
+
"openai/whisper-base": 74,
|
47 |
+
"facebook/wav2vec2-base-960h": 317
|
48 |
+
}
|
49 |
+
df_size = pd.DataFrame(model_sizes.items(), columns=["Model", "Size (Million Parameters)"])
|
50 |
+
fig = px.bar(df_size, x="Model", y="Size (Million Parameters)", title="Model Size Comparison")
|
51 |
+
st.plotly_chart(fig)
|
52 |
+
|
53 |
+
# Tokenization Section
|
54 |
+
st.subheader("π Tokenization Visualization")
|
55 |
+
input_text = st.text_input("Enter Text:", "Hello, how are you?")
|
56 |
+
|
57 |
+
if "whisper" in model_name:
|
58 |
+
st.write("Note: Whisper is an audio model and doesn't use text tokenization")
|
59 |
+
st.write("Instead, it processes raw audio waveforms")
|
60 |
+
else:
|
61 |
+
tokens = tokenizer.tokenize(input_text)
|
62 |
+
st.write("Tokenized Output:", tokens)
|
63 |
+
|
64 |
+
# Token Embeddings Visualization (Fixed PCA Projection)
|
65 |
+
st.subheader("π§© Token Embeddings Visualization")
|
66 |
+
with torch.no_grad():
|
67 |
+
if "whisper" in model_name:
|
68 |
+
st.write("Note: Whisper uses a different embedding structure for audio features")
|
69 |
+
st.write("Cannot directly visualize token embeddings as with text models")
|
70 |
+
else:
|
71 |
+
inputs = tokenizer(input_text, return_tensors="pt")
|
72 |
+
outputs = model(**inputs)
|
73 |
+
if hasattr(outputs, "last_hidden_state"):
|
74 |
+
embeddings = outputs.last_hidden_state.squeeze(0).numpy()
|
75 |
+
# Ensure the number of tokens and embeddings match
|
76 |
+
n_tokens = min(len(tokens), embeddings.shape[0])
|
77 |
+
embeddings = embeddings[:n_tokens] # Trim embeddings to match token count
|
78 |
+
tokens = tokens[:n_tokens] # Trim tokens to match embeddings count
|
79 |
+
pca = PCA(n_components=2)
|
80 |
+
reduced_embeddings = pca.fit_transform(embeddings)
|
81 |
+
df_embeddings = pd.DataFrame(reduced_embeddings, columns=["PCA1", "PCA2"])
|
82 |
+
df_embeddings["Token"] = tokens
|
83 |
+
fig = px.scatter(df_embeddings, x="PCA1", y="PCA2", text="Token",
|
84 |
+
title="Token Embeddings (PCA Projection)")
|
85 |
+
st.plotly_chart(fig)
|
86 |
+
|
87 |
+
# Attention Visualization (for BERT & RoBERTa models)
|
88 |
+
if "bloom" in model_name:
|
89 |
+
st.subheader("π Attention Map")
|
90 |
+
with torch.no_grad():
|
91 |
+
outputs = model(**inputs, output_attentions=True)
|
92 |
+
attention = outputs.attentions[-1].squeeze().detach().numpy()
|
93 |
+
fig, ax = plt.subplots(figsize=(10, 5))
|
94 |
+
sns.heatmap(attention[0], cmap="viridis", xticklabels=tokens, yticklabels=tokens, ax=ax)
|
95 |
+
st.pyplot(fig)
|
96 |
+
|
97 |
+
# Text Generation Demo (for BLOOM)
|
98 |
+
if "bloom" in model_name:
|
99 |
+
st.subheader("βοΈ Text Generation & Token Probabilities")
|
100 |
+
generator = pipeline("text-generation", model=model_name, return_full_text=False)
|
101 |
+
generated_output = generator(input_text, max_length=50, return_tensors=True)
|
102 |
+
st.write("Generated Output:", generated_output[0]["generated_text"])
|
103 |
+
|
104 |
+
# Token Probability Visualization
|
105 |
+
model_gen = AutoModelForCausalLM.from_pretrained(model_name)
|
106 |
+
with torch.no_grad():
|
107 |
+
inputs = tokenizer(input_text, return_tensors="pt")
|
108 |
+
logits = model_gen(**inputs).logits[:, -1, :]
|
109 |
+
probs = torch.nn.functional.softmax(logits, dim=-1).squeeze().detach().numpy()
|
110 |
+
top_tokens = np.argsort(probs)[-10:][::-1] # Top 10 tokens
|
111 |
+
token_probs = {tokenizer.decode([idx]): probs[idx] for idx in top_tokens}
|
112 |
+
df_probs = pd.DataFrame(token_probs.items(), columns=["Token", "Probability"])
|
113 |
+
fig_prob = px.bar(df_probs, x="Token", y="Probability", title="Top Token Predictions")
|
114 |
+
st.plotly_chart(fig_prob)
|
115 |
+
|
116 |
+
st.markdown("π‘ *Explore more about Transformer models!*")
|