Spaces:
Runtime error
Runtime error
File size: 6,063 Bytes
5aad559 bcd62f3 a9d1447 5aad559 9735252 5aad559 a9d1447 5aad559 bcd62f3 5aad559 d915150 a9d1447 9735252 5aad559 737452a 5aad559 9735252 5aad559 a9d1447 5aad559 262af65 5aad559 bcd62f3 a9d1447 bcd62f3 abd3459 bcd62f3 9735252 bcd62f3 9735252 a9d1447 5aad559 262af65 a9d1447 5aad559 737452a 5aad559 9735252 5aad559 9735252 5aad559 b40f734 5aad559 bcd62f3 b846822 bcd62f3 5aad559 9735252 5aad559 bcd62f3 a9d1447 bcd62f3 5aad559 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import logging
from functools import partial
from typing import Callable, List, Optional
import numpy as np
import pandas as pd
import streamlit as st
import umap
from bokeh.models import ColumnDataSource, HoverTool
from bokeh.palettes import Cividis256 as Pallete
from bokeh.plotting import Figure, figure
from bokeh.transform import factor_cmap
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sklearn.manifold import TSNE
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
SEED = 0
@st.cache(show_spinner=False, allow_output_mutation=True)
def load_model(model_name: str) -> SentenceTransformer:
embedder = model_name
return SentenceTransformer(embedder)
def embed_text(text: List[str], model: SentenceTransformer) -> np.ndarray:
return model.encode(text)
def encode_labels(labels: pd.Series) -> pd.Series:
if pd.api.types.is_numeric_dtype(labels):
return labels
return labels.astype("category").cat.codes
def get_tsne_embeddings(
embeddings: np.ndarray, perplexity: int = 30, n_components: int = 2, init: str = "pca", n_iter: int = 5000, random_state: int = SEED
) -> np.ndarray:
tsne = TSNE(perplexity=perplexity, n_components=n_components, init=init, n_iter=n_iter, random_state=random_state)
return tsne.fit_transform(embeddings)
def get_umap_embeddings(embeddings: np.ndarray) -> np.ndarray:
umap_model = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=SEED)
return umap_model.fit_transform(embeddings)
def draw_interactive_scatter_plot(
texts: np.ndarray, xs: np.ndarray, ys: np.ndarray, values: np.ndarray, labels: np.ndarray, text_column: str, label_column: str
) -> Figure:
# Normalize values to range between 0-255, to assign a color for each value
max_value = values.max()
min_value = values.min()
if max_value - min_value == 0:
values_color = np.ones(len(values))
else:
values_color = ((values - min_value) / (max_value - min_value) * 255).round().astype(int).astype(str)
values_color_set = sorted(values_color)
values_list = values.astype(str).tolist()
values_set = sorted(values_list)
labels_list = labels.astype(str).tolist()
source = ColumnDataSource(data=dict(x=xs, y=ys, text=texts, label=values_list, original_label=labels_list))
hover = HoverTool(tooltips=[(text_column, "@text{safe}"), (label_column, "@original_label")])
p = figure(plot_width=800, plot_height=800, tools=[hover], title="Embedding Lenses")
p.circle("x", "y", size=10, source=source, fill_color=factor_cmap("label", palette=[Pallete[int(id_)] for id_ in values_color_set], factors=values_set))
return p
def uploaded_file_to_dataframe(uploaded_file: st.uploaded_file_manager.UploadedFile) -> pd.DataFrame:
extension = uploaded_file.name.split(".")[-1]
return pd.read_csv(uploaded_file, sep="\t" if extension == "tsv" else ",")
def hub_dataset_to_dataframe(path: str, name: str, split: str, sample: int) -> pd.DataFrame:
load_dataset_fn = partial(load_dataset, path=path)
if name:
load_dataset_fn = partial(load_dataset_fn, name=name)
if split:
load_dataset_fn = partial(load_dataset_fn, split=split)
dataset = load_dataset_fn().shuffle(seed=SEED)[:sample]
return pd.DataFrame(dataset)
def generate_plot(
df: pd.DataFrame,
text_column: str,
label_column: str,
sample: Optional[int],
dimensionality_reduction_function: Callable,
model: SentenceTransformer,
) -> Figure:
logger.info("Loading dataset in memory")
if text_column not in df.columns:
raise ValueError(f"The specified column name doesn't exist. Columns available: {df.columns.values}")
if label_column not in df.columns:
df[label_column] = 0
df = df.dropna(subset=[text_column, label_column])
if sample:
df = df.sample(min(sample, df.shape[0]), random_state=SEED)
logger.info("Embedding sentences")
embeddings = embed_text(df[text_column].values.tolist(), model)
logger.info("Encoding labels")
encoded_labels = encode_labels(df[label_column])
logger.info("Running dimensionality reduction")
embeddings_2d = dimensionality_reduction_function(embeddings)
logger.info("Generating figure")
plot = draw_interactive_scatter_plot(
df[text_column].values, embeddings_2d[:, 0], embeddings_2d[:, 1], encoded_labels.values, df[label_column].values, text_column, label_column
)
return plot
st.title("Embedding Lenses")
st.write("Visualize text embeddings in 2D using colors for continuous or categorical labels.")
uploaded_file = st.file_uploader("Choose an csv/tsv file...", type=["csv", "tsv"])
st.write("Alternatively, select a dataset from the hub")
col1, col2, col3 = st.columns(3)
with col1:
hub_dataset = st.text_input("Dataset name", "ag_news")
with col2:
hub_dataset_config = st.text_input("Dataset configuration", "")
with col3:
hub_dataset_split = st.text_input("Dataset split", "train")
text_column = st.text_input("Text column name", "text")
label_column = st.text_input("Numerical/categorical column name (ignore if not applicable)", "label")
sample = st.number_input("Maximum number of documents to use", 1, 100000, 1000)
dimensionality_reduction = st.selectbox("Dimensionality Reduction algorithm", ["UMAP", "t-SNE"], 0)
model_name = st.selectbox("Sentence embedding model", ["distiluse-base-multilingual-cased-v1", "all-mpnet-base-v2"], 0)
model = load_model(model_name)
dimensionality_reduction_function = get_umap_embeddings if dimensionality_reduction == "UMAP" else get_tsne_embeddings
if uploaded_file or hub_dataset:
if uploaded_file:
df = uploaded_file_to_dataframe(uploaded_file)
else:
df = hub_dataset_to_dataframe(hub_dataset, hub_dataset_config, hub_dataset_split, sample)
plot = generate_plot(df, text_column, label_column, sample, dimensionality_reduction_function, model)
logger.info("Displaying plot")
st.bokeh_chart(plot)
logger.info("Done")
|