File size: 2,726 Bytes
935a660
eabf510
82413ee
eabf510
 
4480f3c
d0d62c4
 
4480f3c
09e6c30
 
 
 
eabf510
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
09e6c30
4480f3c
eabf510
 
 
 
 
 
4480f3c
 
 
935a660
eabf510
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b71e40
 
 
 
f5f665e
 
ee4ac1e
f5f665e
 
eabf510
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f414b62
 
 
 
 
 
d0d62c4
 
eabf510
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import streamlit as st
import plotly.express as px
import pandas as pd
import random
from umap import UMAP
from sentence_transformers import SentenceTransformer, util
from datasets import load_dataset


@st.cache_resource
def load_model():
    return SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')


@st.cache_data
def load_words_dataset():
    dataset = load_dataset("marksverdhei/wordnet-definitions-en-2021", split="train")
    return dataset["Word"]


@st.cache_resource
def prepare_umap():
    all_enc = model.encode(all_words)
    umap_3d = UMAP(n_components=3, init='random', random_state=0)
    proj_3d = umap_3d.fit_transform(all_enc)
    return umap_3d


all_words = load_words_dataset()

model = load_model()

umap_3d = prepare_umap()


secret_word = random.choice(all_words)
secret_embedding = model.encode(secret_word)


if 'words' not in st.session_state:
    st.session_state['words'] = []

if 'words_umap_df' not in st.session_state:
    st.session_state['words_umap_df'] = pd.DataFrame({
        "x": [],
        "y": [],
        "z": [],
        "similarity": [],
        "s": [],
        "l": [],
    })
    words_umap_df = st.session_state['words_umap_df']

    secret_embedding_3d = umap_3d.transform([secret_embedding])[0]
    words_umap_df.loc[len(words_umap_df)] = {
        "x": secret_embedding_3d[0],
        "y": secret_embedding_3d[1],
        "z": secret_embedding_3d[2],
        "similarity": 1,
        "s": 10,
        "l": "Secret word"
    }

words_umap_df = st.session_state['words_umap_df']

st.write('Try to guess a secret word by semantic similarity')

word = st.text_input("Input a word")

used_words = [w for w, s in st.session_state['words']]

if st.button("Guess") or word:
    if word not in used_words:
        word_embedding = model.encode(word)
        similarity = util.pytorch_cos_sim(
            secret_embedding,
            word_embedding
        ).cpu().numpy()[0][0]
        st.session_state['words'].append((str(word), similarity))

        pt = umap_3d.transform([word_embedding])[0]
        words_umap_df.loc[len(words_umap_df)] = {
            "x": pt[0],
            "y": pt[1],
            "z": pt[2],
            "similarity": similarity,
            "s": 3,
            "l": str(word)
        }

words_df = pd.DataFrame(
    st.session_state['words'],
    columns=["word", "similarity"]
).sort_values(by=["similarity"], ascending=False)
st.dataframe(words_df)


fig_3d = px.scatter_3d(word_points, x="x", y="y", z="z", color="similarity", hover_name="l", hover_data={"x": False, "y": False, "z": False, "s": False}, size="s", size_max=10, range_color=(0,1))
st.plotly_chart(fig_3d, theme="streamlit", use_container_width=True)