import os import gradio as gr import torch from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer import matplotlib.pyplot as plt from sklearn.decomposition import PCA import numpy as np import plotly.express as px from sklearn.metrics.pairwise import cosine_similarity import umap import pandas as pd class EmbeddingVisualizer: def __init__(self): self.model = None self.tokenizer = None self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def load_model(self, model_name): if self.model is not None: # Clear CUDA cache if using GPU if torch.cuda.is_available(): torch.cuda.empty_cache() self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=os.environ.get("HF_TOKEN")) if "gemma" in model_name: self.model = AutoModelForCausalLM.from_pretrained(model_name, token=os.environ.get("HF_TOKEN"), torch_dtype=torch.float16) else: self.model = AutoModel.from_pretrained(model_name) self.model = self.model.to(self.device) return f"Loaded model: {model_name}" def get_embedding(self, text): if not text.strip(): return None inputs = self.tokenizer(text, return_tensors="pt", padding=True) inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): outputs = self.model(**inputs, output_hidden_states=True) hidden_states = outputs.hidden_states[-1] mask = inputs["attention_mask"].unsqueeze(-1).expand(hidden_states.size()).float() masked_embeddings = hidden_states * mask sum_embeddings = torch.sum(masked_embeddings, dim=1) sum_mask = torch.clamp(torch.sum(mask, dim=1), min=1e-9) embedding = (sum_embeddings / sum_mask).squeeze().cpu().numpy() return embedding def calculate_similarity_matrix(self, embeddings): if not embeddings: return None embeddings_np = np.array(embeddings) return cosine_similarity(embeddings_np) def reduce_dimensionality(self, embeddings, n_components, method): # Ensure we have enough samples for the requested components n_samples = embeddings.shape[0] # If only one sample, return it repeated to create a visible point if n_samples == 1: return np.tile(np.zeros((1, n_components)), (1, 1)) n_components = min(n_components, n_samples - 1) # Ensure k < N if method == "pca": reducer = PCA(n_components=n_components) elif method == "umap": # For very small datasets, fall back to PCA if n_samples < 4: reducer = PCA(n_components=n_components) else: # Adjust parameters based on data size n_neighbors = min(15, n_samples - 1) # Ensure n_neighbors < n_samples min_dist = 0.1 if n_samples > 4 else 0.5 # Increase min_dist for small datasets reducer = umap.UMAP( n_components=n_components, n_neighbors=n_neighbors, min_dist=min_dist, metric='euclidean', random_state=42 ) else: raise ValueError("Invalid dimensionality reduction method") # Convert to dense array if sparse if hasattr(embeddings, 'toarray'): embeddings = embeddings.toarray() return reducer.fit_transform(embeddings) def visualize_embeddings(self, model_choice, is_3d, word1, word2, word3, word4, word5, word6, word7, word8, positive_word1, positive_word2, negative_word1, negative_word2, dim_reduction_method): words = [word1, word2, word3, word4, word5, word6, word7, word8] words = [w for w in words if w.strip()] positive_words = [w for w in [positive_word1, positive_word2] if w.strip()] negative_words = [w for w in [negative_word1, negative_word2] if w.strip()] embeddings = [] labels = [] for word in words: emb = self.get_embedding(word) if emb is not None: embeddings.append(emb) labels.append(word) if positive_words or negative_words: pos_embs = [self.get_embedding(w) for w in positive_words if self.get_embedding(w) is not None] neg_embs = [self.get_embedding(w) for w in negative_words if self.get_embedding(w) is not None] if pos_embs or neg_embs: pos_sum = sum(pos_embs) if pos_embs else 0 neg_sum = sum(neg_embs) if neg_embs else 0 arithmetic_emb = pos_sum - neg_sum embeddings.append(arithmetic_emb) labels.append("Arithmetic Result") if not embeddings: return None embeddings = np.array(embeddings) # Reduce dimensionality if is_3d: embeddings_reduced = self.reduce_dimensionality(embeddings, 3, dim_reduction_method) fig = px.scatter_3d(x=embeddings_reduced[:, 0], y=embeddings_reduced[:, 1], z=embeddings_reduced[:, 2], text=labels, title=f"3D Word Embeddings Visualization ({model_choice}) - {dim_reduction_method.upper()}") fig.update_traces(textposition='top center') return fig else: embeddings_reduced = self.reduce_dimensionality(embeddings, 2, dim_reduction_method) fig = px.scatter(x=embeddings_reduced[:, 0], y=embeddings_reduced[:, 1], text=labels, title=f"2D Word Embeddings Visualization ({model_choice}) - {dim_reduction_method.upper()}") fig.update_traces(textposition='top center') return fig def visualize_similarity_heatmap(self, model_choice, word1, word2, word3, word4, word5, word6, word7, word8): words = [word1, word2, word3, word4, word5, word6, word7, word8] words = [w for w in words if w.strip()] embeddings = [self.get_embedding(word) for word in words if self.get_embedding(word) is not None] if not embeddings: return None similarity_matrix = self.calculate_similarity_matrix(embeddings) if similarity_matrix is None: return None fig = plt.figure(figsize=(10, 8)) ax = fig.add_subplot(111) cax = ax.matshow(similarity_matrix, interpolation='nearest') fig.colorbar(cax) ax.set_xticks(np.arange(len(words))) ax.set_yticks(np.arange(len(words))) ax.set_xticklabels(words, rotation=45, ha='left') ax.set_yticklabels(words) plt.title(f"Cosine Similarity Heatmap ({model_choice})") return fig # Initialize the visualizer visualizer = EmbeddingVisualizer() # Create Gradio interface with gr.Blocks() as iface: gr.Markdown("# Word Embedding Visualization") with gr.Row(): with gr.Column(): model_choice = gr.Dropdown( choices=["google/gemma-2b", "bert-large-uncased"], value="google/gemma-2b", label="Select Model" ) load_status = gr.Textbox(label="Model Status", interactive=False) is_3d = gr.Checkbox(label="Use 3D Visualization", value=False) dim_reduction_method = gr.Radio( choices=["pca", "umap"], value="pca", label="Dimensionality Reduction Method" ) with gr.Column(): word1 = gr.Textbox(label="Word 1") word2 = gr.Textbox(label="Word 2") word3 = gr.Textbox(label="Word 3") word4 = gr.Textbox(label="Word 4") word5 = gr.Textbox(label="Word 5") word6 = gr.Textbox(label="Word 6") word7 = gr.Textbox(label="Word 7") word8 = gr.Textbox(label="Word 8") with gr.Column(): positive_word1 = gr.Textbox(label="Positive Word 1") positive_word2 = gr.Textbox(label="Positive Word 2") negative_word1 = gr.Textbox(label="Negative Word 1") negative_word2 = gr.Textbox(label="Negative Word 2") with gr.Tabs(): with gr.Tab("Scatter Plot"): plot_output = gr.Plot() with gr.Tab("Similarity Heatmap"): heatmap_output = gr.Plot() # Load model when selected model_choice.change( fn=visualizer.load_model, inputs=[model_choice], outputs=[load_status] ) # Update visualization when any input changes inputs = [ model_choice, is_3d, word1, word2, word3, word4, word5, word6, word7, word8, positive_word1, positive_word2, negative_word1, negative_word2, dim_reduction_method ] for input_component in inputs: input_component.change( fn=visualizer.visualize_embeddings, inputs=inputs, outputs=[plot_output] ) similarity_inputs = [model_choice, word1, word2, word3, word4, word5, word6, word7, word8] for input_component in similarity_inputs: input_component.change( fn=visualizer.visualize_similarity_heatmap, inputs=similarity_inputs, outputs=[heatmap_output] ) # Add Clear All button clear_button = gr.Button("Clear All") def clear_all(): return [""] * 12 # Returns empty strings for the 12 text input components clear_button.click( fn=clear_all, inputs=[], outputs=[word1, word2, word3, word4, word5, word6, word7, word8, positive_word1, positive_word2, negative_word1, negative_word2] ) if __name__ == "__main__": # Load initial model visualizer.load_model("google/gemma-2b") iface.launch()