Allob commited on
Commit
ee5ab0e
·
1 Parent(s): 0df4e9d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -60
app.py CHANGED
@@ -3,14 +3,13 @@ import plotly.express as px
3
  import pandas as pd
4
  import random
5
  import logging
6
- from umap import UMAP
7
  from sentence_transformers import SentenceTransformer, util
8
  from datasets import load_dataset
9
 
10
 
11
  @st.cache_resource
12
- def load_model():
13
- return SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
14
 
15
 
16
  @st.cache_data
@@ -24,23 +23,20 @@ def choose_secret_word():
24
  return random.choice(all_words)
25
 
26
 
27
- @st.cache_resource
28
- def prepare_umap():
29
- all_enc = model.encode(all_words)
30
- umap_3d = UMAP(n_components=3, init='random', random_state=0)
31
- proj_3d = umap_3d.fit_transform(random.sample(all_enc.tolist(), k=1000))
32
- return umap_3d
33
-
34
-
35
  all_words = load_words_dataset()
36
 
37
- model = load_model()
38
 
39
- umap_3d = prepare_umap()
 
 
 
40
 
 
 
 
41
 
42
- secret_word =choose_secret_word()
43
- secret_embedding = model.encode(secret_word.lower().strip())
44
 
45
  print("Secret word ", secret_word)
46
 
@@ -48,27 +44,6 @@ print("Secret word ", secret_word)
48
  if 'words' not in st.session_state:
49
  st.session_state['words'] = []
50
 
51
- if 'words_umap_df' not in st.session_state:
52
- words_umap_df = pd.DataFrame({
53
- "x": [],
54
- "y": [],
55
- "z": [],
56
- "similarity": [],
57
- "s": [],
58
- "l": [],
59
- })
60
- st.session_state['words_umap_df'] = words_umap_df
61
- secret_embedding_3d = umap_3d.transform([secret_embedding])[0]
62
- words_umap_df.loc[len(words_umap_df)] = {
63
- "x": secret_embedding_3d[0],
64
- "y": secret_embedding_3d[1],
65
- "z": secret_embedding_3d[2],
66
- "similarity": 1,
67
- "s": 10,
68
- "l": "Secret word"
69
- }
70
- st.session_state['words_umap_df'] = words_umap_df
71
-
72
 
73
 
74
 
@@ -80,32 +55,15 @@ used_words = [w for w, s in st.session_state['words']]
80
 
81
  if st.button("Guess") or word:
82
  if word not in used_words:
83
- word_embedding = model.encode(word.lower().strip())
84
- similarity = util.pytorch_cos_sim(
85
- secret_embedding,
86
- word_embedding
87
- ).cpu().numpy()[0][0]
88
- st.session_state['words'].append((str(word), similarity))
89
-
90
- pt = umap_3d.transform([word_embedding])[0]
91
- words_umap_df = st.session_state['words_umap_df']
92
- words_umap_df.loc[len(words_umap_df)] = {
93
- "x": pt[0],
94
- "y": pt[1],
95
- "z": pt[2],
96
- "similarity": similarity,
97
- "s": 3,
98
- "l": str(word)
99
- }
100
- st.session_state['words_umap_df'] = words_umap_df
101
 
102
  words_df = pd.DataFrame(
103
  st.session_state['words'],
104
- columns=["word", "similarity"]
105
- ).sort_values(by=["similarity"], ascending=False)
106
  st.dataframe(words_df, use_container_width=True)
107
 
108
-
109
- words_umap_df = st.session_state['words_umap_df']
110
- fig_3d = px.scatter_3d(words_umap_df, x="x", y="y", z="z", color="similarity", hover_name="l", hover_data={"x": False, "y": False, "z": False, "s": False}, size="s", size_max=10, range_color=(0,1))
111
- st.plotly_chart(fig_3d, theme="streamlit", use_container_width=True)
 
3
  import pandas as pd
4
  import random
5
  import logging
 
6
  from sentence_transformers import SentenceTransformer, util
7
  from datasets import load_dataset
8
 
9
 
10
  @st.cache_resource
11
+ def load_model(name):
12
+ return SentenceTransformer(name)
13
 
14
 
15
  @st.cache_data
 
23
  return random.choice(all_words)
24
 
25
 
 
 
 
 
 
 
 
 
26
  all_words = load_words_dataset()
27
 
 
28
 
29
+ model_names = [
30
+ 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2',
31
+ 'BAAI/bge-small-en-v1.5'
32
+ ]
33
 
34
+ models = {
35
+ name: load_model(name) for name in model_names
36
+ }
37
 
38
+ secret_word =choose_secret_word().lower().strip()
39
+ secret_embedding = [models[name].encode(secret_word) for name in model_names]
40
 
41
  print("Secret word ", secret_word)
42
 
 
44
  if 'words' not in st.session_state:
45
  st.session_state['words'] = []
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
 
49
 
 
55
 
56
  if st.button("Guess") or word:
57
  if word not in used_words:
58
+ word_embedding = [models[name].encode(word.lower().strip()) for name in model_names]
59
+ similarities = [util.pytorch_cos_sim(secret_embedding[i], word_embedding[i]).cpu().numpy()[0][0] for i, name in enumerate(model_names)]
60
+ st.session_state['words'].append([str(word)] + similarities))
61
+
62
+
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  words_df = pd.DataFrame(
65
  st.session_state['words'],
66
+ columns=["word"] + ["Similarity for " + name for name in model_names]
67
+ ).sort_values(by=["sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"], ascending=False)
68
  st.dataframe(words_df, use_container_width=True)
69