# import os | |
# import pickle | |
import streamlit as st | |
st.text("This is a test") | |
# import pandas as pd | |
# import vec2text | |
# import torch | |
# from transformers import AutoModel, AutoTokenizer | |
# from umap import UMAP | |
# from tqdm import tqdm | |
# import plotly.express as px | |
# import numpy as np | |
# from sklearn.decomposition import PCA | |
# # from streamlit_plotly_events import plotly_events | |
# import plotly.graph_objects as go | |
# import logging | |
# import utils | |
# # Activate tqdm with pandas | |
# tqdm.pandas() | |
# # Custom file cache decorator | |
# def file_cache(file_path): | |
# def decorator(func): | |
# def wrapper(*args, **kwargs): | |
# # Check if the file already exists | |
# if os.path.exists(file_path): | |
# # Load from cache | |
# with open(file_path, "rb") as f: | |
# print(f"Loading cached data from {file_path}") | |
# return pickle.load(f) | |
# else: | |
# # Compute and save to cache | |
# result = func(*args, **kwargs) | |
# with open(file_path, "wb") as f: | |
# pickle.dump(result, f) | |
# print(f"Saving new cache to {file_path}") | |
# return result | |
# return wrapper | |
# return decorator | |
# @st.cache_resource | |
# def vector_compressor_from_config(): | |
# # Return UMAP with 2 components for dimensionality reduction | |
# # return UMAP(n_components=2) | |
# return PCA(n_components=2) | |
# # Caching the dataframe since loading from an external source can be time-consuming | |
# @st.cache_data | |
# def load_data(): | |
# return pd.read_csv("https://huggingface.co/datasets/marksverdhei/reddit-syac-urls/resolve/main/train.csv") | |
# df = load_data() | |
# # Caching the model and tokenizer to avoid reloading | |
# # @st.cache_resource | |
# # def load_model_and_tokenizer(): | |
# # encoder = AutoModel.from_pretrained("sentence-transformers/gtr-t5-base").encoder.to("cuda") | |
# # tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/gtr-t5-base") | |
# # return encoder, tokenizer | |
# # encoder, tokenizer = load_model_and_tokenizer() | |
# # Caching the vec2text corrector | |
# # @st.cache_resource | |
# # def load_corrector(): | |
# # return vec2text.load_pretrained_corrector("gtr-base") | |
# # corrector = load_corrector() | |
# # Caching the precomputed embeddings since they are stored locally and large | |
# @st.cache_data | |
# def load_embeddings(): | |
# return np.load("syac-title-embeddings.npy") | |
# embeddings = load_embeddings() | |
# # Custom cache the UMAP reduction using file_cache decorator | |
# @st.cache_data | |
# @file_cache(".cache/reducer_embeddings.pickle") | |
# def reduce_embeddings(embeddings): | |
# reducer = vector_compressor_from_config() | |
# return reducer.fit_transform(embeddings), reducer | |
# vectors_2d, reducer = reduce_embeddings(embeddings) | |
# # Add a scatter plot using Plotly | |
# # fig = px.scatter( | |
# # x=vectors_2d[:, 0], | |
# # y=vectors_2d[:, 1], | |
# # opacity=0.6, | |
# # hover_data={"Title": df["title"]}, | |
# # labels={'x': 'UMAP Dimension 1', 'y': 'UMAP Dimension 2'}, | |
# # title="UMAP Scatter Plot of Reddit Titles", | |
# # color_discrete_sequence=["#ff504c"] # Set default blue color for points | |
# # ) | |
# # # Customize the layout to adapt to browser settings (light/dark mode) | |
# # fig.update_layout( | |
# # template=None, # Let Plotly adapt automatically based on user settings | |
# # plot_bgcolor="rgba(0, 0, 0, 0)", | |
# # paper_bgcolor="rgba(0, 0, 0, 0)" | |
# # ) | |
# x, y = 0.0, 0.0 | |
# vec = np.array([x, y]).astype("float32") | |
# # Add a card container to the right of the content with Streamlit columns | |
# col1, col2 = st.columns([3, 1]) # Adjusting ratio to allocate space for the card container | |
# with col1: | |
# # Main content stays here (scatterplot, form, etc.) | |
# # selected_points = plotly_events(fig, click_event=True, hover_event=False, | |
# # ) | |
# selected_points = None | |
# with st.form(key="form1_main"): | |
# if selected_points: | |
# clicked_point = selected_points[0] | |
# x_coord = x = clicked_point['x'] | |
# y_coord = y = clicked_point['y'] | |
# x = st.number_input("X Coordinate", value=x, format="%.10f") | |
# y = st.number_input("Y Coordinate", value=y, format="%.10f") | |
# vec = np.array([x, y]).astype("float32") | |
# submit_button = st.form_submit_button("Submit") | |
# if selected_points or submit_button: | |
# inferred_embedding = reducer.inverse_transform(np.array([[x, y]]) if not isinstance(reducer, UMAP) else np.array([[x, y]])) | |
# inferred_embedding = inferred_embedding.astype("float32") | |
# output = vec2text.invert_embeddings( | |
# embeddings=torch.tensor(inferred_embedding).cuda(), | |
# corrector=corrector, | |
# num_steps=20, | |
# ) | |
# st.text(str(output)) | |
# st.text(str(inferred_embedding)) | |
# else: | |
# st.text("Click on a point in the scatterplot to see its coordinates.") | |
# with col2: | |
# closest_sentence_index = utils.find_exact_match(vectors_2d, vec, decimals=3) | |
# st.write(f"{vectors_2d.dtype} {vec.dtype}") | |
# if closest_sentence_index > -1: | |
# st.write(df["title"].iloc[closest_sentence_index]) | |
# # Card content | |
# st.markdown("## Card Container") | |
# st.write("This is an additional card container to the right of the main content.") | |
# st.write("You can use this space to show additional information, actions, or insights.") | |
# st.button("Card Button") |