MilanM's picture
Update visualizer_app.py
87bd0d5 verified
import marimo
__generated_with = "0.13.0"
app = marimo.App(width="full")
with app.setup:
# Initialization code that runs before all other cells
import marimo as mo
from typing import Dict, Optional, List, Union, Any
from ibm_watsonx_ai import APIClient, Credentials
from pathlib import Path
import pandas as pd
import mimetypes
import requests
import zipfile
import tempfile
import certifi
import base64
import polars
import nltk
import time
import json
import ast
import os
import io
import re
def get_iam_token(api_key):
return requests.post(
'https://iam.cloud.ibm.com/identity/token',
headers={'Content-Type': 'application/x-www-form-urlencoded'},
data={'grant_type': 'urn:ibm:params:oauth:grant-type:apikey', 'apikey': api_key},
verify=certifi.where()
).json()['access_token']
def setup_task_credentials(client):
# Get existing task credentials
existing_credentials = client.task_credentials.get_details()
# Delete existing credentials if any
if "resources" in existing_credentials and existing_credentials["resources"]:
for cred in existing_credentials["resources"]:
cred_id = client.task_credentials.get_id(cred)
client.task_credentials.delete(cred_id)
# Store new credentials
return client.task_credentials.store()
def get_cred_value(key, creds_var_name="baked_in_creds", default=""): ### Helper for working with preset credentials
"""
Helper function to safely get a value from a credentials dictionary.
Args:
key: The key to look up in the credentials dictionary.
creds_var_name: The variable name of the credentials dictionary.
default: The default value to return if the key is not found.
Returns:
The value from the credentials dictionary if it exists and contains the key,
otherwise returns the default value.
"""
# Check if the credentials variable exists in globals
if creds_var_name in globals():
creds_dict = globals()[creds_var_name]
if isinstance(creds_dict, dict) and key in creds_dict:
return creds_dict[key]
return default
@app.cell
def client_variables(client_instantiation_form):
if client_instantiation_form.value:
client_setup = client_instantiation_form.value
else:
client_setup = None
### Extract Credential Variables:
if client_setup is not None:
wx_url = client_setup["wx_region"]
wx_api_key = client_setup["wx_api_key"].strip()
os.environ["WATSONX_APIKEY"] = wx_api_key
if client_setup["project_id"] is not None:
project_id = client_setup["project_id"].strip()
else:
project_id = None
if client_setup["space_id"] is not None:
space_id = client_setup["space_id"].strip()
else:
space_id = None
else:
os.environ["WATSONX_APIKEY"] = ""
project_id = None
space_id = None
wx_api_key = None
wx_url = None
return client_setup, project_id, space_id, wx_api_key, wx_url
@app.cell
def _(client_setup, wx_api_key):
if client_setup:
token = get_iam_token(wx_api_key)
else:
token = None
return
@app.cell
def _():
baked_in_creds = {
"purpose": "",
"api_key": "",
"project_id": "",
"space_id": "",
}
return baked_in_creds
@app.cell
def client_instantiation(
client_setup,
project_id,
space_id,
wx_api_key,
wx_url,
):
### Instantiate the watsonx.ai client
if client_setup:
wx_credentials = Credentials(
url=wx_url,
api_key=wx_api_key
)
if project_id:
project_client = APIClient(credentials=wx_credentials, project_id=project_id)
else:
project_client = None
if space_id:
deployment_client = APIClient(credentials=wx_credentials, space_id=space_id)
else:
deployment_client = None
if project_client is not None:
task_credentials_details = setup_task_credentials(project_client)
else:
task_credentials_details = setup_task_credentials(deployment_client)
else:
wx_credentials = None
project_client = None
deployment_client = None
task_credentials_details = None
client_status = mo.md("### Client Instantiation Status will turn Green When Ready")
if project_client is not None or deployment_client is not None:
client_callout_kind = "success"
else:
client_callout_kind = "neutral"
return (
client_callout_kind,
client_status,
deployment_client,
project_client,
)
@app.cell
def _():
mo.md(
r"""
#watsonx.ai Embedding Visualizer - Marimo Notebook
#### This marimo notebook can be used to develop a more intuitive understanding of how vector embeddings work by creating a 3D visualization of vector embeddings based on chunked PDF document pages.
#### It can also serve as a useful tool for identifying gaps in model choice, chunking strategy or contents used in building collections by showing how far you are from what you want.
<br>
/// admonition
Created by ***Milan Mrdenovic*** [[email protected]] for IBM Ecosystem Client Engineering, NCEE - ***version 5.3** - 20.04.2025*
///
>Licensed under apache 2.0, users hold full accountability for any use or modification of the code.
><br>This asset is part of a set meant to support IBMers, IBM Partners, Clients in developing understanding of how to better utilize various watsonx features and generative AI as a subject matter.
<br>
"""
)
return
@app.cell
def _():
mo.md("""###Part 1 - Client Setup, File Preparation and Chunking""")
return
@app.cell
def accordion_client_setup(client_selector, client_stack):
ui_accordion_part_1_1 = mo.accordion(
{
"Instantiate Client": mo.vstack([client_stack, client_selector], align="center"),
}
)
ui_accordion_part_1_1
return
@app.cell
def accordion_file_upload(select_stack):
ui_accordion_part_1_2 = mo.accordion(
{
"Select Model & Upload Files": select_stack
}
)
ui_accordion_part_1_2
return
@app.cell
def loaded_texts(
create_temp_files_from_uploads,
file_loader,
pdf_reader,
run_upload_button,
set_text_state,
):
if file_loader.value is not None and run_upload_button.value:
filepaths = create_temp_files_from_uploads(file_loader.value)
loaded_texts = load_pdf_data_with_progress(pdf_reader, filepaths, file_loader.value, show_progress=True)
set_text_state(loaded_texts)
else:
filepaths = None
loaded_texts = None
return
@app.cell
def accordion_chunker_setup(chunker_setup):
ui_accordion_part_1_3 = mo.accordion(
{
"Chunker Setup": chunker_setup
}
)
ui_accordion_part_1_3
return
@app.cell
def chunk_documents_to_nodes(
get_text_state,
sentence_splitter,
sentence_splitter_config,
set_chunk_state,
):
if sentence_splitter_config.value and sentence_splitter and get_text_state() is not None:
chunked_texts = chunk_documents(get_text_state(), sentence_splitter, show_progress=True)
set_chunk_state(chunked_texts)
else:
chunked_texts = None
return (chunked_texts,)
@app.cell
def _():
mo.md(r"""###Part 2 - Query Setup and Visualization""")
return
@app.cell
def accordion_chunk_range(chart_range_selection):
ui_accordion_part_2_1 = mo.accordion(
{
"Chunk Range Selection": chart_range_selection
}
)
ui_accordion_part_2_1
return
@app.cell
def chunk_embedding(
chunks_to_process,
embedding,
sentence_splitter_config,
set_embedding_state,
):
if sentence_splitter_config.value is not None and chunks_to_process is not None:
with mo.status.spinner(title="Embedding Documents...", remove_on_exit=True) as _spinner:
output_embeddings = embedding.embed_documents(chunks_to_process)
_spinner.update("Almost Done")
time.sleep(1.5)
set_embedding_state(output_embeddings)
_spinner.update("Documents Embedded")
else:
output_embeddings = None
return
@app.cell
def preview_chunks(chunks_dict):
if chunks_dict is not None:
stats = create_stats(chunks_dict,
bordered=True,
object_names=['text','text'],
group_by_row=True,
items_per_row=5,
gap=1,
label="Chunk")
ui_chunk_viewer = mo.accordion(
{
"View Chunks": stats,
}
)
else:
ui_chunk_viewer = None
ui_chunk_viewer
return
@app.cell
def accordion_query_view(chart_visualization, query_stack):
ui_accordion_part_2_2 = mo.accordion(
{
"Query": mo.vstack([query_stack, mo.hstack([chart_visualization])], align="center", gap=3)
}
)
ui_accordion_part_2_2
return
@app.cell
def chunker_setup(sentence_splitter_config):
chunker_setup = mo.hstack([sentence_splitter_config], justify="space-around", align="center", widths=[0.55])
return (chunker_setup,)
@app.cell
def file_and_model_select(
file_loader,
get_embedding_model_list,
run_upload_button,
):
select_stack = mo.hstack([get_embedding_model_list(), mo.vstack([file_loader, run_upload_button], align="center")], justify="space-around", align="center", widths=[0.3,0.3])
return (select_stack,)
@app.cell
def client_instantiation_form():
# Endpoints
wx_platform_url = "https://api.dataplatform.cloud.ibm.com"
regions = {
"US": "https://us-south.ml.cloud.ibm.com",
"EU": "https://eu-de.ml.cloud.ibm.com",
"GB": "https://eu-gb.ml.cloud.ibm.com",
"JP": "https://jp-tok.ml.cloud.ibm.com",
"AU": "https://au-syd.ml.cloud.ibm.com",
"CA": "https://ca-tor.ml.cloud.ibm.com"
}
# Create a form with multiple elements
client_instantiation_form = (
mo.md('''
###**watsonx.ai credentials:**
{wx_region}
{wx_api_key}
{project_id}
{space_id}
> You can add either a project_id, space_id or both, **only one is required**.
> If you provide both you can switch the active one in the dropdown.
''')
.batch(
wx_region = mo.ui.dropdown(regions, label="Select your watsonx.ai region:", value="US", searchable=True),
wx_api_key = mo.ui.text(placeholder="Add your IBM Cloud api-key...", label="IBM Cloud Api-key:",
kind="password", value=get_cred_value('api_key', creds_var_name='baked_in_creds')),
project_id = mo.ui.text(placeholder="Add your watsonx.ai project_id...", label="Project_ID:",
kind="text", value=get_cred_value('project_id', creds_var_name='baked_in_creds')),
space_id = mo.ui.text(placeholder="Add your watsonx.ai space_id...", label="Space_ID:",
kind="text", value=get_cred_value('space_id', creds_var_name='baked_in_creds'))
,)
.form(show_clear_button=True, bordered=False)
)
return (client_instantiation_form,)
@app.cell
def instantiation_status(
client_callout_kind,
client_instantiation_form,
client_status,
):
client_callout = mo.callout(client_status, kind=client_callout_kind)
client_stack = mo.hstack([client_instantiation_form, client_callout], align="center", justify="space-around", gap=10)
return (client_stack,)
@app.cell
def client_selector(deployment_client, project_client):
if deployment_client is not None:
client_options = {"Deployment Client":deployment_client}
elif project_client is not None:
client_options = {"Project Client":project_client}
elif project_client is not None and deployment_client is not None:
client_options = {"Project Client":project_client,"Deployment Client":deployment_client}
else:
client_options = {"No Client": "Instantiate a Client"}
default_client = next(iter(client_options))
client_selector = mo.ui.dropdown(client_options, value=default_client, label="**Select your active client:**")
return (client_selector,)
@app.cell
def active_client(client_selector):
client_key = client_selector.value
if client_key == "Instantiate a Client":
client = None
else:
client = client_key
return (client,)
@app.cell
def emb_model_selection(client, set_embedding_model_list):
if client is not None:
model_specs = client.foundation_models.get_embeddings_model_specs()
# model_specs = client.foundation_models.get_model_specs()
resources = model_specs["resources"]
# Define embedding models reference data
embedding_models = {
"ibm/granite-embedding-107m-multilingual": {"max_tokens": 512, "embedding_dimensions": 384},
"ibm/granite-embedding-278m-multilingual": {"max_tokens": 512, "embedding_dimensions": 768},
"ibm/slate-125m-english-rtrvr-v2": {"max_tokens": 512, "embedding_dimensions": 768},
"ibm/slate-125m-english-rtrvr": {"max_tokens": 512, "embedding_dimensions": 768},
"ibm/slate-30m-english-rtrvr-v2": {"max_tokens": 512, "embedding_dimensions": 384},
"ibm/slate-30m-english-rtrvr": {"max_tokens": 512, "embedding_dimensions": 384},
"sentence-transformers/all-minilm-l6-v2": {"max_tokens": 128, "embedding_dimensions": 384},
"sentence-transformers/all-minilm-l12-v2": {"max_tokens": 128, "embedding_dimensions": 384},
"intfloat/multilingual-e5-large": {"max_tokens": 512, "embedding_dimensions": 1024}
}
# Get model IDs from resources
model_id_list = []
for resource in resources:
model_id_list.append(resource["model_id"])
# Create enhanced model data for the table
embedding_model_data = []
for model_id in model_id_list:
model_entry = {"model_id": model_id}
# Add properties if model exists in our reference, otherwise use 0
if model_id in embedding_models:
model_entry["max_tokens"] = embedding_models[model_id]["max_tokens"]
model_entry["embedding_dimensions"] = embedding_models[model_id]["embedding_dimensions"]
else:
model_entry["max_tokens"] = 0
model_entry["embedding_dimensions"] = 0
embedding_model_data.append(model_entry)
embedding_model_selection = mo.ui.table(
embedding_model_data,
selection="single", # Only allow selecting one row
label="Select an embedding model to use.",
page_size=30,
initial_selection=[1]
)
set_embedding_model_list(embedding_model_selection)
else:
default_model_data = [{
"model_id": "ibm/granite-embedding-107m-multilingual",
"max_tokens": 512,
"embedding_dimensions": 384
}]
set_embedding_model_list(create_emb_model_selection_table(default_model_data, initial_selection=0, selection_type="single", label="Select a model to use."))
return
@app.function
def create_emb_model_selection_table(model_data, initial_selection=0, selection_type="single", label="Select a model to use."):
embedding_model_selection = mo.ui.table(
model_data,
selection=selection_type, # Only allow selecting one row
label=label,
page_size=30,
initial_selection=[initial_selection]
)
return embedding_model_selection
@app.cell
def embedding_model():
get_embedding_model_list, set_embedding_model_list = mo.state(None)
return get_embedding_model_list, set_embedding_model_list
@app.cell
def emb_model_parameters(emb_model_max_tk):
from ibm_watsonx_ai.metanames import EmbedTextParamsMetaNames as EmbedParams
if embedding_model is not None:
embed_params = {
EmbedParams.TRUNCATE_INPUT_TOKENS: emb_model_max_tk,
EmbedParams.RETURN_OPTIONS: {
'input_text': True
}
}
else:
embed_params = {
EmbedParams.TRUNCATE_INPUT_TOKENS: 128,
EmbedParams.RETURN_OPTIONS: {
'input_text': True
}
}
return embed_params
@app.cell
def emb_model_state(get_embedding_model_list):
embedding_model = get_embedding_model_list()
return (embedding_model,)
@app.cell
def emb_model_setup(embedding_model):
if embedding_model is not None:
emb_model = embedding_model.value[0]['model_id']
emb_model_max_tk = embedding_model.value[0]['max_tokens']
emb_model_emb_dim = embedding_model.value[0]['embedding_dimensions']
else:
emb_model = None
emb_model_max_tk = None
emb_model_emb_dim = None
return emb_model, emb_model_emb_dim, emb_model_max_tk
@app.cell
def emb_model_instantiation(client, emb_model, embed_params):
from ibm_watsonx_ai.foundation_models import Embeddings
if client is not None:
embedding = Embeddings(
model_id=emb_model,
api_client=client,
params=embed_params,
batch_size=1000,
concurrency_limit=10
)
else:
embedding = None
return (embedding,)
@app.cell
def _():
get_embedding_state, set_embedding_state = mo.state(None)
return get_embedding_state, set_embedding_state
@app.cell
def _():
get_query_state, set_query_state = mo.state(None)
return get_query_state, set_query_state
@app.cell
def file_loader_input():
file_loader = mo.ui.file(
kind="area",
filetypes=[".pdf"],
label=" Load .pdf files ",
multiple=True
)
return (file_loader,)
@app.cell
def file_loader_run(file_loader):
if file_loader.value:
run_upload_button = mo.ui.run_button(label="Load Files")
else:
run_upload_button = mo.ui.run_button(disabled=True, label="Load Files")
return (run_upload_button,)
@app.cell
def helper_function_tempfiles():
def create_temp_files_from_uploads(upload_results) -> List[str]:
"""
Creates temporary files from a tuple of FileUploadResults objects and returns their paths.
Args:
upload_results: Object containing a value attribute that is a tuple of FileUploadResults
Returns:
List of temporary file paths
"""
temp_file_paths = []
# Get the number of items in the tuple
num_items = len(upload_results)
# Process each item by index
for i in range(num_items):
result = upload_results[i] # Get item by index
# Create a temporary file with the original filename
temp_dir = tempfile.gettempdir()
file_name = result.name
temp_path = os.path.join(temp_dir, file_name)
# Write the contents to the temp file
with open(temp_path, 'wb') as temp_file:
temp_file.write(result.contents)
# Add the path to our list
temp_file_paths.append(temp_path)
return temp_file_paths
def cleanup_temp_files(temp_file_paths: List[str]) -> None:
"""Delete temporary files after use."""
for path in temp_file_paths:
if os.path.exists(path):
os.unlink(path)
return (create_temp_files_from_uploads,)
@app.function
def load_pdf_data_with_progress(pdf_reader, filepaths, file_loader_value, show_progress=True):
"""
Loads PDF data for each file path and organizes results by original filename.
Args:
pdf_reader: The PyMuPDFReader instance
filepaths: List of temporary file paths
file_loader_value: The original upload results value containing file information
show_progress: Whether to show a progress bar during loading (default: False)
Returns:
Dictionary mapping original filenames to their loaded text content
"""
results = {}
# Process files with or without progress bar
if show_progress:
import marimo as mo
# Use progress bar with the length of filepaths as total
with mo.status.progress_bar(
total=len(filepaths),
title="Loading PDFs",
subtitle="Processing documents...",
completion_title="PDF Loading Complete",
completion_subtitle=f"{len(filepaths)} documents processed",
remove_on_exit=True
) as bar:
# Process each file path
for i, file_path in enumerate(filepaths):
original_file_name = file_loader_value[i].name
bar.update(subtitle=f"Processing {original_file_name}...")
loaded_text = pdf_reader.load_data(file_path=file_path, metadata=True)
# Store the result with the original filename as the key
results[original_file_name] = loaded_text
# Update progress bar
bar.update(increment=1)
else:
# Original logic without progress bar
for i, file_path in enumerate(filepaths):
original_file_name = file_loader_value[i].name
loaded_text = pdf_reader.load_data(file_path=file_path, metadata=True)
results[original_file_name] = loaded_text
return results
@app.cell
def file_readers():
from llama_index.readers.file import PyMuPDFReader
from llama_index.readers.file import FlatReader
from llama_index.core.node_parser import SentenceSplitter
### File Readers
pdf_reader = PyMuPDFReader()
# flat_file_reader = FlatReader()
return SentenceSplitter, pdf_reader
@app.cell
def sentence_splitter_setup():
### Chunker Setup
sentence_splitter_config = (
mo.md('''
###**Chunking Setup:**
> Unless you want to do some advanced sentence splitting, it's best to stick to adjusting only the chunk size and overlap. Changing the other settings might result in unexpected results.
Separator value is set to **" "** by default, while the paragraph separator is **"\\n\\n\\n"**.
{chunk_size}
{chunk_overlap}
{separator} {paragraph_separator}
{secondary_chunking_regex} {include_metadata}
''')
.batch(
chunk_size = mo.ui.slider(start=100, stop=5000, step=1, label="**Chunk Size:**", value=350, show_value=True, full_width=True),
chunk_overlap = mo.ui.slider(start=1, stop=1000, step=1, label="**Chunk Overlap** *(Must always be smaller than Chunk Size)* **:**", value=50, show_value=True, full_width=True),
separator = mo.ui.text(placeholder="Define a separator", label="**Separator:**", kind="text", value=" "),
paragraph_separator = mo.ui.text(placeholder="Define a paragraph separator",
label="**Paragraph Separator:**", kind="text",
value="\n\n\n"),
secondary_chunking_regex = mo.ui.text(placeholder="Define a secondary chunking regex",
label="**Chunking Regex:**", kind="text",
value="[^,.;?!]+[,.;?!]?"),
include_metadata= mo.ui.checkbox(value=True, label="**Include Metadata**")
)
.form(show_clear_button=True, bordered=False)
)
return (sentence_splitter_config,)
@app.cell
def sentence_splitter_instantiation(
SentenceSplitter,
sentence_splitter_config,
):
### Chunker/Sentence Splitter
def simple_whitespace_tokenizer(text):
return text.split()
if sentence_splitter_config.value is not None:
sentence_splitter_config_values = sentence_splitter_config.value
validated_chunk_overlap = min(sentence_splitter_config_values.get("chunk_overlap"),
int(sentence_splitter_config_values.get("chunk_size") * 0.3))
sentence_splitter = SentenceSplitter(
chunk_size=sentence_splitter_config_values.get("chunk_size"),
chunk_overlap=validated_chunk_overlap,
separator=sentence_splitter_config_values.get("separator"),
paragraph_separator=sentence_splitter_config_values.get("paragraph_separator"),
secondary_chunking_regex=sentence_splitter_config_values.get("secondary_chunking_regex"),
include_metadata=sentence_splitter_config_values.get("include_metadata"),
tokenizer=simple_whitespace_tokenizer
)
else:
sentence_splitter = SentenceSplitter(
chunk_size=2048,
chunk_overlap=204,
separator=" ",
paragraph_separator="\n\n\n",
secondary_chunking_regex="[^,.;?!]+[,.;?!]?",
include_metadata=True,
tokenizer=simple_whitespace_tokenizer
)
return (sentence_splitter,)
@app.cell
def text_state():
get_text_state, set_text_state = mo.state(None)
return get_text_state, set_text_state
@app.cell
def chunk_state():
get_chunk_state, set_chunk_state = mo.state(None)
return get_chunk_state, set_chunk_state
@app.function
def chunk_documents(loaded_texts, sentence_splitter, show_progress=True):
"""
Process each document in the loaded_texts dictionary using the sentence_splitter,
with an optional marimo progress bar tracking progress at document level.
Args:
loaded_texts (dict): Dictionary containing lists of Document objects
sentence_splitter: The sentence splitter object with get_nodes_from_documents method
show_progress (bool): Whether to show a progress bar during processing
Returns:
dict: Dictionary with the same structure but containing chunked texts
"""
chunked_texts_dict = {}
# Get the total number of documents across all keys
total_docs = sum(len(docs) for docs in loaded_texts.values())
processed_docs = 0
# Process with or without progress bar
if show_progress:
import marimo as mo
# Use progress bar with the total number of documents as total
with mo.status.progress_bar(
total=total_docs,
title="Processing Documents",
subtitle="Chunking documents...",
completion_title="Processing Complete",
completion_subtitle=f"{total_docs} documents processed",
remove_on_exit=True
) as bar:
# Process each key-value pair in the loaded_texts dictionary
for key, documents in loaded_texts.items():
# Update progress bar subtitle to show current key
doc_count = len(documents)
bar.update(subtitle=f"Chunking {key}... ({doc_count} documents)")
# Apply the sentence splitter to each list of documents
chunked_texts = sentence_splitter.get_nodes_from_documents(
documents,
show_progress=False # Disable internal progress to avoid nested bars
)
# Store the result with the same key
chunked_texts_dict[key] = chunked_texts
time.sleep(0.15)
# Update progress bar with the number of documents in this batch
bar.update(increment=doc_count)
processed_docs += doc_count
else:
# Process without progress bar
for key, documents in loaded_texts.items():
chunked_texts = sentence_splitter.get_nodes_from_documents(
documents,
show_progress=True # Use the internal progress bar if no marimo bar
)
chunked_texts_dict[key] = chunked_texts
return chunked_texts_dict
@app.cell
def chunked_nodes(chunked_texts, get_chunk_state, sentence_splitter):
if chunked_texts is not None and sentence_splitter:
chunked_documents = get_chunk_state()
else:
chunked_documents = None
return (chunked_documents,)
@app.cell
def prep_cumulative_df(chunked_documents, llamaindex_convert_docs_multi):
if chunked_documents is not None:
dict_from_nodes = llamaindex_convert_docs_multi(chunked_documents)
nodes_from_dict = llamaindex_convert_docs_multi(dict_from_nodes)
else:
dict_from_nodes = None
nodes_from_dict = None
return (dict_from_nodes,)
@app.cell
def chunks_to_process(
dict_from_nodes,
document_range_stack,
get_data_in_range_triplequote,
):
if dict_from_nodes is not None and document_range_stack is not None:
chunk_dict_df = create_cumulative_dataframe(dict_from_nodes)
if document_range_stack.value is not None:
chunk_start_idx = document_range_stack.value[0]
chunk_end_idx = document_range_stack.value[1]
else:
chunk_start_idx = 0
chunk_end_idx = len(chunk_dict_df)
chunk_range_index = [chunk_start_idx, chunk_end_idx]
chunks_dict = get_data_in_range_triplequote(chunk_dict_df,
index_range=chunk_range_index,
columns_to_include=["text"])
chunks_to_process = chunks_dict['text'] if 'text' in chunks_dict else []
else:
chunk_objects = None
chunks_dict = None
chunks_to_process = None
return chunks_dict, chunks_to_process
@app.cell
def helper_function_doc_formatting():
def llamaindex_convert_docs_multi(items):
"""
Automatically convert between document objects and dictionaries.
This function handles:
- Converting dictionaries to document objects
- Converting document objects to dictionaries
- Processing lists or individual items
- Supporting dictionary structures where values are lists of documents
Args:
items: A document object, dictionary, or list of either.
Can also be a dictionary mapping filenames to lists of documents.
Returns:
Converted item(s) maintaining the original structure
"""
# Handle empty or None input
if not items:
return []
# Handle dictionary mapping filenames to document lists (from load_pdf_data)
if isinstance(items, dict) and all(isinstance(v, list) for v in items.values()):
result = {}
for filename, doc_list in items.items():
result[filename] = llamaindex_convert_docs(doc_list)
return result
# Handle single items (not in a list)
if not isinstance(items, list):
# Single dictionary to document
if isinstance(items, dict):
# Determine document class
doc_class = None
if 'doc_type' in items:
import importlib
module_path, class_name = items['doc_type'].rsplit('.', 1)
module = importlib.import_module(module_path)
doc_class = getattr(module, class_name)
if not doc_class:
from llama_index.core.schema import Document
doc_class = Document
return doc_class.from_dict(items)
# Single document to dictionary
elif hasattr(items, 'to_dict'):
return items.to_dict()
# Return as is if can't convert
return items
# Handle list input
result = []
# Handle empty list
if len(items) == 0:
return result
# Determine the type of conversion based on the first non-None item
first_item = next((item for item in items if item is not None), None)
# If we found no non-None items, return empty list
if first_item is None:
return result
# Convert dictionaries to documents
if isinstance(first_item, dict):
# Get the right document class from the items themselves
doc_class = None
# Try to get doc class from metadata if available
if 'doc_type' in first_item:
import importlib
module_path, class_name = first_item['doc_type'].rsplit('.', 1)
module = importlib.import_module(module_path)
doc_class = getattr(module, class_name)
if not doc_class:
# Fallback to default Document class from llama_index
from llama_index.core.schema import Document
doc_class = Document
# Convert each dictionary to document
for item in items:
if isinstance(item, dict):
result.append(doc_class.from_dict(item))
elif item is None:
result.append(None)
elif isinstance(item, list):
result.append(llamaindex_convert_docs(item))
else:
result.append(item)
# Convert documents to dictionaries
else:
for item in items:
if hasattr(item, 'to_dict'):
result.append(item.to_dict())
elif item is None:
result.append(None)
elif isinstance(item, list):
result.append(llamaindex_convert_docs(item))
else:
result.append(item)
return result
def llamaindex_convert_docs(items):
"""
Automatically convert between document objects and dictionaries.
Args:
items: A list of document objects or dictionaries
Returns:
List of converted items (dictionaries or document objects)
"""
result = []
# Handle empty or None input
if not items:
return result
# Determine the type of conversion based on the first item
if isinstance(items[0], dict):
# Get the right document class from the items themselves
# Look for a 'doc_type' or '__class__' field in the dictionary
doc_class = None
# Try to get doc class from metadata if available
if 'doc_type' in items[0]:
import importlib
module_path, class_name = items[0]['doc_type'].rsplit('.', 1)
module = importlib.import_module(module_path)
doc_class = getattr(module, class_name)
if not doc_class:
# Fallback to default Document class from llama_index
from llama_index.core.schema import Document
doc_class = Document
# Convert dictionaries to documents
for item in items:
if isinstance(item, dict):
result.append(doc_class.from_dict(item))
else:
# Convert documents to dictionaries
for item in items:
if hasattr(item, 'to_dict'):
result.append(item.to_dict())
return result
return (llamaindex_convert_docs_multi,)
@app.cell
def helper_function_create_df():
def create_document_dataframes(dict_from_docs):
"""
Creates a pandas DataFrame for each file in the dictionary.
Args:
dict_from_docs: Dictionary mapping filenames to lists of documents
Returns:
List of pandas DataFrames, each representing all documents from a single file
"""
dataframes = []
for filename, docs in dict_from_docs.items():
# Create a list to hold all document records for this file
file_records = []
for i, doc in enumerate(docs):
# Convert the document to a format compatible with DataFrame
if hasattr(doc, 'to_dict'):
doc_data = doc.to_dict()
elif isinstance(doc, dict):
doc_data = doc
else:
doc_data = {'content': str(doc)}
# Add document index information
doc_data['doc_index'] = i
# Add to the list of records for this file
file_records.append(doc_data)
# Create a single DataFrame for all documents in this file
if file_records:
df = pd.DataFrame(file_records)
df['filename'] = filename # Add filename as a column
dataframes.append(df)
return dataframes
def create_dataframe_previews(dataframe_list, page_size=5):
"""
Creates a list of mo.ui.dataframe components, one for each DataFrame in the input list.
Args:
dataframe_list: List of pandas DataFrames (output from create_document_dataframes)
page_size: Number of rows to show per page for each component
Returns:
List of mo.ui.dataframe components
"""
# Create a list of mo.ui.dataframe components
preview_components = []
for df in dataframe_list:
# Create a mo.ui.dataframe component for this DataFrame
preview = mo.ui.dataframe(df, page_size=page_size)
preview_components.append(preview)
return preview_components
return
@app.cell
def helper_function_chart_preparation():
import altair as alt
import numpy as np
import plotly.express as px
from sklearn.manifold import TSNE
def prepare_embedding_data(embeddings, texts, model_id=None, embedding_dimensions=None):
"""
Prepare embedding data for visualization
Args:
embeddings: List of embeddings arrays
texts: List of text strings
model_id: Embedding model ID (optional)
embedding_dimensions: Embedding dimensions (optional)
Returns:
DataFrame with processed data and metadata
"""
# Flatten embeddings (in case they're nested)
flattened_embeddings = []
for emb in embeddings:
if isinstance(emb, list) and len(emb) > 0 and isinstance(emb[0], list):
flattened_embeddings.append(emb[0]) # Take first element if nested
else:
flattened_embeddings.append(emb)
# Convert to numpy array
embedding_array = np.array(flattened_embeddings)
# Apply dimensionality reduction (t-SNE)
tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(embedding_array)-1))
reduced_embeddings = tsne.fit_transform(embedding_array)
# Create truncated texts for display
truncated_texts = [text[:50] + "..." if len(text) > 50 else text for text in texts]
# Create dataframe for visualization
df = pd.DataFrame({
"x": reduced_embeddings[:, 0],
"y": reduced_embeddings[:, 1],
"text": truncated_texts,
"full_text": texts,
"index": range(len(texts))
})
# Add metadata
metadata = {
"model_id": model_id,
"embedding_dimensions": embedding_dimensions
}
return df, metadata
def create_embedding_chart(df, metadata=None):
"""
Create an Altair chart for embedding visualization
Args:
df: DataFrame with x, y coordinates and text
metadata: Dictionary with model_id and embedding_dimensions
Returns:
Altair chart
"""
model_id = metadata.get("model_id") if metadata else None
embedding_dimensions = metadata.get("embedding_dimensions") if metadata else None
selection = alt.selection_multi(fields=['index'])
base = alt.Chart(df).encode(
x=alt.X("x:Q", title="Dimension 1"),
y=alt.Y("y:Q", title="Dimension 2"),
tooltip=["text", "index"]
)
points = base.mark_circle(size=100).encode(
color=alt.Color("index:N", legend=None),
opacity=alt.condition(selection, alt.value(1), alt.value(0.2))
).add_selection(selection) # Add this line to apply the selection
text = base.mark_text(align="left", dx=7).encode(
text="index:N"
)
return (points + text).properties(
width=700,
height=500,
title=f"Embedding Visualization{f' - Model: {model_id}' if model_id else ''}{f' ({embedding_dimensions} dimensions)' if embedding_dimensions else ''}"
).interactive()
def show_selected_text(indices, texts):
"""
Create markdown display for selected texts
Args:
indices: List of selected indices
texts: List of all texts
Returns:
Markdown string
"""
if not indices:
return "No text selected"
selected_texts = [texts[i] for i in indices if i < len(texts)]
return "\n\n".join([f"**Document {i}**:\n{text}" for i, text in zip(indices, selected_texts)])
def prepare_embedding_data_3d(embeddings, texts, model_id=None, embedding_dimensions=None):
"""
Prepare embedding data for 3D visualization
Args:
embeddings: List of embeddings arrays
texts: List of text strings
model_id: Embedding model ID (optional)
embedding_dimensions: Embedding dimensions (optional)
Returns:
DataFrame with processed data and metadata
"""
# Flatten embeddings (in case they're nested)
flattened_embeddings = []
for emb in embeddings:
if isinstance(emb, list) and len(emb) > 0 and isinstance(emb[0], list):
flattened_embeddings.append(emb[0])
else:
flattened_embeddings.append(emb)
# Convert to numpy array
embedding_array = np.array(flattened_embeddings)
# Handle the case of a single embedding differently
if len(embedding_array) == 1:
# For a single point, we don't need t-SNE, just use a fixed position
reduced_embeddings = np.array([[0.0, 0.0, 0.0]])
else:
# Apply dimensionality reduction to 3D
# Fix: Ensure perplexity is at least 1.0
perplexity_value = max(1.0, min(30, len(embedding_array)-1))
tsne = TSNE(n_components=3, random_state=42, perplexity=perplexity_value)
reduced_embeddings = tsne.fit_transform(embedding_array)
# Format texts for display
formatted_texts = []
for text in texts:
# Truncate if needed
if len(text) > 500:
text = text[:500] + "..."
# Insert line breaks for wrapping
wrapped_text = ""
for i in range(0, len(text), 50):
wrapped_text += text[i:i+50] + "<br>"
formatted_texts.append("<b>"+wrapped_text+"</b>")
# Create dataframe for visualization
df = pd.DataFrame({
"x": reduced_embeddings[:, 0],
"y": reduced_embeddings[:, 1],
"z": reduced_embeddings[:, 2],
"text": formatted_texts,
"full_text": texts,
"index": range(len(texts)),
"embedding": flattened_embeddings # Store the original embeddings for later use
})
# Add metadata
metadata = {
"model_id": model_id,
"embedding_dimensions": embedding_dimensions
}
return df, metadata
def create_3d_embedding_chart(df, metadata=None, chart_width=1200, chart_height=800, marker_size_var: int=3):
"""
Create a 3D Plotly chart for embedding visualization with proximity-based coloring
"""
model_id = metadata.get("model_id") if metadata else None
embedding_dimensions = metadata.get("embedding_dimensions") if metadata else None
# Calculate the proximity between points
from scipy.spatial.distance import pdist, squareform
# Get the coordinates as a numpy array
coords = df[['x', 'y', 'z']].values
# Calculate pairwise distances
dist_matrix = squareform(pdist(coords))
# For each point, find its average distance to all other points
avg_distances = np.mean(dist_matrix, axis=1)
# Add this to the dataframe - smaller values = closer to other points
df['proximity'] = avg_distances
# Create 3D scatter plot with proximity-based coloring
fig = px.scatter_3d(
df,
x='x',
y='y',
z='z',
# x='petal_length', # Changed from 'x' to 'petal_length'
# y='petal_width', # Changed from 'y' to 'petal_width'
# z='petal_height',
color='proximity', # Color based on proximity
color_continuous_scale='Viridis_r', # Reversed so closer points are warmer colors
hover_data=['text', 'index', 'proximity'],
labels={'x': 'Dimension 1', 'y': 'Dimension 2', 'z': 'Dimension 3', 'proximity': 'Avg Distance'},
# labels={'x': 'Dimension 1', 'y': 'Dimension 2', 'z': 'Dimension 3', 'proximity': 'Avg Distance'},
title=f"<b>3D Embedding Visualization</b>{f' - Model: <b>{model_id}</b>' if model_id else ''}{f' <i>({embedding_dimensions} dimensions)</i>' if embedding_dimensions else ''}",
text='index',
# size_max=marker_size_var
)
# Update marker size and layout
# fig.update_traces(marker=dict(size=3), selector=dict(mode='markers'))
fig.update_traces(
marker=dict(
size=marker_size_var, # Very small marker size
opacity=0.7, # Slightly transparent
symbol="diamond", # Use circle markers (other options: "square", "diamond", "cross", "x")
line=dict(
width=0.5, # Very thin border
color="white" # White outline makes small dots more visible
)
),
textfont=dict(
color="rgba(255, 255, 255, 0.3)",
size=8
),
# hovertemplate="<b>index=%{text}</b><br>%{customdata[0]}<br><br>Avg Distance=%{customdata[2]:.4f}<extra></extra>", ### Hover Changes
hovertemplate="text:<br><b>%{customdata[0]}</b><br>index: <b>%{text}</b><br><br>Avg Distance: <b>%{customdata[2]:.4f}</b><extra></extra>",
hoverinfo="text+name",
hoverlabel=dict(
bgcolor="white", # White background for hover labels
font_size=12 # Font size for hover text
),
selector=dict(type='scatter3d')
)
# Keep your existing layout settings
fig.update_layout(
scene=dict(
xaxis=dict(
title='Dimension 1',
nticks=40,
backgroundcolor="rgb(10, 10, 20, 0.1)",
gridcolor="white",
showbackground=True,
gridwidth=0.35,
zerolinecolor="white",
),
yaxis=dict(
title='Dimension 2',
nticks=40,
backgroundcolor="rgb(10, 10, 20, 0.1)",
gridcolor="white",
showbackground=True,
gridwidth=0.35,
zerolinecolor="white",
),
zaxis=dict(
title='Dimension 3',
nticks=40,
backgroundcolor="rgb(10, 10, 20, 0.1)",
gridcolor="white",
showbackground=True,
gridwidth=0.35,
zerolinecolor="white",
),
# Control camera view angle
camera=dict(
up=dict(x=0, y=0, z=1),
center=dict(x=0, y=0, z=0),
eye=dict(x=1.25, y=1.25, z=1.25),
),
aspectratio=dict(x=1, y=1, z=1),
aspectmode='data'
),
width=int(chart_width),
height=int(chart_height),
margin=dict(r=20, l=10, b=10, t=50),
paper_bgcolor="rgb(0, 0, 0)",
plot_bgcolor="rgb(0, 0, 0)",
coloraxis_colorbar=dict(
title="Average Distance",
thicknessmode="pixels", thickness=20,
lenmode="pixels", len=400,
yanchor="top", y=1,
ticks="outside",
dtick=0.1
)
)
return fig
return create_3d_embedding_chart, prepare_embedding_data_3d
@app.cell
def helper_function_text_preparation():
def convert_table_to_json_docs(df, selected_columns=None):
"""
Convert a pandas DataFrame or dictionary to a list of JSON documents.
Dynamically includes columns based on user selection.
Column names are standardized to lowercase with underscores instead of spaces
and special characters removed.
Args:
df: The DataFrame or dictionary to process
selected_columns: List of column names to include in the output documents
Returns:
list: A list of dictionaries, each representing a row as a JSON document
"""
import pandas as pd
import re
def standardize_key(key):
"""Convert a column name to lowercase with underscores instead of spaces and no special characters"""
if not isinstance(key, str):
return str(key).lower()
# Replace spaces with underscores and convert to lowercase
key = key.lower().replace(' ', '_')
# Remove special characters (keeping alphanumeric and underscores)
return re.sub(r'[^\w]', '', key)
# Handle case when input is a dictionary
if isinstance(df, dict):
# Filter the dictionary to include only selected columns
if selected_columns:
return [{standardize_key(k): df.get(k, None) for k in selected_columns}]
else:
# If no columns selected, return all key-value pairs with standardized keys
return [{standardize_key(k): v for k, v in df.items()}]
# Handle case when df is None
if df is None:
return []
# Ensure df is a DataFrame
if not isinstance(df, pd.DataFrame):
try:
df = pd.DataFrame(df)
except:
return [] # Return empty list if conversion fails
# Now check if DataFrame is empty
if df.empty:
return []
# If no columns are specifically selected, use all available columns
if not selected_columns or not isinstance(selected_columns, list) or len(selected_columns) == 0:
selected_columns = list(df.columns)
# Determine which columns exist in the DataFrame
available_columns = []
columns_lower = {col.lower(): col for col in df.columns if isinstance(col, str)}
for col in selected_columns:
if col in df.columns:
available_columns.append(col)
elif isinstance(col, str) and col.lower() in columns_lower:
available_columns.append(columns_lower[col.lower()])
# If no valid columns found, return empty list
if not available_columns:
return []
# Process rows
json_docs = []
for _, row in df.iterrows():
doc = {}
for col in available_columns:
value = row[col]
# Standardize the column name when adding to document
std_col = standardize_key(col)
doc[std_col] = None if pd.isna(value) else value
json_docs.append(doc)
return json_docs
def get_column_values(df, columns_to_include):
"""
Extract values from specified columns of a dataframe as lists.
Args:
df: A pandas DataFrame
columns_to_include: A list of column names to extract
Returns:
Dictionary with column names as keys and their values as lists
"""
result = {}
# Validate that columns exist in the dataframe
valid_columns = [col for col in columns_to_include if col in df.columns]
invalid_columns = set(columns_to_include) - set(valid_columns)
if invalid_columns:
print(f"Warning: These columns don't exist in the dataframe: {list(invalid_columns)}")
# Extract values for each valid column
for col in valid_columns:
result[col] = df[col].tolist()
return result
def get_data_in_range(doc_dict_df, index_range, columns_to_include):
"""
Extract values from specified columns of a dataframe within a given index range.
Args:
doc_dict_df: The pandas DataFrame to extract data from
index_range: An integer specifying the number of rows to include (from 0 to index_range-1)
columns_to_include: A list of column names to extract
Returns:
Dictionary with column names as keys and their values (within the index range) as lists
"""
# Validate the index range
max_index = len(doc_dict_df)
if index_range <= 0:
print(f"Warning: Invalid index range {index_range}. Must be positive.")
return {}
# Adjust index_range if it exceeds the dataframe length
if index_range > max_index:
print(f"Warning: Index range {index_range} exceeds dataframe length {max_index}. Using maximum length.")
index_range = max_index
# Slice the dataframe to get rows from 0 to index_range-1
df_subset = doc_dict_df.iloc[:index_range]
# Use the provided get_column_values function to extract column data
return get_column_values(df_subset, columns_to_include)
def get_data_in_range_triplequote(doc_dict_df, index_range, columns_to_include):
"""
Extract values from specified columns of a dataframe within a given index range.
Wraps string values with triple quotes and escapes URLs.
Args:
doc_dict_df: The pandas DataFrame to extract data from
index_range: A list of two integers specifying the start and end indices of rows to include
(e.g., [0, 10] includes rows from index 0 to 9 inclusive)
columns_to_include: A list of column names to extract
"""
# Validate the index range
start_idx, end_idx = index_range
max_index = len(doc_dict_df)
# Validate start index
if start_idx < 0:
print(f"Warning: Invalid start index {start_idx}. Using 0 instead.")
start_idx = 0
# Validate end index
if end_idx <= start_idx:
print(f"Warning: End index {end_idx} must be greater than start index {start_idx}. Using {start_idx + 1} instead.")
end_idx = start_idx + 1
# Adjust end index if it exceeds the dataframe length
if end_idx > max_index:
print(f"Warning: End index {end_idx} exceeds dataframe length {max_index}. Using maximum length.")
end_idx = max_index
# Slice the dataframe to get rows from start_idx to end_idx-1
# Using .loc with slice to preserve original indices
df_subset = doc_dict_df.iloc[start_idx:end_idx]
# Use the provided get_column_values function to extract column data
result = get_column_values(df_subset, columns_to_include)
# Process each string result to wrap in triple quotes
for col in result:
if isinstance(result[col], list):
# Create a new list with items wrapped in triple quotes
processed_items = []
for item in result[col]:
if isinstance(item, str):
# Replace http:// and https:// with escaped versions
item = item.replace("http://", "http\\://").replace("https://", "https\\://")
# processed_items.append('"""' + item + '"""')
processed_items.append(item)
else:
processed_items.append(item)
result[col] = processed_items
return result
return (get_data_in_range_triplequote,)
@app.cell
def prepare_doc_select(sentence_splitter_config):
def prepare_document_selection(node_dict):
"""
Creates document selection UI component.
Args:
node_dict: Dictionary mapping filenames to lists of documents
Returns:
mo.ui component for document selection
"""
# Calculate total number of documents across all files
total_docs = sum(len(docs) for docs in node_dict.values())
# Create a combined DataFrame of all documents for table selection
all_docs_records = []
doc_index_global = 0
for filename, docs in node_dict.items():
for i, doc in enumerate(docs):
# Convert the document to a format compatible with DataFrame
if hasattr(doc, 'to_dict'):
doc_data = doc.to_dict()
elif isinstance(doc, dict):
doc_data = doc
else:
doc_data = {'content': str(doc)}
# Add metadata
doc_data['filename'] = filename
doc_data['doc_index'] = i
doc_data['global_index'] = doc_index_global
all_docs_records.append(doc_data)
doc_index_global += 1
# Create UI component
stop_value = max(total_docs, 1)
llama_docs = mo.ui.range_slider(
start=1,
stop=stop_value,
step=1,
full_width=True,
show_value=True,
label="**Select a Range of Chunks to Visualize:**"
).form(submit_button_disabled=check_state(sentence_splitter_config.value))
return llama_docs
return (prepare_document_selection,)
@app.cell
def document_range_selection(
dict_from_nodes,
prepare_document_selection,
set_range_slider_state,
):
if dict_from_nodes is not None:
llama_docs = prepare_document_selection(dict_from_nodes)
set_range_slider_state(llama_docs)
else:
bare_dict = {}
llama_docs = prepare_document_selection(bare_dict)
return
@app.function
def create_cumulative_dataframe(dict_from_docs):
"""
Creates a cumulative DataFrame from a nested dictionary of documents.
Args:
dict_from_docs: Dictionary mapping filenames to lists of documents
Returns:
DataFrame with all documents flattened with global indices
"""
# Create a list to hold all document records
all_records = []
global_idx = 1 # Start from 1 to match range slider expectations
for filename, docs in dict_from_docs.items():
for i, doc in enumerate(docs):
# Convert the document to a dict format
if hasattr(doc, 'to_dict'):
doc_data = doc.to_dict()
elif isinstance(doc, dict):
doc_data = doc.copy()
else:
doc_data = {'content': str(doc)}
# Add additional metadata
doc_data['filename'] = filename
doc_data['doc_index'] = i
doc_data['global_index'] = global_idx
# If there's 'content' but no 'text', create a 'text' field
if 'content' in doc_data and 'text' not in doc_data:
doc_data['text'] = doc_data['content']
all_records.append(doc_data)
global_idx += 1
# Create DataFrame from all records
return pd.DataFrame(all_records)
@app.function
def create_stats(texts_dict, bordered=False, object_names=None, group_by_row=False, items_per_row=6, gap=2, label="Chunk"):
"""
Create a list of stat objects for each item in the specified dictionary.
Parameters:
- texts_dict (dict): Dictionary containing the text data
- bordered (bool): Whether the stats should be bordered
- object_names (list or tuple): Two object names to use for label and value
[label_object, value_object]
- group_by_row (bool): Whether to group stats in rows (horizontal stacks)
- items_per_row (int): Number of stat objects per row when group_by_row is True
Returns:
- object: A vertical stack of stat objects or rows of stat objects
"""
if not object_names or len(object_names) < 2:
raise ValueError("You must provide two object names as a list or tuple")
label_object = object_names[0]
value_object = object_names[1]
# Validate that both objects exist in the dictionary
if label_object not in texts_dict:
raise ValueError(f"Label object '{label_object}' not found in texts_dict")
if value_object not in texts_dict:
raise ValueError(f"Value object '{value_object}' not found in texts_dict")
# Determine how many items to process (based on the label object length)
num_items = len(texts_dict[label_object])
# Create individual stat objects
individual_stats = []
for i in range(num_items):
stat = mo.stat(
label=texts_dict[label_object][i],
value=f"{label} Number: {len(texts_dict[value_object][i])}",
bordered=bordered
)
individual_stats.append(stat)
# If grouping is not enabled, just return a vertical stack of all stats
if not group_by_row:
return mo.vstack(individual_stats, wrap=False)
# Group stats into rows based on items_per_row
rows = []
for i in range(0, num_items, items_per_row):
# Get a slice of stats for this row (up to items_per_row items)
row_stats = individual_stats[i:i+items_per_row]
# Create a horizontal stack for this row
widths = [0.35] * len(row_stats)
row = mo.hstack(row_stats, gap=gap, align="start", justify="center", widths=widths)
rows.append(row)
# Return a vertical stack of all rows
return mo.vstack(rows)
@app.cell
def prepare_chart_embeddings(
chunks_to_process,
emb_model,
emb_model_emb_dim,
get_embedding_state,
prepare_embedding_data_3d,
):
# chart_dataframe, chart_metadata = None, None
if chunks_to_process is not None and get_embedding_state() is not None:
chart_dataframe, chart_metadata = prepare_embedding_data_3d(
get_embedding_state(),
chunks_to_process,
model_id=emb_model,
embedding_dimensions=emb_model_emb_dim
)
else:
chart_dataframe, chart_metadata = None, None
return chart_dataframe, chart_metadata
@app.cell
def chart_dims():
chart_dimensions = (
mo.md('''
> **Adjust Chart Window**
{chart_height}
{chat_width}
''').batch(
chart_height = mo.ui.slider(start=500, step=30, stop=1000, label="**Height:**", value=800, show_value=True),
chat_width = mo.ui.slider(start=900, step=50, stop=1400, label="**Width:**", value=1200, show_value=True)
)
)
return (chart_dimensions,)
@app.cell
def chart_dim_values(chart_dimensions):
chart_height = chart_dimensions.value['chart_height']
chart_width = chart_dimensions.value['chat_width']
return chart_height, chart_width
@app.cell
def create_baseline_chart(
chart_dataframe,
chart_height,
chart_metadata,
chart_width,
create_3d_embedding_chart,
):
if chart_dataframe is not None and chart_metadata is not None:
emb_plot = create_3d_embedding_chart(chart_dataframe, chart_metadata, chart_width, chart_height, marker_size_var=9)
chart = mo.ui.plotly(emb_plot)
else:
emb_plot = None
chart = None
return (emb_plot,)
@app.cell
def test_query(get_chunk_state):
placeholder = """How can i use watsonx.data to perform vector search?"""
query = mo.ui.text_area(label="**Write text to check:**", full_width=True, rows=8, value=placeholder).form(show_clear_button=True, submit_button_disabled=check_state(get_chunk_state()))
return (query,)
@app.cell
def query_stack(chart_dimensions, query):
# query_stack = mo.hstack([query], justify="space-around", align="center", widths=[0.65])
query_stack = mo.hstack([query, chart_dimensions], justify="space-around", align="center", gap=15)
return (query_stack,)
@app.function
def check_state(variable):
return variable is None
@app.cell
def helper_function_add_query_to_chart():
def add_query_to_embedding_chart(existing_chart, query_coords, query_text, marker_size=12):
"""
Add a query point to an existing 3D embedding chart as a large red dot.
Args:
existing_chart: The existing plotly figure or chart data
query_coords: Dictionary with 'x', 'y', 'z' coordinates for the query point
query_text: Text of the query to display on hover
marker_size: Size of the query marker (default: 18, typically 2x other markers)
Returns:
A modified plotly figure with the query point added as a red dot
"""
import plotly.graph_objects as go
# Create a deep copy of the existing chart to avoid modifying the original
import copy
chart_copy = copy.deepcopy(existing_chart)
# Handle case where chart_copy is a dictionary or list (from mo.ui.plotly)
if isinstance(chart_copy, (dict, list)):
# Create a new plotly figure from the data
import plotly.graph_objects as go
if isinstance(chart_copy, list):
# If it's a list, assume it's a list of traces
fig = go.Figure(data=chart_copy)
else:
# If it's a dict with 'data' and 'layout'
fig = go.Figure(data=chart_copy.get('data', []), layout=chart_copy.get('layout', {}))
chart_copy = fig
# Create the query trace
query_trace = go.Scatter3d(
x=[query_coords['x']],
y=[query_coords['y']],
z=[query_coords['z']],
mode='markers',
name='Query',
marker=dict(
size=marker_size, # Typically 2x the size of other markers
color='red', # Bright red color
symbol='circle', # Circle shape
opacity=0.70, # Fully opaque
line=dict(
width=1, # Thin white border
color='white'
)
),
# text=['Query: ' + query_text],
text=['<b>Query:</b><br>' + '<br>'.join([query_text[i:i+50] for i in range(0, len(query_text), 50)])], ### Text Wrapping
hoverinfo="text+name"
)
# Add the query trace to the chart copy
chart_copy.add_trace(query_trace)
return chart_copy
def get_query_coordinates(reference_embeddings=None, query_embedding=None):
"""
Calculate appropriate coordinates for a query point based on reference embeddings.
This function handles several scenarios:
1. If both reference embeddings and query embedding are provided, it places the
query near similar documents.
2. If only reference embeddings are provided, it places the query at a visible
location near the center of the chart.
3. If neither are provided, it returns default origin coordinates.
Args:
reference_embeddings: DataFrame with x, y, z coordinates from the main chart
query_embedding: The embedding vector of the query
Returns:
Dictionary with x, y, z coordinates for the query point
"""
import numpy as np
# Default coordinates (origin with slight offset)
default_coords = {'x': 0.0, 'y': 0.0, 'z': 0.0}
# If we don't have reference embeddings, return default
if reference_embeddings is None or len(reference_embeddings) == 0:
return default_coords
# If we have reference embeddings but no query embedding,
# position at a visible location near the center
if query_embedding is None:
center_coords = {
'x': reference_embeddings['x'].mean(),
'y': reference_embeddings['y'].mean(),
'z': reference_embeddings['z'].mean()
}
return center_coords
# If we have both reference embeddings and query embedding,
# try to position near similar documents
try:
from sklearn.metrics.pairwise import cosine_similarity
# Check if original embeddings are in the dataframe
if 'embedding' in reference_embeddings.columns:
# Get all document embeddings as a 2D array
if isinstance(reference_embeddings['embedding'].iloc[0], list):
doc_embeddings = np.array(reference_embeddings['embedding'].tolist())
else:
doc_embeddings = np.array([emb for emb in reference_embeddings['embedding'].values])
# Reshape query embedding for comparison
query_emb_array = np.array(query_embedding)
if query_emb_array.ndim == 1:
query_emb_array = query_emb_array.reshape(1, -1)
# Calculate cosine similarities
similarities = cosine_similarity(query_emb_array, doc_embeddings)[0]
# Find the closest document
closest_idx = np.argmax(similarities)
# Use the position of the closest document, with slight offset for visibility
query_coords = {
'x': reference_embeddings['x'].iloc[closest_idx] + 0.2,
'y': reference_embeddings['y'].iloc[closest_idx] + 0.2,
'z': reference_embeddings['z'].iloc[closest_idx] + 0.2
}
return query_coords
except Exception as e:
print(f"Error positioning query near similar documents: {e}")
# Fallback to center position if similarity calculation fails
center_coords = {
'x': reference_embeddings['x'].mean(),
'y': reference_embeddings['y'].mean(),
'z': reference_embeddings['z'].mean()
}
return center_coords
return add_query_to_embedding_chart, get_query_coordinates
@app.cell
def combined_chart_visualization(
add_query_to_embedding_chart,
chart_dataframe,
emb_plot,
embedding,
get_query_coordinates,
get_query_state,
query,
set_chart_state,
set_query_state,
):
# Usage with highlight_closest=True
if chart_dataframe is not None and query.value:
with mo.status.spinner(title="Embedding Query...", remove_on_exit=True) as _spinner:
query_emb = embedding.embed_documents([query.value])
set_query_state(query_emb)
_spinner.update("Preparing Query Coordinates") # --- --- ---
time.sleep(1.0)
# Get appropriate coordinates for the query
query_coords = get_query_coordinates(
reference_embeddings=chart_dataframe,
query_embedding=get_query_state()
)
_spinner.update("Adding Query to Chart") # --- --- ---
time.sleep(1.0)
# Add the query to the chart with closest points highlighted
result = add_query_to_embedding_chart(
existing_chart=emb_plot,
query_coords=query_coords,
query_text=query.value,
)
chart_with_query = result
_spinner.update("Preparing Visualization") # --- --- ---
time.sleep(1.0)
# Create the visualization
combined_viz = mo.ui.plotly(chart_with_query)
set_chart_state(combined_viz)
_spinner.update("Done") # --- --- ---
else:
combined_viz = None
return
@app.cell
def _():
get_range_slider_state, set_range_slider_state = mo.state(None)
return get_range_slider_state, set_range_slider_state
@app.cell
def _(get_range_slider_state):
if get_range_slider_state() is not None:
document_range_stack = get_range_slider_state()
else:
document_range_stack = None
return (document_range_stack,)
@app.cell
def _():
get_chart_state, set_chart_state = mo.state(None)
return get_chart_state, set_chart_state
@app.cell
def _(get_chart_state, query):
if query.value is not None:
chart_visualization = get_chart_state()
else:
chart_visualization = None
return (chart_visualization,)
@app.cell
def c(document_range_stack):
chart_range_selection = mo.hstack([document_range_stack], justify="space-around", align="center", widths=[0.65])
return (chart_range_selection,)
if __name__ == "__main__":
app.run()