Spaces:
Sleeping
Sleeping
import streamlit as st | |
import plotly.express as px | |
import pandas as pd | |
import random | |
import logging | |
from umap import UMAP | |
from sentence_transformers import SentenceTransformer, util | |
from datasets import load_dataset | |
def load_model(): | |
return SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2') | |
def load_words_dataset(): | |
dataset = load_dataset("marksverdhei/wordnet-definitions-en-2021", split="train") | |
return dataset["Word"] | |
def choose_secret_word(): | |
all_words = load_words_dataset() | |
return random.choice(all_words) | |
def prepare_umap(): | |
all_enc = model.encode(all_words) | |
umap_3d = UMAP(n_components=3, init='random', random_state=0) | |
proj_3d = umap_3d.fit_transform(random.sample(all_enc.tolist(), k=1000)) | |
return umap_3d | |
all_words = load_words_dataset() | |
model = load_model() | |
umap_3d = prepare_umap() | |
secret_word =choose_secret_word() | |
secret_embedding = model.encode(secret_word) | |
print("Secret word ", secret_word) | |
if 'words' not in st.session_state: | |
st.session_state['words'] = [] | |
if 'words_umap_df' not in st.session_state: | |
words_umap_df = pd.DataFrame({ | |
"x": [], | |
"y": [], | |
"z": [], | |
"similarity": [], | |
"s": [], | |
"l": [], | |
}) | |
secret_embedding_3d = umap_3d.transform([secret_embedding])[0] | |
words_umap_df.loc[len(words_umap_df)] = { | |
"x": secret_embedding_3d[0], | |
"y": secret_embedding_3d[1], | |
"z": secret_embedding_3d[2], | |
"similarity": 1, | |
"s": 10, | |
"l": "Secret word" | |
} | |
st.session_state['words_umap_df'] = words_umap_df | |
st.write('Try to guess a secret word by semantic similarity') | |
word = st.text_input("Input a word") | |
used_words = [w for w, s in st.session_state['words']] | |
if st.button("Guess") or word: | |
if word not in used_words: | |
word_embedding = model.encode(word) | |
similarity = util.pytorch_cos_sim( | |
secret_embedding, | |
word_embedding | |
).cpu().numpy()[0][0] | |
st.session_state['words'].append((str(word), similarity)) | |
pt = umap_3d.transform([word_embedding])[0] | |
words_umap_df = st.session_state['words_umap_df'] | |
words_umap_df.loc[len(words_umap_df)] = { | |
"x": pt[0], | |
"y": pt[1], | |
"z": pt[2], | |
"similarity": similarity, | |
"s": 3, | |
"l": str(word) | |
} | |
st.session_state['words_umap_df'] = words_umap_df | |
words_df = pd.DataFrame( | |
st.session_state['words'], | |
columns=["word", "similarity"] | |
).sort_values(by=["similarity"], ascending=False) | |
st.dataframe(words_df, use_container_width=True) | |
words_umap_df = st.session_state['words_umap_df'] | |
fig_3d = px.scatter_3d(words_umap_df, x="x", y="y", z="z", color="similarity", hover_name="l", hover_data={"x": False, "y": False, "z": False, "s": False}, size="s", size_max=10, range_color=(0,1)) | |
st.plotly_chart(fig_3d, theme="streamlit", use_container_width=True) | |