File size: 6,063 Bytes
5aad559
bcd62f3
a9d1447
5aad559
 
 
 
9735252
5aad559
 
a9d1447
5aad559
bcd62f3
5aad559
 
 
 
 
 
 
 
d915150
a9d1447
9735252
5aad559
 
 
737452a
 
5aad559
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9735252
 
 
 
 
5aad559
 
a9d1447
5aad559
 
 
262af65
 
 
 
5aad559
 
 
 
 
 
 
 
 
 
 
 
 
bcd62f3
 
 
 
 
a9d1447
bcd62f3
 
 
 
 
abd3459
bcd62f3
 
 
9735252
bcd62f3
9735252
 
 
 
 
a9d1447
5aad559
262af65
a9d1447
5aad559
 
 
 
 
 
737452a
5aad559
 
9735252
 
5aad559
 
9735252
5aad559
 
 
 
 
b40f734
5aad559
bcd62f3
b846822
bcd62f3
 
 
 
 
 
 
5aad559
 
 
9735252
 
 
 
5aad559
bcd62f3
 
 
 
a9d1447
bcd62f3
5aad559
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import logging
from functools import partial
from typing import Callable, List, Optional

import numpy as np
import pandas as pd
import streamlit as st
import umap
from bokeh.models import ColumnDataSource, HoverTool
from bokeh.palettes import Cividis256 as Pallete
from bokeh.plotting import Figure, figure
from bokeh.transform import factor_cmap
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sklearn.manifold import TSNE

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
SEED = 0


@st.cache(show_spinner=False, allow_output_mutation=True)
def load_model(model_name: str) -> SentenceTransformer:
    embedder = model_name
    return SentenceTransformer(embedder)


def embed_text(text: List[str], model: SentenceTransformer) -> np.ndarray:
    return model.encode(text)


def encode_labels(labels: pd.Series) -> pd.Series:
    if pd.api.types.is_numeric_dtype(labels):
        return labels
    return labels.astype("category").cat.codes


def get_tsne_embeddings(
    embeddings: np.ndarray, perplexity: int = 30, n_components: int = 2, init: str = "pca", n_iter: int = 5000, random_state: int = SEED
) -> np.ndarray:
    tsne = TSNE(perplexity=perplexity, n_components=n_components, init=init, n_iter=n_iter, random_state=random_state)
    return tsne.fit_transform(embeddings)


def get_umap_embeddings(embeddings: np.ndarray) -> np.ndarray:
    umap_model = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=SEED)
    return umap_model.fit_transform(embeddings)


def draw_interactive_scatter_plot(
    texts: np.ndarray, xs: np.ndarray, ys: np.ndarray, values: np.ndarray, labels: np.ndarray, text_column: str, label_column: str
) -> Figure:
    # Normalize values to range between 0-255, to assign a color for each value
    max_value = values.max()
    min_value = values.min()
    if max_value - min_value == 0:
        values_color = np.ones(len(values))
    else:
        values_color = ((values - min_value) / (max_value - min_value) * 255).round().astype(int).astype(str)
    values_color_set = sorted(values_color)

    values_list = values.astype(str).tolist()
    values_set = sorted(values_list)
    labels_list = labels.astype(str).tolist()

    source = ColumnDataSource(data=dict(x=xs, y=ys, text=texts, label=values_list, original_label=labels_list))
    hover = HoverTool(tooltips=[(text_column, "@text{safe}"), (label_column, "@original_label")])
    p = figure(plot_width=800, plot_height=800, tools=[hover], title="Embedding Lenses")
    p.circle("x", "y", size=10, source=source, fill_color=factor_cmap("label", palette=[Pallete[int(id_)] for id_ in values_color_set], factors=values_set))
    return p


def uploaded_file_to_dataframe(uploaded_file: st.uploaded_file_manager.UploadedFile) -> pd.DataFrame:
    extension = uploaded_file.name.split(".")[-1]
    return pd.read_csv(uploaded_file, sep="\t" if extension == "tsv" else ",")


def hub_dataset_to_dataframe(path: str, name: str, split: str, sample: int) -> pd.DataFrame:
    load_dataset_fn = partial(load_dataset, path=path)
    if name:
        load_dataset_fn = partial(load_dataset_fn, name=name)
    if split:
        load_dataset_fn = partial(load_dataset_fn, split=split)
    dataset = load_dataset_fn().shuffle(seed=SEED)[:sample]
    return pd.DataFrame(dataset)


def generate_plot(
    df: pd.DataFrame,
    text_column: str,
    label_column: str,
    sample: Optional[int],
    dimensionality_reduction_function: Callable,
    model: SentenceTransformer,
) -> Figure:
    logger.info("Loading dataset in memory")
    if text_column not in df.columns:
        raise ValueError(f"The specified column name doesn't exist. Columns available: {df.columns.values}")
    if label_column not in df.columns:
        df[label_column] = 0
    df = df.dropna(subset=[text_column, label_column])
    if sample:
        df = df.sample(min(sample, df.shape[0]), random_state=SEED)
    logger.info("Embedding sentences")
    embeddings = embed_text(df[text_column].values.tolist(), model)
    logger.info("Encoding labels")
    encoded_labels = encode_labels(df[label_column])
    logger.info("Running dimensionality reduction")
    embeddings_2d = dimensionality_reduction_function(embeddings)
    logger.info("Generating figure")
    plot = draw_interactive_scatter_plot(
        df[text_column].values, embeddings_2d[:, 0], embeddings_2d[:, 1], encoded_labels.values, df[label_column].values, text_column, label_column
    )
    return plot


st.title("Embedding Lenses")
st.write("Visualize text embeddings in 2D using colors for continuous or categorical labels.")
uploaded_file = st.file_uploader("Choose an csv/tsv file...", type=["csv", "tsv"])
st.write("Alternatively, select a dataset from the hub")
col1, col2, col3 = st.columns(3)
with col1:
    hub_dataset = st.text_input("Dataset name", "ag_news")
with col2:
    hub_dataset_config = st.text_input("Dataset configuration", "")
with col3:
    hub_dataset_split = st.text_input("Dataset split", "train")

text_column = st.text_input("Text column name", "text")
label_column = st.text_input("Numerical/categorical column name (ignore if not applicable)", "label")
sample = st.number_input("Maximum number of documents to use", 1, 100000, 1000)
dimensionality_reduction = st.selectbox("Dimensionality Reduction algorithm", ["UMAP", "t-SNE"], 0)
model_name = st.selectbox("Sentence embedding model", ["distiluse-base-multilingual-cased-v1", "all-mpnet-base-v2"], 0)
model = load_model(model_name)
dimensionality_reduction_function = get_umap_embeddings if dimensionality_reduction == "UMAP" else get_tsne_embeddings

if uploaded_file or hub_dataset:
    if uploaded_file:
        df = uploaded_file_to_dataframe(uploaded_file)
    else:
        df = hub_dataset_to_dataframe(hub_dataset, hub_dataset_config, hub_dataset_split, sample)
    plot = generate_plot(df, text_column, label_column, sample, dimensionality_reduction_function, model)
    logger.info("Displaying plot")
    st.bokeh_chart(plot)
    logger.info("Done")