Spaces:
Runtime error
Runtime error
Avoid model mutation warning
Browse files
app.py
CHANGED
@@ -22,9 +22,8 @@ def load_model():
|
|
22 |
return SentenceTransformer(embedder)
|
23 |
|
24 |
|
25 |
-
def embed_text(text: List[str]) -> np.ndarray:
|
26 |
-
|
27 |
-
return embedder_model.encode(text)
|
28 |
|
29 |
|
30 |
def encode_labels(labels: pd.Series) -> pd.Series:
|
@@ -60,7 +59,7 @@ def draw_interactive_scatter_plot(
|
|
60 |
return p
|
61 |
|
62 |
|
63 |
-
def generate_plot(tsv: st.uploaded_file_manager.UploadedFile, text_column: str, label_column: str, sample: Optional[int]):
|
64 |
logger.info("Loading dataset in memory")
|
65 |
df = pd.read_csv(tsv, sep="\t")
|
66 |
if label_column not in df.columns:
|
@@ -69,7 +68,7 @@ def generate_plot(tsv: st.uploaded_file_manager.UploadedFile, text_column: str,
|
|
69 |
if sample:
|
70 |
df = df.sample(min(sample, df.shape[0]), random_state=SEED)
|
71 |
logger.info("Embedding sentences")
|
72 |
-
embeddings = embed_text(df[text_column].values.tolist())
|
73 |
logger.info("Encoding labels")
|
74 |
encoded_labels = encode_labels(df[label_column])
|
75 |
logger.info("Running t-SNE")
|
@@ -86,9 +85,10 @@ uploaded_file = st.file_uploader("Choose an csv/tsv file...", type=["csv", "tsv"
|
|
86 |
text_column = st.text_input("Text column name", "text")
|
87 |
label_column = st.text_input("Numerical/categorical column name (ignore if not applicable)", "label")
|
88 |
sample = st.number_input("Maximum number of documents to use", 1, 100000, 1000)
|
|
|
89 |
|
90 |
if uploaded_file:
|
91 |
-
plot = generate_plot(uploaded_file, text_column, label_column, sample)
|
92 |
logger.info("Displaying plot")
|
93 |
st.bokeh_chart(plot)
|
94 |
logger.info("Done")
|
|
|
22 |
return SentenceTransformer(embedder)
|
23 |
|
24 |
|
25 |
+
def embed_text(text: List[str], model: SentenceTransformer) -> np.ndarray:
|
26 |
+
return model.encode(text)
|
|
|
27 |
|
28 |
|
29 |
def encode_labels(labels: pd.Series) -> pd.Series:
|
|
|
59 |
return p
|
60 |
|
61 |
|
62 |
+
def generate_plot(tsv: st.uploaded_file_manager.UploadedFile, text_column: str, label_column: str, sample: Optional[int], model: SentenceTransformer):
|
63 |
logger.info("Loading dataset in memory")
|
64 |
df = pd.read_csv(tsv, sep="\t")
|
65 |
if label_column not in df.columns:
|
|
|
68 |
if sample:
|
69 |
df = df.sample(min(sample, df.shape[0]), random_state=SEED)
|
70 |
logger.info("Embedding sentences")
|
71 |
+
embeddings = embed_text(df[text_column].values.tolist(), model)
|
72 |
logger.info("Encoding labels")
|
73 |
encoded_labels = encode_labels(df[label_column])
|
74 |
logger.info("Running t-SNE")
|
|
|
85 |
text_column = st.text_input("Text column name", "text")
|
86 |
label_column = st.text_input("Numerical/categorical column name (ignore if not applicable)", "label")
|
87 |
sample = st.number_input("Maximum number of documents to use", 1, 100000, 1000)
|
88 |
+
model = load_model()
|
89 |
|
90 |
if uploaded_file:
|
91 |
+
plot = generate_plot(uploaded_file, text_column, label_column, sample, model)
|
92 |
logger.info("Displaying plot")
|
93 |
st.bokeh_chart(plot)
|
94 |
logger.info("Done")
|