edugp commited on
Commit
a9d1447
·
1 Parent(s): abd3459

Fix logs, type hints and improve error message

Browse files
Files changed (1) hide show
  1. app.py +8 -9
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import logging
2
- from typing import Any, Callable, List, Optional
3
  from functools import partial
 
4
 
5
  import numpy as np
6
  import pandas as pd
@@ -8,7 +8,7 @@ import streamlit as st
8
  import umap
9
  from bokeh.models import ColumnDataSource, HoverTool
10
  from bokeh.palettes import Cividis256 as Pallete
11
- from bokeh.plotting import figure
12
  from bokeh.transform import factor_cmap
13
  from datasets import load_dataset
14
  from sentence_transformers import SentenceTransformer
@@ -20,7 +20,7 @@ SEED = 0
20
 
21
 
22
  @st.cache(show_spinner=False, allow_output_mutation=True)
23
- def load_model(model_name):
24
  embedder = model_name
25
  return SentenceTransformer(embedder)
26
 
@@ -49,7 +49,7 @@ def get_umap_embeddings(embeddings: np.ndarray) -> np.ndarray:
49
 
50
  def draw_interactive_scatter_plot(
51
  texts: np.ndarray, xs: np.ndarray, ys: np.ndarray, values: np.ndarray, labels: np.ndarray, text_column: str, label_column: str
52
- ) -> Any:
53
  # Normalize values to range between 0-255, to assign a color for each value
54
  max_value = values.max()
55
  min_value = values.min()
@@ -75,7 +75,7 @@ def uploaded_file_to_dataframe(uploaded_file: st.uploaded_file_manager.UploadedF
75
  return pd.read_csv(uploaded_file, sep="\t" if extension == "tsv" else ",")
76
 
77
 
78
- def hub_dataset_to_dataframe(path: str, name: str, split: str, text_column: str, label_column: str, sample: int) -> pd.DataFrame:
79
  load_dataset_fn = partial(load_dataset, path=path)
80
  if name:
81
  load_dataset_fn = partial(load_dataset_fn, name=name)
@@ -92,10 +92,10 @@ def generate_plot(
92
  sample: Optional[int],
93
  dimensionality_reduction_function: Callable,
94
  model: SentenceTransformer,
95
- ):
96
  logger.info("Loading dataset in memory")
97
  if text_column not in df.columns:
98
- raise ValueError("The specified column name doesn't exist")
99
  if label_column not in df.columns:
100
  df[label_column] = 0
101
  df = df.dropna(subset=[text_column, label_column])
@@ -138,9 +138,8 @@ if uploaded_file or hub_dataset:
138
  if uploaded_file:
139
  df = uploaded_file_to_dataframe(uploaded_file)
140
  else:
141
- df = hub_dataset_to_dataframe(hub_dataset, hub_dataset_config, hub_dataset_split, text_column, label_column, sample)
142
  plot = generate_plot(df, text_column, label_column, sample, dimensionality_reduction_function, model)
143
- print(type(plot))
144
  logger.info("Displaying plot")
145
  st.bokeh_chart(plot)
146
  logger.info("Done")
 
1
  import logging
 
2
  from functools import partial
3
+ from typing import Callable, List, Optional
4
 
5
  import numpy as np
6
  import pandas as pd
 
8
  import umap
9
  from bokeh.models import ColumnDataSource, HoverTool
10
  from bokeh.palettes import Cividis256 as Pallete
11
+ from bokeh.plotting import Figure, figure
12
  from bokeh.transform import factor_cmap
13
  from datasets import load_dataset
14
  from sentence_transformers import SentenceTransformer
 
20
 
21
 
22
  @st.cache(show_spinner=False, allow_output_mutation=True)
23
+ def load_model(model_name: str) -> SentenceTransformer:
24
  embedder = model_name
25
  return SentenceTransformer(embedder)
26
 
 
49
 
50
  def draw_interactive_scatter_plot(
51
  texts: np.ndarray, xs: np.ndarray, ys: np.ndarray, values: np.ndarray, labels: np.ndarray, text_column: str, label_column: str
52
+ ) -> Figure:
53
  # Normalize values to range between 0-255, to assign a color for each value
54
  max_value = values.max()
55
  min_value = values.min()
 
75
  return pd.read_csv(uploaded_file, sep="\t" if extension == "tsv" else ",")
76
 
77
 
78
+ def hub_dataset_to_dataframe(path: str, name: str, split: str, sample: int) -> pd.DataFrame:
79
  load_dataset_fn = partial(load_dataset, path=path)
80
  if name:
81
  load_dataset_fn = partial(load_dataset_fn, name=name)
 
92
  sample: Optional[int],
93
  dimensionality_reduction_function: Callable,
94
  model: SentenceTransformer,
95
+ ) -> Figure:
96
  logger.info("Loading dataset in memory")
97
  if text_column not in df.columns:
98
+ raise ValueError(f"The specified column name doesn't exist. Columns available: {df.columns.values}")
99
  if label_column not in df.columns:
100
  df[label_column] = 0
101
  df = df.dropna(subset=[text_column, label_column])
 
138
  if uploaded_file:
139
  df = uploaded_file_to_dataframe(uploaded_file)
140
  else:
141
+ df = hub_dataset_to_dataframe(hub_dataset, hub_dataset_config, hub_dataset_split, sample)
142
  plot = generate_plot(df, text_column, label_column, sample, dimensionality_reduction_function, model)
 
143
  logger.info("Displaying plot")
144
  st.bokeh_chart(plot)
145
  logger.info("Done")