borutokarma123 commited on
Commit
ffea4fb
Β·
verified Β·
1 Parent(s): c3d8699

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -0
app.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import matplotlib.pyplot as plt
4
+ import seaborn as sns
5
+ import numpy as np
6
+ import pandas as pd
7
+ import plotly.express as px
8
+ from sklearn.decomposition import PCA
9
+ from transformers import AutoModel, AutoTokenizer, pipeline, AutoModelForCausalLM
10
+
11
+ # App Title
12
+ st.title("πŸš€ Transformer Model Explorer")
13
+ st.markdown("""
14
+ Explore different transformer models, their architectures, tokenization, and attention mechanisms.
15
+ """)
16
+
17
+ # Model Selection
18
+ model_name = st.selectbox(
19
+ "Choose a Transformer Model:",
20
+ ["bigscience/bloom", "openai/whisper-base", "facebook/wav2vec2-base-960h"]
21
+ )
22
+
23
+ # Load Tokenizer & Model
24
+ st.write(f"Loading model: `{model_name}`...")
25
+ if "bloom" in model_name:
26
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
27
+ model = AutoModel.from_pretrained(model_name)
28
+ elif "whisper" in model_name:
29
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
30
+ model = AutoModel.from_pretrained(model_name)
31
+ elif "wav2vec2" in model_name:
32
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
33
+ model = AutoModel.from_pretrained(model_name)
34
+
35
+ # Display Model Details
36
+ st.subheader("πŸ›  Model Details")
37
+ st.write(f"Model Type: `{model.config.model_type}`")
38
+ st.write(f"Number of Layers: `{model.config.num_hidden_layers}`")
39
+ st.write(f"Number of Attention Heads: `{model.config.num_attention_heads if hasattr(model.config, 'num_attention_heads') else 'N/A'}`")
40
+ st.write(f"Total Parameters: `{sum(p.numel() for p in model.parameters())/1e6:.2f}M`")
41
+
42
+ # Model Size Comparison
43
+ st.subheader("πŸ“Š Model Size Comparison")
44
+ model_sizes = {
45
+ "bigscience/bloom": 176,
46
+ "openai/whisper-base": 74,
47
+ "facebook/wav2vec2-base-960h": 317
48
+ }
49
+ df_size = pd.DataFrame(model_sizes.items(), columns=["Model", "Size (Million Parameters)"])
50
+ fig = px.bar(df_size, x="Model", y="Size (Million Parameters)", title="Model Size Comparison")
51
+ st.plotly_chart(fig)
52
+
53
+ # Tokenization Section
54
+ st.subheader("πŸ“ Tokenization Visualization")
55
+ input_text = st.text_input("Enter Text:", "Hello, how are you?")
56
+
57
+ if "whisper" in model_name:
58
+ st.write("Note: Whisper is an audio model and doesn't use text tokenization")
59
+ st.write("Instead, it processes raw audio waveforms")
60
+ else:
61
+ tokens = tokenizer.tokenize(input_text)
62
+ st.write("Tokenized Output:", tokens)
63
+
64
+ # Token Embeddings Visualization (Fixed PCA Projection)
65
+ st.subheader("🧩 Token Embeddings Visualization")
66
+ with torch.no_grad():
67
+ if "whisper" in model_name:
68
+ st.write("Note: Whisper uses a different embedding structure for audio features")
69
+ st.write("Cannot directly visualize token embeddings as with text models")
70
+ else:
71
+ inputs = tokenizer(input_text, return_tensors="pt")
72
+ outputs = model(**inputs)
73
+ if hasattr(outputs, "last_hidden_state"):
74
+ embeddings = outputs.last_hidden_state.squeeze(0).numpy()
75
+ # Ensure the number of tokens and embeddings match
76
+ n_tokens = min(len(tokens), embeddings.shape[0])
77
+ embeddings = embeddings[:n_tokens] # Trim embeddings to match token count
78
+ tokens = tokens[:n_tokens] # Trim tokens to match embeddings count
79
+ pca = PCA(n_components=2)
80
+ reduced_embeddings = pca.fit_transform(embeddings)
81
+ df_embeddings = pd.DataFrame(reduced_embeddings, columns=["PCA1", "PCA2"])
82
+ df_embeddings["Token"] = tokens
83
+ fig = px.scatter(df_embeddings, x="PCA1", y="PCA2", text="Token",
84
+ title="Token Embeddings (PCA Projection)")
85
+ st.plotly_chart(fig)
86
+
87
+ # Attention Visualization (for BERT & RoBERTa models)
88
+ if "bloom" in model_name:
89
+ st.subheader("πŸ” Attention Map")
90
+ with torch.no_grad():
91
+ outputs = model(**inputs, output_attentions=True)
92
+ attention = outputs.attentions[-1].squeeze().detach().numpy()
93
+ fig, ax = plt.subplots(figsize=(10, 5))
94
+ sns.heatmap(attention[0], cmap="viridis", xticklabels=tokens, yticklabels=tokens, ax=ax)
95
+ st.pyplot(fig)
96
+
97
+ # Text Generation Demo (for BLOOM)
98
+ if "bloom" in model_name:
99
+ st.subheader("✍️ Text Generation & Token Probabilities")
100
+ generator = pipeline("text-generation", model=model_name, return_full_text=False)
101
+ generated_output = generator(input_text, max_length=50, return_tensors=True)
102
+ st.write("Generated Output:", generated_output[0]["generated_text"])
103
+
104
+ # Token Probability Visualization
105
+ model_gen = AutoModelForCausalLM.from_pretrained(model_name)
106
+ with torch.no_grad():
107
+ inputs = tokenizer(input_text, return_tensors="pt")
108
+ logits = model_gen(**inputs).logits[:, -1, :]
109
+ probs = torch.nn.functional.softmax(logits, dim=-1).squeeze().detach().numpy()
110
+ top_tokens = np.argsort(probs)[-10:][::-1] # Top 10 tokens
111
+ token_probs = {tokenizer.decode([idx]): probs[idx] for idx in top_tokens}
112
+ df_probs = pd.DataFrame(token_probs.items(), columns=["Token", "Probability"])
113
+ fig_prob = px.bar(df_probs, x="Token", y="Probability", title="Top Token Predictions")
114
+ st.plotly_chart(fig_prob)
115
+
116
+ st.markdown("πŸ’‘ *Explore more about Transformer models!*")