import streamlit as st import plotly.express as px import pandas as pd import random import logging from umap import UMAP from sentence_transformers import SentenceTransformer, util from datasets import load_dataset @st.cache_resource def load_model(): return SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2') @st.cache_data def load_words_dataset(): dataset = load_dataset("marksverdhei/wordnet-definitions-en-2021", split="train") return dataset["Word"] @st.cache_data def choose_secret_word(): all_words = load_words_dataset() return random.choice(all_words) @st.cache_resource def prepare_umap(): all_enc = model.encode(all_words) umap_3d = UMAP(n_components=3, init='random', random_state=0) proj_3d = umap_3d.fit_transform(random.sample(all_enc.tolist(), k=1000)) return umap_3d all_words = load_words_dataset() model = load_model() umap_3d = prepare_umap() secret_word =choose_secret_word() secret_embedding = model.encode(secret_word) print("Secret word ", secret_word) if 'words' not in st.session_state: st.session_state['words'] = [] if 'words_umap_df' not in st.session_state: words_umap_df = pd.DataFrame({ "x": [], "y": [], "z": [], "similarity": [], "s": [], "l": [], }) secret_embedding_3d = umap_3d.transform([secret_embedding])[0] words_umap_df.loc[len(words_umap_df)] = { "x": secret_embedding_3d[0], "y": secret_embedding_3d[1], "z": secret_embedding_3d[2], "similarity": 1, "s": 10, "l": "Secret word" } st.session_state['words_umap_df'] = words_umap_df st.write('Try to guess a secret word by semantic similarity') word = st.text_input("Input a word") used_words = [w for w, s in st.session_state['words']] if st.button("Guess") or word: if word not in used_words: word_embedding = model.encode(word) similarity = util.pytorch_cos_sim( secret_embedding, word_embedding ).cpu().numpy()[0][0] st.session_state['words'].append((str(word), similarity)) pt = umap_3d.transform([word_embedding])[0] words_umap_df = st.session_state['words_umap_df'] words_umap_df.loc[len(words_umap_df)] = { "x": pt[0], "y": pt[1], "z": pt[2], "similarity": similarity, "s": 3, "l": str(word) } st.session_state['words_umap_df'] = words_umap_df words_df = pd.DataFrame( st.session_state['words'], columns=["word", "similarity"] ).sort_values(by=["similarity"], ascending=False) st.dataframe(words_df, use_container_width=True) words_umap_df = st.session_state['words_umap_df'] fig_3d = px.scatter_3d(words_umap_df, x="x", y="y", z="z", color="similarity", hover_name="l", hover_data={"x": False, "y": False, "z": False, "s": False}, size="s", size_max=10, range_color=(0,1)) st.plotly_chart(fig_3d, theme="streamlit", use_container_width=True)