Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import torch | |
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer | |
import matplotlib.pyplot as plt | |
from sklearn.decomposition import PCA | |
import numpy as np | |
import plotly.express as px | |
from sklearn.metrics.pairwise import cosine_similarity | |
import umap | |
import pandas as pd | |
class EmbeddingVisualizer: | |
def __init__(self): | |
self.model = None | |
self.tokenizer = None | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def load_model(self, model_name): | |
if self.model is not None: | |
# Clear CUDA cache if using GPU | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=os.environ.get("HF_TOKEN")) | |
if "gemma" in model_name: | |
self.model = AutoModelForCausalLM.from_pretrained(model_name, token=os.environ.get("HF_TOKEN"), torch_dtype=torch.float16) | |
else: | |
self.model = AutoModel.from_pretrained(model_name) | |
self.model = self.model.to(self.device) | |
return f"Loaded model: {model_name}" | |
def get_embedding(self, text): | |
if not text.strip(): | |
return None | |
inputs = self.tokenizer(text, return_tensors="pt", padding=True) | |
inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
outputs = self.model(**inputs, output_hidden_states=True) | |
hidden_states = outputs.hidden_states[-1] | |
mask = inputs["attention_mask"].unsqueeze(-1).expand(hidden_states.size()).float() | |
masked_embeddings = hidden_states * mask | |
sum_embeddings = torch.sum(masked_embeddings, dim=1) | |
sum_mask = torch.clamp(torch.sum(mask, dim=1), min=1e-9) | |
embedding = (sum_embeddings / sum_mask).squeeze().cpu().numpy() | |
return embedding | |
def calculate_similarity_matrix(self, embeddings): | |
if not embeddings: | |
return None | |
embeddings_np = np.array(embeddings) | |
return cosine_similarity(embeddings_np) | |
def reduce_dimensionality(self, embeddings, n_components, method): | |
# Ensure we have enough samples for the requested components | |
n_samples = embeddings.shape[0] | |
# If only one sample, return it repeated to create a visible point | |
if n_samples == 1: | |
return np.tile(np.zeros((1, n_components)), (1, 1)) | |
n_components = min(n_components, n_samples - 1) # Ensure k < N | |
if method == "pca": | |
reducer = PCA(n_components=n_components) | |
elif method == "umap": | |
# For very small datasets, fall back to PCA | |
if n_samples < 4: | |
reducer = PCA(n_components=n_components) | |
else: | |
# Adjust parameters based on data size | |
n_neighbors = min(15, n_samples - 1) # Ensure n_neighbors < n_samples | |
min_dist = 0.1 if n_samples > 4 else 0.5 # Increase min_dist for small datasets | |
reducer = umap.UMAP( | |
n_components=n_components, | |
n_neighbors=n_neighbors, | |
min_dist=min_dist, | |
metric='euclidean', | |
random_state=42 | |
) | |
else: | |
raise ValueError("Invalid dimensionality reduction method") | |
# Convert to dense array if sparse | |
if hasattr(embeddings, 'toarray'): | |
embeddings = embeddings.toarray() | |
return reducer.fit_transform(embeddings) | |
def visualize_embeddings(self, model_choice, is_3d, | |
word1, word2, word3, word4, word5, word6, word7, word8, | |
positive_word1, positive_word2, | |
negative_word1, negative_word2, | |
dim_reduction_method): | |
words = [word1, word2, word3, word4, word5, word6, word7, word8] | |
words = [w for w in words if w.strip()] | |
positive_words = [w for w in [positive_word1, positive_word2] if w.strip()] | |
negative_words = [w for w in [negative_word1, negative_word2] if w.strip()] | |
embeddings = [] | |
labels = [] | |
for word in words: | |
emb = self.get_embedding(word) | |
if emb is not None: | |
embeddings.append(emb) | |
labels.append(word) | |
if positive_words or negative_words: | |
pos_embs = [self.get_embedding(w) for w in positive_words if self.get_embedding(w) is not None] | |
neg_embs = [self.get_embedding(w) for w in negative_words if self.get_embedding(w) is not None] | |
if pos_embs or neg_embs: | |
pos_sum = sum(pos_embs) if pos_embs else 0 | |
neg_sum = sum(neg_embs) if neg_embs else 0 | |
arithmetic_emb = pos_sum - neg_sum | |
embeddings.append(arithmetic_emb) | |
labels.append("Arithmetic Result") | |
if not embeddings: | |
return None | |
embeddings = np.array(embeddings) | |
# Reduce dimensionality | |
if is_3d: | |
embeddings_reduced = self.reduce_dimensionality(embeddings, 3, dim_reduction_method) | |
fig = px.scatter_3d(x=embeddings_reduced[:, 0], | |
y=embeddings_reduced[:, 1], | |
z=embeddings_reduced[:, 2], | |
text=labels, | |
title=f"3D Word Embeddings Visualization ({model_choice}) - {dim_reduction_method.upper()}") | |
fig.update_traces(textposition='top center') | |
return fig | |
else: | |
embeddings_reduced = self.reduce_dimensionality(embeddings, 2, dim_reduction_method) | |
fig = px.scatter(x=embeddings_reduced[:, 0], | |
y=embeddings_reduced[:, 1], | |
text=labels, | |
title=f"2D Word Embeddings Visualization ({model_choice}) - {dim_reduction_method.upper()}") | |
fig.update_traces(textposition='top center') | |
return fig | |
def visualize_similarity_heatmap(self, model_choice, | |
word1, word2, word3, word4, word5, word6, word7, word8): | |
words = [word1, word2, word3, word4, word5, word6, word7, word8] | |
words = [w for w in words if w.strip()] | |
embeddings = [self.get_embedding(word) for word in words if self.get_embedding(word) is not None] | |
if not embeddings: | |
return None | |
similarity_matrix = self.calculate_similarity_matrix(embeddings) | |
if similarity_matrix is None: | |
return None | |
fig = plt.figure(figsize=(10, 8)) | |
ax = fig.add_subplot(111) | |
cax = ax.matshow(similarity_matrix, interpolation='nearest') | |
fig.colorbar(cax) | |
ax.set_xticks(np.arange(len(words))) | |
ax.set_yticks(np.arange(len(words))) | |
ax.set_xticklabels(words, rotation=45, ha='left') | |
ax.set_yticklabels(words) | |
plt.title(f"Cosine Similarity Heatmap ({model_choice})") | |
return fig | |
# Initialize the visualizer | |
visualizer = EmbeddingVisualizer() | |
# Create Gradio interface | |
with gr.Blocks() as iface: | |
gr.Markdown("# Word Embedding Visualization") | |
with gr.Row(): | |
with gr.Column(): | |
model_choice = gr.Dropdown( | |
choices=["google/gemma-2b", "bert-large-uncased"], | |
value="google/gemma-2b", | |
label="Select Model" | |
) | |
load_status = gr.Textbox(label="Model Status", interactive=False) | |
is_3d = gr.Checkbox(label="Use 3D Visualization", value=False) | |
dim_reduction_method = gr.Radio( | |
choices=["pca", "umap"], | |
value="pca", | |
label="Dimensionality Reduction Method" | |
) | |
with gr.Column(): | |
word1 = gr.Textbox(label="Word 1") | |
word2 = gr.Textbox(label="Word 2") | |
word3 = gr.Textbox(label="Word 3") | |
word4 = gr.Textbox(label="Word 4") | |
word5 = gr.Textbox(label="Word 5") | |
word6 = gr.Textbox(label="Word 6") | |
word7 = gr.Textbox(label="Word 7") | |
word8 = gr.Textbox(label="Word 8") | |
with gr.Column(): | |
positive_word1 = gr.Textbox(label="Positive Word 1") | |
positive_word2 = gr.Textbox(label="Positive Word 2") | |
negative_word1 = gr.Textbox(label="Negative Word 1") | |
negative_word2 = gr.Textbox(label="Negative Word 2") | |
with gr.Tabs(): | |
with gr.Tab("Scatter Plot"): | |
plot_output = gr.Plot() | |
with gr.Tab("Similarity Heatmap"): | |
heatmap_output = gr.Plot() | |
# Load model when selected | |
model_choice.change( | |
fn=visualizer.load_model, | |
inputs=[model_choice], | |
outputs=[load_status] | |
) | |
# Update visualization when any input changes | |
inputs = [ | |
model_choice, is_3d, | |
word1, word2, word3, word4, word5, word6, word7, word8, | |
positive_word1, positive_word2, | |
negative_word1, negative_word2, | |
dim_reduction_method | |
] | |
for input_component in inputs: | |
input_component.change( | |
fn=visualizer.visualize_embeddings, | |
inputs=inputs, | |
outputs=[plot_output] | |
) | |
similarity_inputs = [model_choice, | |
word1, word2, word3, word4, word5, word6, word7, word8] | |
for input_component in similarity_inputs: | |
input_component.change( | |
fn=visualizer.visualize_similarity_heatmap, | |
inputs=similarity_inputs, | |
outputs=[heatmap_output] | |
) | |
# Add Clear All button | |
clear_button = gr.Button("Clear All") | |
def clear_all(): | |
return [""] * 12 # Returns empty strings for the 12 text input components | |
clear_button.click( | |
fn=clear_all, | |
inputs=[], | |
outputs=[word1, word2, word3, word4, word5, word6, word7, word8, | |
positive_word1, positive_word2, | |
negative_word1, negative_word2] | |
) | |
if __name__ == "__main__": | |
# Load initial model | |
visualizer.load_model("google/gemma-2b") | |
iface.launch() | |