File size: 5,579 Bytes
b21feb2 8aa44e7 b21feb2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
# import os
# import pickle
import streamlit as st
st.text("This is a test")
# import pandas as pd
# import vec2text
# import torch
# from transformers import AutoModel, AutoTokenizer
# from umap import UMAP
# from tqdm import tqdm
# import plotly.express as px
# import numpy as np
# from sklearn.decomposition import PCA
# # from streamlit_plotly_events import plotly_events
# import plotly.graph_objects as go
# import logging
# import utils
# # Activate tqdm with pandas
# tqdm.pandas()
# # Custom file cache decorator
# def file_cache(file_path):
# def decorator(func):
# def wrapper(*args, **kwargs):
# # Check if the file already exists
# if os.path.exists(file_path):
# # Load from cache
# with open(file_path, "rb") as f:
# print(f"Loading cached data from {file_path}")
# return pickle.load(f)
# else:
# # Compute and save to cache
# result = func(*args, **kwargs)
# with open(file_path, "wb") as f:
# pickle.dump(result, f)
# print(f"Saving new cache to {file_path}")
# return result
# return wrapper
# return decorator
# @st.cache_resource
# def vector_compressor_from_config():
# # Return UMAP with 2 components for dimensionality reduction
# # return UMAP(n_components=2)
# return PCA(n_components=2)
# # Caching the dataframe since loading from an external source can be time-consuming
# @st.cache_data
# def load_data():
# return pd.read_csv("https://huggingface.co/datasets/marksverdhei/reddit-syac-urls/resolve/main/train.csv")
# df = load_data()
# # Caching the model and tokenizer to avoid reloading
# # @st.cache_resource
# # def load_model_and_tokenizer():
# # encoder = AutoModel.from_pretrained("sentence-transformers/gtr-t5-base").encoder.to("cuda")
# # tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/gtr-t5-base")
# # return encoder, tokenizer
# # encoder, tokenizer = load_model_and_tokenizer()
# # Caching the vec2text corrector
# # @st.cache_resource
# # def load_corrector():
# # return vec2text.load_pretrained_corrector("gtr-base")
# # corrector = load_corrector()
# # Caching the precomputed embeddings since they are stored locally and large
# @st.cache_data
# def load_embeddings():
# return np.load("syac-title-embeddings.npy")
# embeddings = load_embeddings()
# # Custom cache the UMAP reduction using file_cache decorator
# @st.cache_data
# @file_cache(".cache/reducer_embeddings.pickle")
# def reduce_embeddings(embeddings):
# reducer = vector_compressor_from_config()
# return reducer.fit_transform(embeddings), reducer
# vectors_2d, reducer = reduce_embeddings(embeddings)
# # Add a scatter plot using Plotly
# # fig = px.scatter(
# # x=vectors_2d[:, 0],
# # y=vectors_2d[:, 1],
# # opacity=0.6,
# # hover_data={"Title": df["title"]},
# # labels={'x': 'UMAP Dimension 1', 'y': 'UMAP Dimension 2'},
# # title="UMAP Scatter Plot of Reddit Titles",
# # color_discrete_sequence=["#ff504c"] # Set default blue color for points
# # )
# # # Customize the layout to adapt to browser settings (light/dark mode)
# # fig.update_layout(
# # template=None, # Let Plotly adapt automatically based on user settings
# # plot_bgcolor="rgba(0, 0, 0, 0)",
# # paper_bgcolor="rgba(0, 0, 0, 0)"
# # )
# x, y = 0.0, 0.0
# vec = np.array([x, y]).astype("float32")
# # Add a card container to the right of the content with Streamlit columns
# col1, col2 = st.columns([3, 1]) # Adjusting ratio to allocate space for the card container
# with col1:
# # Main content stays here (scatterplot, form, etc.)
# # selected_points = plotly_events(fig, click_event=True, hover_event=False,
# # )
# selected_points = None
# with st.form(key="form1_main"):
# if selected_points:
# clicked_point = selected_points[0]
# x_coord = x = clicked_point['x']
# y_coord = y = clicked_point['y']
# x = st.number_input("X Coordinate", value=x, format="%.10f")
# y = st.number_input("Y Coordinate", value=y, format="%.10f")
# vec = np.array([x, y]).astype("float32")
# submit_button = st.form_submit_button("Submit")
# if selected_points or submit_button:
# inferred_embedding = reducer.inverse_transform(np.array([[x, y]]) if not isinstance(reducer, UMAP) else np.array([[x, y]]))
# inferred_embedding = inferred_embedding.astype("float32")
# output = vec2text.invert_embeddings(
# embeddings=torch.tensor(inferred_embedding).cuda(),
# corrector=corrector,
# num_steps=20,
# )
# st.text(str(output))
# st.text(str(inferred_embedding))
# else:
# st.text("Click on a point in the scatterplot to see its coordinates.")
# with col2:
# closest_sentence_index = utils.find_exact_match(vectors_2d, vec, decimals=3)
# st.write(f"{vectors_2d.dtype} {vec.dtype}")
# if closest_sentence_index > -1:
# st.write(df["title"].iloc[closest_sentence_index])
# # Card content
# st.markdown("## Card Container")
# st.write("This is an additional card container to the right of the main content.")
# st.write("You can use this space to show additional information, actions, or insights.")
# st.button("Card Button") |