edugp commited on
Commit
bcd62f3
·
1 Parent(s): 118c8b8

Add support for datasets from the Hugging Face hub and default dataset

Browse files
Files changed (1) hide show
  1. app.py +34 -5
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import logging
2
  from typing import Any, Callable, List, Optional
 
3
 
4
  import numpy as np
5
  import pandas as pd
@@ -9,6 +10,7 @@ from bokeh.models import ColumnDataSource, HoverTool
9
  from bokeh.palettes import Cividis256 as Pallete
10
  from bokeh.plotting import figure
11
  from bokeh.transform import factor_cmap
 
12
  from sentence_transformers import SentenceTransformer
13
  from sklearn.manifold import TSNE
14
 
@@ -68,8 +70,23 @@ def draw_interactive_scatter_plot(
68
  return p
69
 
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  def generate_plot(
72
- uploaded_file: st.uploaded_file_manager.UploadedFile,
73
  text_column: str,
74
  label_column: str,
75
  sample: Optional[int],
@@ -77,8 +94,6 @@ def generate_plot(
77
  model: SentenceTransformer,
78
  ):
79
  logger.info("Loading dataset in memory")
80
- extension = uploaded_file.name.split(".")[-1]
81
- df = pd.read_csv(uploaded_file, sep="\t" if extension == "tsv" else ",")
82
  if text_column not in df.columns:
83
  raise ValueError("The specified column name doesn't exist")
84
  if label_column not in df.columns:
@@ -102,6 +117,15 @@ def generate_plot(
102
  st.title("Embedding Lenses")
103
  st.write("Visualize text embeddings in 2D using colors for continuous or categorical labels.")
104
  uploaded_file = st.file_uploader("Choose an csv/tsv file...", type=["csv", "tsv"])
 
 
 
 
 
 
 
 
 
105
  text_column = st.text_input("Text column name", "text")
106
  label_column = st.text_input("Numerical/categorical column name (ignore if not applicable)", "label")
107
  sample = st.number_input("Maximum number of documents to use", 1, 100000, 1000)
@@ -110,8 +134,13 @@ model_name = st.selectbox("Sentence embedding model", ["distiluse-base-multiling
110
  model = load_model(model_name)
111
  dimensionality_reduction_function = get_umap_embeddings if dimensionality_reduction == "UMAP" else get_tsne_embeddings
112
 
113
- if uploaded_file:
114
- plot = generate_plot(uploaded_file, text_column, label_column, sample, dimensionality_reduction_function, model)
 
 
 
 
 
115
  logger.info("Displaying plot")
116
  st.bokeh_chart(plot)
117
  logger.info("Done")
 
1
  import logging
2
  from typing import Any, Callable, List, Optional
3
+ from functools import partial
4
 
5
  import numpy as np
6
  import pandas as pd
 
10
  from bokeh.palettes import Cividis256 as Pallete
11
  from bokeh.plotting import figure
12
  from bokeh.transform import factor_cmap
13
+ from datasets import load_dataset
14
  from sentence_transformers import SentenceTransformer
15
  from sklearn.manifold import TSNE
16
 
 
70
  return p
71
 
72
 
73
+ def uploaded_file_to_dataframe(uploaded_file: st.uploaded_file_manager.UploadedFile) -> pd.DataFrame:
74
+ extension = uploaded_file.name.split(".")[-1]
75
+ return pd.read_csv(uploaded_file, sep="\t" if extension == "tsv" else ",")
76
+
77
+
78
+ def hub_dataset_to_dataframe(path: str, name: str, split: str, text_column: str, label_column: str, sample: int) -> pd.DataFrame:
79
+ load_dataset_fn = partial(load_dataset, path=path)
80
+ if name:
81
+ load_dataset_fn = partial(load_dataset_fn, name=name)
82
+ if split:
83
+ load_dataset_fn = partial(load_dataset_fn, split=split)
84
+ dataset = load_dataset_fn().shuffle()[:sample]
85
+ return pd.DataFrame(dataset)
86
+
87
+
88
  def generate_plot(
89
+ df: pd.DataFrame,
90
  text_column: str,
91
  label_column: str,
92
  sample: Optional[int],
 
94
  model: SentenceTransformer,
95
  ):
96
  logger.info("Loading dataset in memory")
 
 
97
  if text_column not in df.columns:
98
  raise ValueError("The specified column name doesn't exist")
99
  if label_column not in df.columns:
 
117
  st.title("Embedding Lenses")
118
  st.write("Visualize text embeddings in 2D using colors for continuous or categorical labels.")
119
  uploaded_file = st.file_uploader("Choose an csv/tsv file...", type=["csv", "tsv"])
120
+ st.write("Alternatively, select a dataset from the hub")
121
+ col1, col2, col3 = st.beta_columns(3)
122
+ with col1:
123
+ hub_dataset = st.text_input("Dataset name", "ag_news")
124
+ with col2:
125
+ hub_dataset_config = st.text_input("Dataset configuration", "")
126
+ with col3:
127
+ hub_dataset_split = st.text_input("Dataset split", "train")
128
+
129
  text_column = st.text_input("Text column name", "text")
130
  label_column = st.text_input("Numerical/categorical column name (ignore if not applicable)", "label")
131
  sample = st.number_input("Maximum number of documents to use", 1, 100000, 1000)
 
134
  model = load_model(model_name)
135
  dimensionality_reduction_function = get_umap_embeddings if dimensionality_reduction == "UMAP" else get_tsne_embeddings
136
 
137
+ if uploaded_file or hub_dataset:
138
+ if uploaded_file:
139
+ df = uploaded_file_to_dataframe(uploaded_file)
140
+ else:
141
+ df = hub_dataset_to_dataframe(hub_dataset, hub_dataset_config, hub_dataset_split, text_column, label_column, sample)
142
+ plot = generate_plot(df, text_column, label_column, sample, dimensionality_reduction_function, model)
143
+ print(type(plot))
144
  logger.info("Displaying plot")
145
  st.bokeh_chart(plot)
146
  logger.info("Done")