edugp commited on
Commit
9735252
·
1 Parent(s): 262af65

Allow for selecting dimensionality reduction techniques and sentence embedding model. Add UMAP and all-mpnet-base-v2.

Browse files
Files changed (2) hide show
  1. app.py +25 -9
  2. requirements.txt +3 -1
app.py CHANGED
@@ -1,9 +1,10 @@
1
  import logging
2
- from typing import Any, List, Optional
3
 
4
  import numpy as np
5
  import pandas as pd
6
  import streamlit as st
 
7
  from bokeh.models import ColumnDataSource, HoverTool
8
  from bokeh.palettes import Cividis256 as Pallete
9
  from bokeh.plotting import figure
@@ -17,8 +18,8 @@ SEED = 0
17
 
18
 
19
  @st.cache(show_spinner=False, allow_output_mutation=True)
20
- def load_model():
21
- embedder = "distiluse-base-multilingual-cased-v1"
22
  return SentenceTransformer(embedder)
23
 
24
 
@@ -39,6 +40,11 @@ def get_tsne_embeddings(
39
  return tsne.fit_transform(embeddings)
40
 
41
 
 
 
 
 
 
42
  def draw_interactive_scatter_plot(
43
  texts: np.ndarray, xs: np.ndarray, ys: np.ndarray, values: np.ndarray, labels: np.ndarray, text_column: str, label_column: str
44
  ) -> Any:
@@ -62,7 +68,14 @@ def draw_interactive_scatter_plot(
62
  return p
63
 
64
 
65
- def generate_plot(uploaded_file: st.uploaded_file_manager.UploadedFile, text_column: str, label_column: str, sample: Optional[int], model: SentenceTransformer):
 
 
 
 
 
 
 
66
  logger.info("Loading dataset in memory")
67
  extension = uploaded_file.name.split(".")[-1]
68
  df = pd.read_csv(uploaded_file, sep="\t" if extension == "tsv" else ",")
@@ -77,11 +90,11 @@ def generate_plot(uploaded_file: st.uploaded_file_manager.UploadedFile, text_col
77
  embeddings = embed_text(df[text_column].values.tolist(), model)
78
  logger.info("Encoding labels")
79
  encoded_labels = encode_labels(df[label_column])
80
- logger.info("Running t-SNE")
81
- tsne_embeddings = get_tsne_embeddings(embeddings)
82
  logger.info("Generating figure")
83
  plot = draw_interactive_scatter_plot(
84
- df[text_column].values, tsne_embeddings[:, 0], tsne_embeddings[:, 1], encoded_labels.values, df[label_column].values, text_column, label_column
85
  )
86
  return plot
87
 
@@ -92,10 +105,13 @@ uploaded_file = st.file_uploader("Choose an csv/tsv file...", type=["csv", "tsv"
92
  text_column = st.text_input("Text column name", "text")
93
  label_column = st.text_input("Numerical/categorical column name (ignore if not applicable)", "label")
94
  sample = st.number_input("Maximum number of documents to use", 1, 100000, 1000)
95
- model = load_model()
 
 
 
96
 
97
  if uploaded_file:
98
- plot = generate_plot(uploaded_file, text_column, label_column, sample, model)
99
  logger.info("Displaying plot")
100
  st.bokeh_chart(plot)
101
  logger.info("Done")
 
1
  import logging
2
+ from typing import Any, Callable, List, Optional
3
 
4
  import numpy as np
5
  import pandas as pd
6
  import streamlit as st
7
+ import umap
8
  from bokeh.models import ColumnDataSource, HoverTool
9
  from bokeh.palettes import Cividis256 as Pallete
10
  from bokeh.plotting import figure
 
18
 
19
 
20
  @st.cache(show_spinner=False, allow_output_mutation=True)
21
+ def load_model(model_name):
22
+ embedder = model_name
23
  return SentenceTransformer(embedder)
24
 
25
 
 
40
  return tsne.fit_transform(embeddings)
41
 
42
 
43
+ def get_umap_embeddings(embeddings: np.ndarray) -> np.ndarray:
44
+ umap_model = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=SEED)
45
+ return umap_model.fit_transform(embeddings)
46
+
47
+
48
  def draw_interactive_scatter_plot(
49
  texts: np.ndarray, xs: np.ndarray, ys: np.ndarray, values: np.ndarray, labels: np.ndarray, text_column: str, label_column: str
50
  ) -> Any:
 
68
  return p
69
 
70
 
71
+ def generate_plot(
72
+ uploaded_file: st.uploaded_file_manager.UploadedFile,
73
+ text_column: str,
74
+ label_column: str,
75
+ sample: Optional[int],
76
+ dimensionality_reduction_function: Callable,
77
+ model: SentenceTransformer,
78
+ ):
79
  logger.info("Loading dataset in memory")
80
  extension = uploaded_file.name.split(".")[-1]
81
  df = pd.read_csv(uploaded_file, sep="\t" if extension == "tsv" else ",")
 
90
  embeddings = embed_text(df[text_column].values.tolist(), model)
91
  logger.info("Encoding labels")
92
  encoded_labels = encode_labels(df[label_column])
93
+ logger.info("Running dimensionality reduction")
94
+ embeddings_2d = dimensionality_reduction_function(embeddings)
95
  logger.info("Generating figure")
96
  plot = draw_interactive_scatter_plot(
97
+ df[text_column].values, embeddings_2d[:, 0], embeddings_2d[:, 1], encoded_labels.values, df[label_column].values, text_column, label_column
98
  )
99
  return plot
100
 
 
105
  text_column = st.text_input("Text column name", "text")
106
  label_column = st.text_input("Numerical/categorical column name (ignore if not applicable)", "label")
107
  sample = st.number_input("Maximum number of documents to use", 1, 100000, 1000)
108
+ dimensionality_reduction = st.selectbox("Dimensionality Reduction algorithm", ["UMAP", "t-SNE"], 0)
109
+ model_name = st.selectbox("Sentence embedding model", ["distiluse-base-multilingual-cased-v1", "all-mpnet-base-v2"], 0)
110
+ model = load_model(model_name)
111
+ dimensionality_reduction_function = get_umap_embeddings if dimensionality_reduction == "UMAP" else get_tsne_embeddings
112
 
113
  if uploaded_file:
114
+ plot = generate_plot(uploaded_file, text_column, label_column, sample, dimensionality_reduction_function, model)
115
  logger.info("Displaying plot")
116
  st.bokeh_chart(plot)
117
  logger.info("Done")
requirements.txt CHANGED
@@ -3,4 +3,6 @@ streamlit==0.84.1
3
  transformers==4.8.2
4
  watchdog==2.1.3
5
  sentence-transformers==2.0.0
6
- bokeh==2.2.2
 
 
 
3
  transformers==4.8.2
4
  watchdog==2.1.3
5
  sentence-transformers==2.0.0
6
+ bokeh==2.2.2
7
+ umap-learn==0.5.1
8
+ numpy==1.20.0