File size: 2,796 Bytes
935a660
eabf510
82413ee
eabf510
 
4480f3c
d0d62c4
 
4480f3c
09e6c30
 
 
 
eabf510
 
 
 
 
 
 
 
 
 
 
be07201
eabf510
 
 
 
 
09e6c30
4480f3c
eabf510
 
 
 
 
 
4480f3c
 
 
935a660
eabf510
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b71e40
 
 
 
f5f665e
 
ee4ac1e
f5f665e
 
eabf510
 
 
 
 
 
 
eedd478
eabf510
 
 
 
 
 
 
f414b62
 
 
 
 
b26ceaa
d0d62c4
 
eedd478
eabf510
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import streamlit as st
import plotly.express as px
import pandas as pd
import random
from umap import UMAP
from sentence_transformers import SentenceTransformer, util
from datasets import load_dataset


@st.cache_resource
def load_model():
    return SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')


@st.cache_data
def load_words_dataset():
    dataset = load_dataset("marksverdhei/wordnet-definitions-en-2021", split="train")
    return dataset["Word"]


@st.cache_resource
def prepare_umap():
    all_enc = model.encode(all_words)
    umap_3d = UMAP(n_components=3, init='random', random_state=0)
    proj_3d = umap_3d.fit_transform(random.sample(all_enc.tolist(), k=2000))
    return umap_3d


all_words = load_words_dataset()

model = load_model()

umap_3d = prepare_umap()


secret_word = random.choice(all_words)
secret_embedding = model.encode(secret_word)


if 'words' not in st.session_state:
    st.session_state['words'] = []

if 'words_umap_df' not in st.session_state:
    st.session_state['words_umap_df'] = pd.DataFrame({
        "x": [],
        "y": [],
        "z": [],
        "similarity": [],
        "s": [],
        "l": [],
    })
    words_umap_df = st.session_state['words_umap_df']

    secret_embedding_3d = umap_3d.transform([secret_embedding])[0]
    words_umap_df.loc[len(words_umap_df)] = {
        "x": secret_embedding_3d[0],
        "y": secret_embedding_3d[1],
        "z": secret_embedding_3d[2],
        "similarity": 1,
        "s": 10,
        "l": "Secret word"
    }


st.write('Try to guess a secret word by semantic similarity')

word = st.text_input("Input a word")

used_words = [w for w, s in st.session_state['words']]

if st.button("Guess") or word:
    if word not in used_words:
        word_embedding = model.encode(word)
        similarity = util.pytorch_cos_sim(
            secret_embedding,
            word_embedding
        ).cpu().numpy()[0][0]
        st.session_state['words'].append((str(word), similarity))

        pt = umap_3d.transform([word_embedding])[0]
        st.session_state['words_umap_df'].loc[len(st.session_state['words_umap_df'])] = {
            "x": pt[0],
            "y": pt[1],
            "z": pt[2],
            "similarity": similarity,
            "s": 3,
            "l": str(word)
        }

words_df = pd.DataFrame(
    st.session_state['words'],
    columns=["word", "similarity"]
).sort_values(by=["similarity"], ascending=False)
st.dataframe(words_df, use_container_width=True)


fig_3d = px.scatter_3d(st.session_state['words_umap_df'], x="x", y="y", z="z", color="similarity", hover_name="l", hover_data={"x": False, "y": False, "z": False, "s": False}, size="s", size_max=10, range_color=(0,1))
st.plotly_chart(fig_3d, theme="streamlit", use_container_width=True)