edugp commited on
Commit
737452a
·
1 Parent(s): 5aad559

Avoid model mutation warning

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -22,9 +22,8 @@ def load_model():
22
  return SentenceTransformer(embedder)
23
 
24
 
25
- def embed_text(text: List[str]) -> np.ndarray:
26
- embedder_model = load_model()
27
- return embedder_model.encode(text)
28
 
29
 
30
  def encode_labels(labels: pd.Series) -> pd.Series:
@@ -60,7 +59,7 @@ def draw_interactive_scatter_plot(
60
  return p
61
 
62
 
63
- def generate_plot(tsv: st.uploaded_file_manager.UploadedFile, text_column: str, label_column: str, sample: Optional[int]):
64
  logger.info("Loading dataset in memory")
65
  df = pd.read_csv(tsv, sep="\t")
66
  if label_column not in df.columns:
@@ -69,7 +68,7 @@ def generate_plot(tsv: st.uploaded_file_manager.UploadedFile, text_column: str,
69
  if sample:
70
  df = df.sample(min(sample, df.shape[0]), random_state=SEED)
71
  logger.info("Embedding sentences")
72
- embeddings = embed_text(df[text_column].values.tolist())
73
  logger.info("Encoding labels")
74
  encoded_labels = encode_labels(df[label_column])
75
  logger.info("Running t-SNE")
@@ -86,9 +85,10 @@ uploaded_file = st.file_uploader("Choose an csv/tsv file...", type=["csv", "tsv"
86
  text_column = st.text_input("Text column name", "text")
87
  label_column = st.text_input("Numerical/categorical column name (ignore if not applicable)", "label")
88
  sample = st.number_input("Maximum number of documents to use", 1, 100000, 1000)
 
89
 
90
  if uploaded_file:
91
- plot = generate_plot(uploaded_file, text_column, label_column, sample)
92
  logger.info("Displaying plot")
93
  st.bokeh_chart(plot)
94
  logger.info("Done")
 
22
  return SentenceTransformer(embedder)
23
 
24
 
25
+ def embed_text(text: List[str], model: SentenceTransformer) -> np.ndarray:
26
+ return model.encode(text)
 
27
 
28
 
29
  def encode_labels(labels: pd.Series) -> pd.Series:
 
59
  return p
60
 
61
 
62
+ def generate_plot(tsv: st.uploaded_file_manager.UploadedFile, text_column: str, label_column: str, sample: Optional[int], model: SentenceTransformer):
63
  logger.info("Loading dataset in memory")
64
  df = pd.read_csv(tsv, sep="\t")
65
  if label_column not in df.columns:
 
68
  if sample:
69
  df = df.sample(min(sample, df.shape[0]), random_state=SEED)
70
  logger.info("Embedding sentences")
71
+ embeddings = embed_text(df[text_column].values.tolist(), model)
72
  logger.info("Encoding labels")
73
  encoded_labels = encode_labels(df[label_column])
74
  logger.info("Running t-SNE")
 
85
  text_column = st.text_input("Text column name", "text")
86
  label_column = st.text_input("Numerical/categorical column name (ignore if not applicable)", "label")
87
  sample = st.number_input("Maximum number of documents to use", 1, 100000, 1000)
88
+ model = load_model()
89
 
90
  if uploaded_file:
91
+ plot = generate_plot(uploaded_file, text_column, label_column, sample, model)
92
  logger.info("Displaying plot")
93
  st.bokeh_chart(plot)
94
  logger.info("Done")