Spaces:
Runtime error
Runtime error
Fix logs, type hints and improve error message
Browse files
app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import logging
|
2 |
-
from typing import Any, Callable, List, Optional
|
3 |
from functools import partial
|
|
|
4 |
|
5 |
import numpy as np
|
6 |
import pandas as pd
|
@@ -8,7 +8,7 @@ import streamlit as st
|
|
8 |
import umap
|
9 |
from bokeh.models import ColumnDataSource, HoverTool
|
10 |
from bokeh.palettes import Cividis256 as Pallete
|
11 |
-
from bokeh.plotting import figure
|
12 |
from bokeh.transform import factor_cmap
|
13 |
from datasets import load_dataset
|
14 |
from sentence_transformers import SentenceTransformer
|
@@ -20,7 +20,7 @@ SEED = 0
|
|
20 |
|
21 |
|
22 |
@st.cache(show_spinner=False, allow_output_mutation=True)
|
23 |
-
def load_model(model_name):
|
24 |
embedder = model_name
|
25 |
return SentenceTransformer(embedder)
|
26 |
|
@@ -49,7 +49,7 @@ def get_umap_embeddings(embeddings: np.ndarray) -> np.ndarray:
|
|
49 |
|
50 |
def draw_interactive_scatter_plot(
|
51 |
texts: np.ndarray, xs: np.ndarray, ys: np.ndarray, values: np.ndarray, labels: np.ndarray, text_column: str, label_column: str
|
52 |
-
) ->
|
53 |
# Normalize values to range between 0-255, to assign a color for each value
|
54 |
max_value = values.max()
|
55 |
min_value = values.min()
|
@@ -75,7 +75,7 @@ def uploaded_file_to_dataframe(uploaded_file: st.uploaded_file_manager.UploadedF
|
|
75 |
return pd.read_csv(uploaded_file, sep="\t" if extension == "tsv" else ",")
|
76 |
|
77 |
|
78 |
-
def hub_dataset_to_dataframe(path: str, name: str, split: str,
|
79 |
load_dataset_fn = partial(load_dataset, path=path)
|
80 |
if name:
|
81 |
load_dataset_fn = partial(load_dataset_fn, name=name)
|
@@ -92,10 +92,10 @@ def generate_plot(
|
|
92 |
sample: Optional[int],
|
93 |
dimensionality_reduction_function: Callable,
|
94 |
model: SentenceTransformer,
|
95 |
-
):
|
96 |
logger.info("Loading dataset in memory")
|
97 |
if text_column not in df.columns:
|
98 |
-
raise ValueError("The specified column name doesn't exist")
|
99 |
if label_column not in df.columns:
|
100 |
df[label_column] = 0
|
101 |
df = df.dropna(subset=[text_column, label_column])
|
@@ -138,9 +138,8 @@ if uploaded_file or hub_dataset:
|
|
138 |
if uploaded_file:
|
139 |
df = uploaded_file_to_dataframe(uploaded_file)
|
140 |
else:
|
141 |
-
df = hub_dataset_to_dataframe(hub_dataset, hub_dataset_config, hub_dataset_split,
|
142 |
plot = generate_plot(df, text_column, label_column, sample, dimensionality_reduction_function, model)
|
143 |
-
print(type(plot))
|
144 |
logger.info("Displaying plot")
|
145 |
st.bokeh_chart(plot)
|
146 |
logger.info("Done")
|
|
|
1 |
import logging
|
|
|
2 |
from functools import partial
|
3 |
+
from typing import Callable, List, Optional
|
4 |
|
5 |
import numpy as np
|
6 |
import pandas as pd
|
|
|
8 |
import umap
|
9 |
from bokeh.models import ColumnDataSource, HoverTool
|
10 |
from bokeh.palettes import Cividis256 as Pallete
|
11 |
+
from bokeh.plotting import Figure, figure
|
12 |
from bokeh.transform import factor_cmap
|
13 |
from datasets import load_dataset
|
14 |
from sentence_transformers import SentenceTransformer
|
|
|
20 |
|
21 |
|
22 |
@st.cache(show_spinner=False, allow_output_mutation=True)
|
23 |
+
def load_model(model_name: str) -> SentenceTransformer:
|
24 |
embedder = model_name
|
25 |
return SentenceTransformer(embedder)
|
26 |
|
|
|
49 |
|
50 |
def draw_interactive_scatter_plot(
|
51 |
texts: np.ndarray, xs: np.ndarray, ys: np.ndarray, values: np.ndarray, labels: np.ndarray, text_column: str, label_column: str
|
52 |
+
) -> Figure:
|
53 |
# Normalize values to range between 0-255, to assign a color for each value
|
54 |
max_value = values.max()
|
55 |
min_value = values.min()
|
|
|
75 |
return pd.read_csv(uploaded_file, sep="\t" if extension == "tsv" else ",")
|
76 |
|
77 |
|
78 |
+
def hub_dataset_to_dataframe(path: str, name: str, split: str, sample: int) -> pd.DataFrame:
|
79 |
load_dataset_fn = partial(load_dataset, path=path)
|
80 |
if name:
|
81 |
load_dataset_fn = partial(load_dataset_fn, name=name)
|
|
|
92 |
sample: Optional[int],
|
93 |
dimensionality_reduction_function: Callable,
|
94 |
model: SentenceTransformer,
|
95 |
+
) -> Figure:
|
96 |
logger.info("Loading dataset in memory")
|
97 |
if text_column not in df.columns:
|
98 |
+
raise ValueError(f"The specified column name doesn't exist. Columns available: {df.columns.values}")
|
99 |
if label_column not in df.columns:
|
100 |
df[label_column] = 0
|
101 |
df = df.dropna(subset=[text_column, label_column])
|
|
|
138 |
if uploaded_file:
|
139 |
df = uploaded_file_to_dataframe(uploaded_file)
|
140 |
else:
|
141 |
+
df = hub_dataset_to_dataframe(hub_dataset, hub_dataset_config, hub_dataset_split, sample)
|
142 |
plot = generate_plot(df, text_column, label_column, sample, dimensionality_reduction_function, model)
|
|
|
143 |
logger.info("Displaying plot")
|
144 |
st.bokeh_chart(plot)
|
145 |
logger.info("Done")
|