File size: 3,014 Bytes
935a660
eabf510
82413ee
eabf510
0b6790a
eabf510
4480f3c
d0d62c4
 
4480f3c
09e6c30
 
 
 
eabf510
 
 
 
 
 
 
0b6790a
 
 
 
 
 
eabf510
 
 
 
09e6c30
4480f3c
0b6790a
eabf510
 
 
 
78dfe4e
 
eabf510
4480f3c
 
 
935a660
eabf510
f2e9695
eabf510
 
 
 
 
 
 
0b6790a
 
eabf510
 
 
 
 
 
 
 
f2e9695
 
 
eabf510
 
5b71e40
 
 
 
f5f665e
 
ee4ac1e
f5f665e
 
eabf510
 
 
 
 
 
0b6790a
 
f2e9695
 
eabf510
 
 
 
 
 
 
f2e9695
f414b62
 
 
 
 
b26ceaa
d0d62c4
 
f2e9695
 
eabf510
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import streamlit as st
import plotly.express as px
import pandas as pd
import random
import logging
from umap import UMAP
from sentence_transformers import SentenceTransformer, util
from datasets import load_dataset


@st.cache_resource
def load_model():
    return SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')


@st.cache_data
def load_words_dataset():
    dataset = load_dataset("marksverdhei/wordnet-definitions-en-2021", split="train")
    return dataset["Word"]


# @st.cache_resource
# def prepare_umap():
#     all_enc = model.encode(all_words)
#     umap_3d = UMAP(n_components=3, init='random', random_state=0)
#     proj_3d = umap_3d.fit_transform(random.sample(all_enc.tolist(), k=2000))
#     return umap_3d


all_words = load_words_dataset()

model = load_model()

#umap_3d = prepare_umap()


secret_word = random.choice(all_words)
secret_embedding = model.encode(secret_word)

logging.info("Secret word", secret_word)


if 'words' not in st.session_state:
    st.session_state['words'] = []

if 'words_umap_df' not in st.session_state:
    words_umap_df = pd.DataFrame({
        "x": [],
        "y": [],
        "z": [],
        "similarity": [],
        "s": [],
        "l": [],
    })
    #secret_embedding_3d = umap_3d.transform([secret_embedding])[0]
    secret_embedding_3d = [0, 1, 2]
    words_umap_df.loc[len(words_umap_df)] = {
        "x": secret_embedding_3d[0],
        "y": secret_embedding_3d[1],
        "z": secret_embedding_3d[2],
        "similarity": 1,
        "s": 10,
        "l": "Secret word"
    }
    st.session_state['words_umap_df'] = words_umap_df




st.write('Try to guess a secret word by semantic similarity')

word = st.text_input("Input a word")

used_words = [w for w, s in st.session_state['words']]

if st.button("Guess") or word:
    if word not in used_words:
        word_embedding = model.encode(word)
        similarity = util.pytorch_cos_sim(
            secret_embedding,
            word_embedding
        ).cpu().numpy()[0][0]
        st.session_state['words'].append((str(word), similarity))

        #pt = umap_3d.transform([word_embedding])[0]
        pt = [0, 1, 2]
        words_umap_df = st.session_state['words_umap_df']
        words_umap_df.loc[len(words_umap_df)] = {
            "x": pt[0],
            "y": pt[1],
            "z": pt[2],
            "similarity": similarity,
            "s": 3,
            "l": str(word)
        }
        st.session_state['words_umap_df'] = words_umap_df

words_df = pd.DataFrame(
    st.session_state['words'],
    columns=["word", "similarity"]
).sort_values(by=["similarity"], ascending=False)
st.dataframe(words_df, use_container_width=True)


words_umap_df = st.session_state['words_umap_df']
fig_3d = px.scatter_3d(words_umap_df, x="x", y="y", z="z", color="similarity", hover_name="l", hover_data={"x": False, "y": False, "z": False, "s": False}, size="s", size_max=10, range_color=(0,1))
st.plotly_chart(fig_3d, theme="streamlit", use_container_width=True)