Spaces:
Runtime error
Runtime error
import logging | |
from functools import partial | |
from typing import Callable, List, Optional | |
import numpy as np | |
import pandas as pd | |
import streamlit as st | |
import umap | |
from bokeh.models import ColumnDataSource, HoverTool | |
from bokeh.palettes import Cividis256 as Pallete | |
from bokeh.plotting import Figure, figure | |
from bokeh.transform import factor_cmap | |
from datasets import load_dataset | |
from sentence_transformers import SentenceTransformer | |
from sklearn.manifold import TSNE | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
SEED = 0 | |
def load_model(model_name: str) -> SentenceTransformer: | |
embedder = model_name | |
return SentenceTransformer(embedder) | |
def embed_text(text: List[str], model: SentenceTransformer) -> np.ndarray: | |
return model.encode(text) | |
def encode_labels(labels: pd.Series) -> pd.Series: | |
if pd.api.types.is_numeric_dtype(labels): | |
return labels | |
return labels.astype("category").cat.codes | |
def get_tsne_embeddings( | |
embeddings: np.ndarray, perplexity: int = 30, n_components: int = 2, init: str = "pca", n_iter: int = 5000, random_state: int = SEED | |
) -> np.ndarray: | |
tsne = TSNE(perplexity=perplexity, n_components=n_components, init=init, n_iter=n_iter, random_state=random_state) | |
return tsne.fit_transform(embeddings) | |
def get_umap_embeddings(embeddings: np.ndarray) -> np.ndarray: | |
umap_model = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=SEED) | |
return umap_model.fit_transform(embeddings) | |
def draw_interactive_scatter_plot( | |
texts: np.ndarray, xs: np.ndarray, ys: np.ndarray, values: np.ndarray, labels: np.ndarray, text_column: str, label_column: str | |
) -> Figure: | |
# Normalize values to range between 0-255, to assign a color for each value | |
max_value = values.max() | |
min_value = values.min() | |
if max_value - min_value == 0: | |
values_color = np.ones(len(values)) | |
else: | |
values_color = ((values - min_value) / (max_value - min_value) * 255).round().astype(int).astype(str) | |
values_color_set = sorted(values_color) | |
values_list = values.astype(str).tolist() | |
values_set = sorted(values_list) | |
labels_list = labels.astype(str).tolist() | |
source = ColumnDataSource(data=dict(x=xs, y=ys, text=texts, label=values_list, original_label=labels_list)) | |
hover = HoverTool(tooltips=[(text_column, "@text{safe}"), (label_column, "@original_label")]) | |
p = figure(plot_width=800, plot_height=800, tools=[hover], title="Embedding Lenses") | |
p.circle("x", "y", size=10, source=source, fill_color=factor_cmap("label", palette=[Pallete[int(id_)] for id_ in values_color_set], factors=values_set)) | |
return p | |
def uploaded_file_to_dataframe(uploaded_file: st.uploaded_file_manager.UploadedFile) -> pd.DataFrame: | |
extension = uploaded_file.name.split(".")[-1] | |
return pd.read_csv(uploaded_file, sep="\t" if extension == "tsv" else ",") | |
def hub_dataset_to_dataframe(path: str, name: str, split: str, sample: int) -> pd.DataFrame: | |
load_dataset_fn = partial(load_dataset, path=path) | |
if name: | |
load_dataset_fn = partial(load_dataset_fn, name=name) | |
if split: | |
load_dataset_fn = partial(load_dataset_fn, split=split) | |
dataset = load_dataset_fn().shuffle(seed=SEED)[:sample] | |
return pd.DataFrame(dataset) | |
def generate_plot( | |
df: pd.DataFrame, | |
text_column: str, | |
label_column: str, | |
sample: Optional[int], | |
dimensionality_reduction_function: Callable, | |
model: SentenceTransformer, | |
) -> Figure: | |
logger.info("Loading dataset in memory") | |
if text_column not in df.columns: | |
raise ValueError(f"The specified column name doesn't exist. Columns available: {df.columns.values}") | |
if label_column not in df.columns: | |
df[label_column] = 0 | |
df = df.dropna(subset=[text_column, label_column]) | |
if sample: | |
df = df.sample(min(sample, df.shape[0]), random_state=SEED) | |
logger.info("Embedding sentences") | |
embeddings = embed_text(df[text_column].values.tolist(), model) | |
logger.info("Encoding labels") | |
encoded_labels = encode_labels(df[label_column]) | |
logger.info("Running dimensionality reduction") | |
embeddings_2d = dimensionality_reduction_function(embeddings) | |
logger.info("Generating figure") | |
plot = draw_interactive_scatter_plot( | |
df[text_column].values, embeddings_2d[:, 0], embeddings_2d[:, 1], encoded_labels.values, df[label_column].values, text_column, label_column | |
) | |
return plot | |
st.title("Embedding Lenses") | |
st.write("Visualize text embeddings in 2D using colors for continuous or categorical labels.") | |
uploaded_file = st.file_uploader("Choose an csv/tsv file...", type=["csv", "tsv"]) | |
st.write("Alternatively, select a dataset from the hub") | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
hub_dataset = st.text_input("Dataset name", "ag_news") | |
with col2: | |
hub_dataset_config = st.text_input("Dataset configuration", "") | |
with col3: | |
hub_dataset_split = st.text_input("Dataset split", "train") | |
text_column = st.text_input("Text column name", "text") | |
label_column = st.text_input("Numerical/categorical column name (ignore if not applicable)", "label") | |
sample = st.number_input("Maximum number of documents to use", 1, 100000, 1000) | |
dimensionality_reduction = st.selectbox("Dimensionality Reduction algorithm", ["UMAP", "t-SNE"], 0) | |
model_name = st.selectbox("Sentence embedding model", ["distiluse-base-multilingual-cased-v1", "all-mpnet-base-v2"], 0) | |
model = load_model(model_name) | |
dimensionality_reduction_function = get_umap_embeddings if dimensionality_reduction == "UMAP" else get_tsne_embeddings | |
if uploaded_file or hub_dataset: | |
if uploaded_file: | |
df = uploaded_file_to_dataframe(uploaded_file) | |
else: | |
df = hub_dataset_to_dataframe(hub_dataset, hub_dataset_config, hub_dataset_split, sample) | |
plot = generate_plot(df, text_column, label_column, sample, dimensionality_reduction_function, model) | |
logger.info("Displaying plot") | |
st.bokeh_chart(plot) | |
logger.info("Done") | |