barttee commited on
Commit
c4eecf3
·
verified ·
1 Parent(s): 73a08d6

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +246 -0
  2. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import torch
4
+ from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
5
+ import matplotlib.pyplot as plt
6
+ from sklearn.decomposition import PCA
7
+ import numpy as np
8
+ import plotly.express as px
9
+ from sklearn.metrics.pairwise import cosine_similarity
10
+ import umap
11
+ import pandas as pd
12
+
13
+ class EmbeddingVisualizer:
14
+ def __init__(self):
15
+ self.model = None
16
+ self.tokenizer = None
17
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+
19
+ def load_model(self, model_name):
20
+ if self.model is not None:
21
+ # Clear CUDA cache if using GPU
22
+ if torch.cuda.is_available():
23
+ torch.cuda.empty_cache()
24
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=os.environ.get("HF_TOKEN"))
25
+ if "gemma" in model_name:
26
+ self.model = AutoModelForCausalLM.from_pretrained(model_name, token=os.environ.get("HF_TOKEN"), torch_dtype=torch.float16)
27
+ else:
28
+ self.model = AutoModel.from_pretrained(model_name)
29
+ self.model = self.model.to(self.device)
30
+ return f"Loaded model: {model_name}"
31
+
32
+ def get_embedding(self, text):
33
+ if not text.strip():
34
+ return None
35
+ inputs = self.tokenizer(text, return_tensors="pt", padding=True)
36
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
37
+ with torch.no_grad():
38
+ outputs = self.model(**inputs, output_hidden_states=True)
39
+ hidden_states = outputs.hidden_states[-1]
40
+ mask = inputs["attention_mask"].unsqueeze(-1).expand(hidden_states.size()).float()
41
+ masked_embeddings = hidden_states * mask
42
+ sum_embeddings = torch.sum(masked_embeddings, dim=1)
43
+ sum_mask = torch.clamp(torch.sum(mask, dim=1), min=1e-9)
44
+ embedding = (sum_embeddings / sum_mask).squeeze().cpu().numpy()
45
+ return embedding
46
+
47
+ def calculate_similarity_matrix(self, embeddings):
48
+ if not embeddings:
49
+ return None
50
+ embeddings_np = np.array(embeddings)
51
+ return cosine_similarity(embeddings_np)
52
+
53
+ def reduce_dimensionality(self, embeddings, n_components, method):
54
+ # Ensure we have enough samples for the requested components
55
+ n_samples = embeddings.shape[0]
56
+
57
+ # If only one sample, return it repeated to create a visible point
58
+ if n_samples == 1:
59
+ return np.tile(np.zeros((1, n_components)), (1, 1))
60
+
61
+ n_components = min(n_components, n_samples - 1) # Ensure k < N
62
+
63
+ if method == "pca":
64
+ reducer = PCA(n_components=n_components)
65
+ elif method == "umap":
66
+ # For very small datasets, fall back to PCA
67
+ if n_samples < 4:
68
+ reducer = PCA(n_components=n_components)
69
+ else:
70
+ # Adjust parameters based on data size
71
+ n_neighbors = min(15, n_samples - 1) # Ensure n_neighbors < n_samples
72
+ min_dist = 0.1 if n_samples > 4 else 0.5 # Increase min_dist for small datasets
73
+
74
+ reducer = umap.UMAP(
75
+ n_components=n_components,
76
+ n_neighbors=n_neighbors,
77
+ min_dist=min_dist,
78
+ metric='euclidean',
79
+ random_state=42
80
+ )
81
+ else:
82
+ raise ValueError("Invalid dimensionality reduction method")
83
+
84
+ # Convert to dense array if sparse
85
+ if hasattr(embeddings, 'toarray'):
86
+ embeddings = embeddings.toarray()
87
+
88
+ return reducer.fit_transform(embeddings)
89
+
90
+
91
+ def visualize_embeddings(self, model_choice, is_3d,
92
+ word1, word2, word3, word4, word5, word6, word7, word8,
93
+ positive_word1, positive_word2,
94
+ negative_word1, negative_word2,
95
+ dim_reduction_method):
96
+ words = [word1, word2, word3, word4, word5, word6, word7, word8]
97
+ words = [w for w in words if w.strip()]
98
+ positive_words = [w for w in [positive_word1, positive_word2] if w.strip()]
99
+ negative_words = [w for w in [negative_word1, negative_word2] if w.strip()]
100
+ embeddings = []
101
+ labels = []
102
+ for word in words:
103
+ emb = self.get_embedding(word)
104
+ if emb is not None:
105
+ embeddings.append(emb)
106
+ labels.append(word)
107
+ if positive_words or negative_words:
108
+ pos_embs = [self.get_embedding(w) for w in positive_words if self.get_embedding(w) is not None]
109
+ neg_embs = [self.get_embedding(w) for w in negative_words if self.get_embedding(w) is not None]
110
+ if pos_embs or neg_embs:
111
+ pos_sum = sum(pos_embs) if pos_embs else 0
112
+ neg_sum = sum(neg_embs) if neg_embs else 0
113
+ arithmetic_emb = pos_sum - neg_sum
114
+ embeddings.append(arithmetic_emb)
115
+ labels.append("Arithmetic Result")
116
+ if not embeddings:
117
+ return None
118
+ embeddings = np.array(embeddings)
119
+ # Reduce dimensionality
120
+ if is_3d:
121
+ embeddings_reduced = self.reduce_dimensionality(embeddings, 3, dim_reduction_method)
122
+ fig = px.scatter_3d(x=embeddings_reduced[:, 0],
123
+ y=embeddings_reduced[:, 1],
124
+ z=embeddings_reduced[:, 2],
125
+ text=labels,
126
+ title=f"3D Word Embeddings Visualization ({model_choice}) - {dim_reduction_method.upper()}")
127
+ fig.update_traces(textposition='top center')
128
+ return fig
129
+ else:
130
+ embeddings_reduced = self.reduce_dimensionality(embeddings, 2, dim_reduction_method)
131
+ fig = px.scatter(x=embeddings_reduced[:, 0],
132
+ y=embeddings_reduced[:, 1],
133
+ text=labels,
134
+ title=f"2D Word Embeddings Visualization ({model_choice}) - {dim_reduction_method.upper()}")
135
+ fig.update_traces(textposition='top center')
136
+ return fig
137
+
138
+
139
+ def visualize_similarity_heatmap(self, model_choice,
140
+ word1, word2, word3, word4, word5, word6, word7, word8):
141
+ words = [word1, word2, word3, word4, word5, word6, word7, word8]
142
+ words = [w for w in words if w.strip()]
143
+ embeddings = [self.get_embedding(word) for word in words if self.get_embedding(word) is not None]
144
+ if not embeddings:
145
+ return None
146
+ similarity_matrix = self.calculate_similarity_matrix(embeddings)
147
+ if similarity_matrix is None:
148
+ return None
149
+ fig = plt.figure(figsize=(10, 8))
150
+ ax = fig.add_subplot(111)
151
+ cax = ax.matshow(similarity_matrix, interpolation='nearest')
152
+ fig.colorbar(cax)
153
+ ax.set_xticks(np.arange(len(words)))
154
+ ax.set_yticks(np.arange(len(words)))
155
+ ax.set_xticklabels(words, rotation=45, ha='left')
156
+ ax.set_yticklabels(words)
157
+ plt.title(f"Cosine Similarity Heatmap ({model_choice})")
158
+ return fig
159
+
160
+ # Initialize the visualizer
161
+ visualizer = EmbeddingVisualizer()
162
+
163
+ # Create Gradio interface
164
+ with gr.Blocks() as iface:
165
+ gr.Markdown("# Word Embedding Visualization")
166
+ with gr.Row():
167
+ with gr.Column():
168
+ model_choice = gr.Dropdown(
169
+ choices=["google/gemma-2b", "bert-large-uncased"],
170
+ value="google/gemma-2b",
171
+ label="Select Model"
172
+ )
173
+ load_status = gr.Textbox(label="Model Status", interactive=False)
174
+ is_3d = gr.Checkbox(label="Use 3D Visualization", value=False)
175
+ dim_reduction_method = gr.Radio(
176
+ choices=["pca", "umap"],
177
+ value="pca",
178
+ label="Dimensionality Reduction Method"
179
+ )
180
+ with gr.Column():
181
+ word1 = gr.Textbox(label="Word 1")
182
+ word2 = gr.Textbox(label="Word 2")
183
+ word3 = gr.Textbox(label="Word 3")
184
+ word4 = gr.Textbox(label="Word 4")
185
+ word5 = gr.Textbox(label="Word 5")
186
+ word6 = gr.Textbox(label="Word 6")
187
+ word7 = gr.Textbox(label="Word 7")
188
+ word8 = gr.Textbox(label="Word 8")
189
+ with gr.Column():
190
+ positive_word1 = gr.Textbox(label="Positive Word 1")
191
+ positive_word2 = gr.Textbox(label="Positive Word 2")
192
+ negative_word1 = gr.Textbox(label="Negative Word 1")
193
+ negative_word2 = gr.Textbox(label="Negative Word 2")
194
+ with gr.Tabs():
195
+ with gr.Tab("Scatter Plot"):
196
+ plot_output = gr.Plot()
197
+ with gr.Tab("Similarity Heatmap"):
198
+ heatmap_output = gr.Plot()
199
+
200
+ # Load model when selected
201
+ model_choice.change(
202
+ fn=visualizer.load_model,
203
+ inputs=[model_choice],
204
+ outputs=[load_status]
205
+ )
206
+ # Update visualization when any input changes
207
+ inputs = [
208
+ model_choice, is_3d,
209
+ word1, word2, word3, word4, word5, word6, word7, word8,
210
+ positive_word1, positive_word2,
211
+ negative_word1, negative_word2,
212
+ dim_reduction_method
213
+ ]
214
+ for input_component in inputs:
215
+ input_component.change(
216
+ fn=visualizer.visualize_embeddings,
217
+ inputs=inputs,
218
+ outputs=[plot_output]
219
+ )
220
+ similarity_inputs = [model_choice,
221
+ word1, word2, word3, word4, word5, word6, word7, word8]
222
+ for input_component in similarity_inputs:
223
+ input_component.change(
224
+ fn=visualizer.visualize_similarity_heatmap,
225
+ inputs=similarity_inputs,
226
+ outputs=[heatmap_output]
227
+ )
228
+
229
+ # Add Clear All button
230
+ clear_button = gr.Button("Clear All")
231
+
232
+ def clear_all():
233
+ return [""] * 12 # Returns empty strings for the 12 text input components
234
+
235
+ clear_button.click(
236
+ fn=clear_all,
237
+ inputs=[],
238
+ outputs=[word1, word2, word3, word4, word5, word6, word7, word8,
239
+ positive_word1, positive_word2,
240
+ negative_word1, negative_word2]
241
+ )
242
+
243
+ if __name__ == "__main__":
244
+ # Load initial model
245
+ visualizer.load_model("google/gemma-2b")
246
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ transformers
4
+ matplotlib
5
+ scikit-learn
6
+ numpy
7
+ plotly
8
+ umap-learn
9
+ pandas