context-game / app.py
Allob's picture
Update app.py
91b81a0
raw
history blame
3.04 kB
import streamlit as st
import plotly.express as px
import pandas as pd
import random
import logging
from umap import UMAP
from sentence_transformers import SentenceTransformer, util
from datasets import load_dataset
@st.cache_resource
def load_model():
return SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
@st.cache_data
def load_words_dataset():
dataset = load_dataset("marksverdhei/wordnet-definitions-en-2021", split="train")
return dataset["Word"]
@st.cache_data
def choose_secret_word():
all_words = load_words_dataset()
return random.choice(all_words)
@st.cache_resource
def prepare_umap():
all_enc = model.encode(all_words)
umap_3d = UMAP(n_components=3, init='random', random_state=0)
proj_3d = umap_3d.fit_transform(random.sample(all_enc.tolist(), k=1000))
return umap_3d
all_words = load_words_dataset()
model = load_model()
umap_3d = prepare_umap()
secret_word =choose_secret_word()
secret_embedding = model.encode(secret_word)
print("Secret word ", secret_word)
if 'words' not in st.session_state:
st.session_state['words'] = []
if 'words_umap_df' not in st.session_state:
words_umap_df = pd.DataFrame({
"x": [],
"y": [],
"z": [],
"similarity": [],
"s": [],
"l": [],
})
secret_embedding_3d = umap_3d.transform([secret_embedding])[0]
words_umap_df.loc[len(words_umap_df)] = {
"x": secret_embedding_3d[0],
"y": secret_embedding_3d[1],
"z": secret_embedding_3d[2],
"similarity": 1,
"s": 10,
"l": "Secret word"
}
st.session_state['words_umap_df'] = words_umap_df
st.write('Try to guess a secret word by semantic similarity')
word = st.text_input("Input a word")
used_words = [w for w, s in st.session_state['words']]
if st.button("Guess") or word:
if word not in used_words:
word_embedding = model.encode(word)
similarity = util.pytorch_cos_sim(
secret_embedding,
word_embedding
).cpu().numpy()[0][0]
st.session_state['words'].append((str(word), similarity))
pt = umap_3d.transform([word_embedding])[0]
words_umap_df = st.session_state['words_umap_df']
words_umap_df.loc[len(words_umap_df)] = {
"x": pt[0],
"y": pt[1],
"z": pt[2],
"similarity": similarity,
"s": 3,
"l": str(word)
}
st.session_state['words_umap_df'] = words_umap_df
words_df = pd.DataFrame(
st.session_state['words'],
columns=["word", "similarity"]
).sort_values(by=["similarity"], ascending=False)
st.dataframe(words_df, use_container_width=True)
words_umap_df = st.session_state['words_umap_df']
fig_3d = px.scatter_3d(words_umap_df, x="x", y="y", z="z", color="similarity", hover_name="l", hover_data={"x": False, "y": False, "z": False, "s": False}, size="s", size_max=10, range_color=(0,1))
st.plotly_chart(fig_3d, theme="streamlit", use_container_width=True)