import marimo __generated_with = "0.13.0" app = marimo.App(width="full") with app.setup: # Initialization code that runs before all other cells import marimo as mo from typing import Dict, Optional, List, Union, Any from ibm_watsonx_ai import APIClient, Credentials from pathlib import Path import pandas as pd import mimetypes import requests import zipfile import tempfile import certifi import base64 import polars import nltk import time import json import ast import os import io import re def get_iam_token(api_key): return requests.post( 'https://iam.cloud.ibm.com/identity/token', headers={'Content-Type': 'application/x-www-form-urlencoded'}, data={'grant_type': 'urn:ibm:params:oauth:grant-type:apikey', 'apikey': api_key}, verify=certifi.where() ).json()['access_token'] def setup_task_credentials(client): # Get existing task credentials existing_credentials = client.task_credentials.get_details() # Delete existing credentials if any if "resources" in existing_credentials and existing_credentials["resources"]: for cred in existing_credentials["resources"]: cred_id = client.task_credentials.get_id(cred) client.task_credentials.delete(cred_id) # Store new credentials return client.task_credentials.store() def get_cred_value(key, creds_var_name="baked_in_creds", default=""): ### Helper for working with preset credentials """ Helper function to safely get a value from a credentials dictionary. Args: key: The key to look up in the credentials dictionary. creds_var_name: The variable name of the credentials dictionary. default: The default value to return if the key is not found. Returns: The value from the credentials dictionary if it exists and contains the key, otherwise returns the default value. """ # Check if the credentials variable exists in globals if creds_var_name in globals(): creds_dict = globals()[creds_var_name] if isinstance(creds_dict, dict) and key in creds_dict: return creds_dict[key] return default @app.cell def client_variables(client_instantiation_form): if client_instantiation_form.value: client_setup = client_instantiation_form.value else: client_setup = None ### Extract Credential Variables: if client_setup is not None: wx_url = client_setup["wx_region"] wx_api_key = client_setup["wx_api_key"].strip() os.environ["WATSONX_APIKEY"] = wx_api_key if client_setup["project_id"] is not None: project_id = client_setup["project_id"].strip() else: project_id = None if client_setup["space_id"] is not None: space_id = client_setup["space_id"].strip() else: space_id = None else: os.environ["WATSONX_APIKEY"] = "" project_id = None space_id = None wx_api_key = None wx_url = None return client_setup, project_id, space_id, wx_api_key, wx_url @app.cell def _(client_setup, wx_api_key): if client_setup: token = get_iam_token(wx_api_key) else: token = None return @app.cell def _(): baked_in_creds = { "purpose": "", "api_key": "", "project_id": "", "space_id": "", } return baked_in_creds @app.cell def client_instantiation( client_setup, project_id, space_id, wx_api_key, wx_url, ): ### Instantiate the watsonx.ai client if client_setup: wx_credentials = Credentials( url=wx_url, api_key=wx_api_key ) if project_id: project_client = APIClient(credentials=wx_credentials, project_id=project_id) else: project_client = None if space_id: deployment_client = APIClient(credentials=wx_credentials, space_id=space_id) else: deployment_client = None if project_client is not None: task_credentials_details = setup_task_credentials(project_client) else: task_credentials_details = setup_task_credentials(deployment_client) else: wx_credentials = None project_client = None deployment_client = None task_credentials_details = None client_status = mo.md("### Client Instantiation Status will turn Green When Ready") if project_client is not None or deployment_client is not None: client_callout_kind = "success" else: client_callout_kind = "neutral" return ( client_callout_kind, client_status, deployment_client, project_client, ) @app.cell def _(): mo.md( r""" #watsonx.ai Embedding Visualizer - Marimo Notebook #### This marimo notebook can be used to develop a more intuitive understanding of how vector embeddings work by creating a 3D visualization of vector embeddings based on chunked PDF document pages. #### It can also serve as a useful tool for identifying gaps in model choice, chunking strategy or contents used in building collections by showing how far you are from what you want.
/// admonition Created by ***Milan Mrdenovic*** [milan.mrdenovic@ibm.com] for IBM Ecosystem Client Engineering, NCEE - ***version 5.3** - 20.04.2025* /// >Licensed under apache 2.0, users hold full accountability for any use or modification of the code. >
This asset is part of a set meant to support IBMers, IBM Partners, Clients in developing understanding of how to better utilize various watsonx features and generative AI as a subject matter.
""" ) return @app.cell def _(): mo.md("""###Part 1 - Client Setup, File Preparation and Chunking""") return @app.cell def accordion_client_setup(client_selector, client_stack): ui_accordion_part_1_1 = mo.accordion( { "Instantiate Client": mo.vstack([client_stack, client_selector], align="center"), } ) ui_accordion_part_1_1 return @app.cell def accordion_file_upload(select_stack): ui_accordion_part_1_2 = mo.accordion( { "Select Model & Upload Files": select_stack } ) ui_accordion_part_1_2 return @app.cell def loaded_texts( create_temp_files_from_uploads, file_loader, pdf_reader, run_upload_button, set_text_state, ): if file_loader.value is not None and run_upload_button.value: filepaths = create_temp_files_from_uploads(file_loader.value) loaded_texts = load_pdf_data_with_progress(pdf_reader, filepaths, file_loader.value, show_progress=True) set_text_state(loaded_texts) else: filepaths = None loaded_texts = None return @app.cell def accordion_chunker_setup(chunker_setup): ui_accordion_part_1_3 = mo.accordion( { "Chunker Setup": chunker_setup } ) ui_accordion_part_1_3 return @app.cell def chunk_documents_to_nodes( get_text_state, sentence_splitter, sentence_splitter_config, set_chunk_state, ): if sentence_splitter_config.value and sentence_splitter and get_text_state() is not None: chunked_texts = chunk_documents(get_text_state(), sentence_splitter, show_progress=True) set_chunk_state(chunked_texts) else: chunked_texts = None return (chunked_texts,) @app.cell def _(): mo.md(r"""###Part 2 - Query Setup and Visualization""") return @app.cell def accordion_chunk_range(chart_range_selection): ui_accordion_part_2_1 = mo.accordion( { "Chunk Range Selection": chart_range_selection } ) ui_accordion_part_2_1 return @app.cell def chunk_embedding( chunks_to_process, embedding, sentence_splitter_config, set_embedding_state, ): if sentence_splitter_config.value is not None and chunks_to_process is not None: with mo.status.spinner(title="Embedding Documents...", remove_on_exit=True) as _spinner: output_embeddings = embedding.embed_documents(chunks_to_process) _spinner.update("Almost Done") time.sleep(1.5) set_embedding_state(output_embeddings) _spinner.update("Documents Embedded") else: output_embeddings = None return @app.cell def preview_chunks(chunks_dict): if chunks_dict is not None: stats = create_stats(chunks_dict, bordered=True, object_names=['text','text'], group_by_row=True, items_per_row=5, gap=1, label="Chunk") ui_chunk_viewer = mo.accordion( { "View Chunks": stats, } ) else: ui_chunk_viewer = None ui_chunk_viewer return @app.cell def accordion_query_view(chart_visualization, query_stack): ui_accordion_part_2_2 = mo.accordion( { "Query": mo.vstack([query_stack, mo.hstack([chart_visualization])], align="center", gap=3) } ) ui_accordion_part_2_2 return @app.cell def chunker_setup(sentence_splitter_config): chunker_setup = mo.hstack([sentence_splitter_config], justify="space-around", align="center", widths=[0.55]) return (chunker_setup,) @app.cell def file_and_model_select( file_loader, get_embedding_model_list, run_upload_button, ): select_stack = mo.hstack([get_embedding_model_list(), mo.vstack([file_loader, run_upload_button], align="center")], justify="space-around", align="center", widths=[0.3,0.3]) return (select_stack,) @app.cell def client_instantiation_form(): # Endpoints wx_platform_url = "https://api.dataplatform.cloud.ibm.com" regions = { "US": "https://us-south.ml.cloud.ibm.com", "EU": "https://eu-de.ml.cloud.ibm.com", "GB": "https://eu-gb.ml.cloud.ibm.com", "JP": "https://jp-tok.ml.cloud.ibm.com", "AU": "https://au-syd.ml.cloud.ibm.com", "CA": "https://ca-tor.ml.cloud.ibm.com" } # Create a form with multiple elements client_instantiation_form = ( mo.md(''' ###**watsonx.ai credentials:** {wx_region} {wx_api_key} {project_id} {space_id} > You can add either a project_id, space_id or both, **only one is required**. > If you provide both you can switch the active one in the dropdown. ''') .batch( wx_region = mo.ui.dropdown(regions, label="Select your watsonx.ai region:", value="US", searchable=True), wx_api_key = mo.ui.text(placeholder="Add your IBM Cloud api-key...", label="IBM Cloud Api-key:", kind="password", value=get_cred_value('api_key', creds_var_name='baked_in_creds')), project_id = mo.ui.text(placeholder="Add your watsonx.ai project_id...", label="Project_ID:", kind="text", value=get_cred_value('project_id', creds_var_name='baked_in_creds')), space_id = mo.ui.text(placeholder="Add your watsonx.ai space_id...", label="Space_ID:", kind="text", value=get_cred_value('space_id', creds_var_name='baked_in_creds')) ,) .form(show_clear_button=True, bordered=False) ) return (client_instantiation_form,) @app.cell def instantiation_status( client_callout_kind, client_instantiation_form, client_status, ): client_callout = mo.callout(client_status, kind=client_callout_kind) client_stack = mo.hstack([client_instantiation_form, client_callout], align="center", justify="space-around", gap=10) return (client_stack,) @app.cell def client_selector(deployment_client, project_client): if deployment_client is not None: client_options = {"Deployment Client":deployment_client} elif project_client is not None: client_options = {"Project Client":project_client} elif project_client is not None and deployment_client is not None: client_options = {"Project Client":project_client,"Deployment Client":deployment_client} else: client_options = {"No Client": "Instantiate a Client"} default_client = next(iter(client_options)) client_selector = mo.ui.dropdown(client_options, value=default_client, label="**Select your active client:**") return (client_selector,) @app.cell def active_client(client_selector): client_key = client_selector.value if client_key == "Instantiate a Client": client = None else: client = client_key return (client,) @app.cell def emb_model_selection(client, set_embedding_model_list): if client is not None: model_specs = client.foundation_models.get_embeddings_model_specs() # model_specs = client.foundation_models.get_model_specs() resources = model_specs["resources"] # Define embedding models reference data embedding_models = { "ibm/granite-embedding-107m-multilingual": {"max_tokens": 512, "embedding_dimensions": 384}, "ibm/granite-embedding-278m-multilingual": {"max_tokens": 512, "embedding_dimensions": 768}, "ibm/slate-125m-english-rtrvr-v2": {"max_tokens": 512, "embedding_dimensions": 768}, "ibm/slate-125m-english-rtrvr": {"max_tokens": 512, "embedding_dimensions": 768}, "ibm/slate-30m-english-rtrvr-v2": {"max_tokens": 512, "embedding_dimensions": 384}, "ibm/slate-30m-english-rtrvr": {"max_tokens": 512, "embedding_dimensions": 384}, "sentence-transformers/all-minilm-l6-v2": {"max_tokens": 128, "embedding_dimensions": 384}, "sentence-transformers/all-minilm-l12-v2": {"max_tokens": 128, "embedding_dimensions": 384}, "intfloat/multilingual-e5-large": {"max_tokens": 512, "embedding_dimensions": 1024} } # Get model IDs from resources model_id_list = [] for resource in resources: model_id_list.append(resource["model_id"]) # Create enhanced model data for the table embedding_model_data = [] for model_id in model_id_list: model_entry = {"model_id": model_id} # Add properties if model exists in our reference, otherwise use 0 if model_id in embedding_models: model_entry["max_tokens"] = embedding_models[model_id]["max_tokens"] model_entry["embedding_dimensions"] = embedding_models[model_id]["embedding_dimensions"] else: model_entry["max_tokens"] = 0 model_entry["embedding_dimensions"] = 0 embedding_model_data.append(model_entry) embedding_model_selection = mo.ui.table( embedding_model_data, selection="single", # Only allow selecting one row label="Select an embedding model to use.", page_size=30, initial_selection=[1] ) set_embedding_model_list(embedding_model_selection) else: default_model_data = [{ "model_id": "ibm/granite-embedding-107m-multilingual", "max_tokens": 512, "embedding_dimensions": 384 }] set_embedding_model_list(create_emb_model_selection_table(default_model_data, initial_selection=0, selection_type="single", label="Select a model to use.")) return @app.function def create_emb_model_selection_table(model_data, initial_selection=0, selection_type="single", label="Select a model to use."): embedding_model_selection = mo.ui.table( model_data, selection=selection_type, # Only allow selecting one row label=label, page_size=30, initial_selection=[initial_selection] ) return embedding_model_selection @app.cell def embedding_model(): get_embedding_model_list, set_embedding_model_list = mo.state(None) return get_embedding_model_list, set_embedding_model_list @app.cell def emb_model_parameters(emb_model_max_tk): from ibm_watsonx_ai.metanames import EmbedTextParamsMetaNames as EmbedParams if embedding_model is not None: embed_params = { EmbedParams.TRUNCATE_INPUT_TOKENS: emb_model_max_tk, EmbedParams.RETURN_OPTIONS: { 'input_text': True } } else: embed_params = { EmbedParams.TRUNCATE_INPUT_TOKENS: 128, EmbedParams.RETURN_OPTIONS: { 'input_text': True } } return embed_params @app.cell def emb_model_state(get_embedding_model_list): embedding_model = get_embedding_model_list() return (embedding_model,) @app.cell def emb_model_setup(embedding_model): if embedding_model is not None: emb_model = embedding_model.value[0]['model_id'] emb_model_max_tk = embedding_model.value[0]['max_tokens'] emb_model_emb_dim = embedding_model.value[0]['embedding_dimensions'] else: emb_model = None emb_model_max_tk = None emb_model_emb_dim = None return emb_model, emb_model_emb_dim, emb_model_max_tk @app.cell def emb_model_instantiation(client, emb_model, embed_params): from ibm_watsonx_ai.foundation_models import Embeddings if client is not None: embedding = Embeddings( model_id=emb_model, api_client=client, params=embed_params, batch_size=1000, concurrency_limit=10 ) else: embedding = None return (embedding,) @app.cell def _(): get_embedding_state, set_embedding_state = mo.state(None) return get_embedding_state, set_embedding_state @app.cell def _(): get_query_state, set_query_state = mo.state(None) return get_query_state, set_query_state @app.cell def file_loader_input(): file_loader = mo.ui.file( kind="area", filetypes=[".pdf"], label=" Load .pdf files ", multiple=True ) return (file_loader,) @app.cell def file_loader_run(file_loader): if file_loader.value: run_upload_button = mo.ui.run_button(label="Load Files") else: run_upload_button = mo.ui.run_button(disabled=True, label="Load Files") return (run_upload_button,) @app.cell def helper_function_tempfiles(): def create_temp_files_from_uploads(upload_results) -> List[str]: """ Creates temporary files from a tuple of FileUploadResults objects and returns their paths. Args: upload_results: Object containing a value attribute that is a tuple of FileUploadResults Returns: List of temporary file paths """ temp_file_paths = [] # Get the number of items in the tuple num_items = len(upload_results) # Process each item by index for i in range(num_items): result = upload_results[i] # Get item by index # Create a temporary file with the original filename temp_dir = tempfile.gettempdir() file_name = result.name temp_path = os.path.join(temp_dir, file_name) # Write the contents to the temp file with open(temp_path, 'wb') as temp_file: temp_file.write(result.contents) # Add the path to our list temp_file_paths.append(temp_path) return temp_file_paths def cleanup_temp_files(temp_file_paths: List[str]) -> None: """Delete temporary files after use.""" for path in temp_file_paths: if os.path.exists(path): os.unlink(path) return (create_temp_files_from_uploads,) @app.function def load_pdf_data_with_progress(pdf_reader, filepaths, file_loader_value, show_progress=True): """ Loads PDF data for each file path and organizes results by original filename. Args: pdf_reader: The PyMuPDFReader instance filepaths: List of temporary file paths file_loader_value: The original upload results value containing file information show_progress: Whether to show a progress bar during loading (default: False) Returns: Dictionary mapping original filenames to their loaded text content """ results = {} # Process files with or without progress bar if show_progress: import marimo as mo # Use progress bar with the length of filepaths as total with mo.status.progress_bar( total=len(filepaths), title="Loading PDFs", subtitle="Processing documents...", completion_title="PDF Loading Complete", completion_subtitle=f"{len(filepaths)} documents processed", remove_on_exit=True ) as bar: # Process each file path for i, file_path in enumerate(filepaths): original_file_name = file_loader_value[i].name bar.update(subtitle=f"Processing {original_file_name}...") loaded_text = pdf_reader.load_data(file_path=file_path, metadata=True) # Store the result with the original filename as the key results[original_file_name] = loaded_text # Update progress bar bar.update(increment=1) else: # Original logic without progress bar for i, file_path in enumerate(filepaths): original_file_name = file_loader_value[i].name loaded_text = pdf_reader.load_data(file_path=file_path, metadata=True) results[original_file_name] = loaded_text return results @app.cell def file_readers(): from llama_index.readers.file import PyMuPDFReader from llama_index.readers.file import FlatReader from llama_index.core.node_parser import SentenceSplitter ### File Readers pdf_reader = PyMuPDFReader() # flat_file_reader = FlatReader() return SentenceSplitter, pdf_reader @app.cell def sentence_splitter_setup(): ### Chunker Setup sentence_splitter_config = ( mo.md(''' ###**Chunking Setup:** > Unless you want to do some advanced sentence splitting, it's best to stick to adjusting only the chunk size and overlap. Changing the other settings might result in unexpected results. Separator value is set to **" "** by default, while the paragraph separator is **"\\n\\n\\n"**. {chunk_size} {chunk_overlap} {separator} {paragraph_separator} {secondary_chunking_regex} {include_metadata} ''') .batch( chunk_size = mo.ui.slider(start=100, stop=5000, step=1, label="**Chunk Size:**", value=350, show_value=True, full_width=True), chunk_overlap = mo.ui.slider(start=1, stop=1000, step=1, label="**Chunk Overlap** *(Must always be smaller than Chunk Size)* **:**", value=50, show_value=True, full_width=True), separator = mo.ui.text(placeholder="Define a separator", label="**Separator:**", kind="text", value=" "), paragraph_separator = mo.ui.text(placeholder="Define a paragraph separator", label="**Paragraph Separator:**", kind="text", value="\n\n\n"), secondary_chunking_regex = mo.ui.text(placeholder="Define a secondary chunking regex", label="**Chunking Regex:**", kind="text", value="[^,.;?!]+[,.;?!]?"), include_metadata= mo.ui.checkbox(value=True, label="**Include Metadata**") ) .form(show_clear_button=True, bordered=False) ) return (sentence_splitter_config,) @app.cell def sentence_splitter_instantiation( SentenceSplitter, sentence_splitter_config, ): ### Chunker/Sentence Splitter def simple_whitespace_tokenizer(text): return text.split() if sentence_splitter_config.value is not None: sentence_splitter_config_values = sentence_splitter_config.value validated_chunk_overlap = min(sentence_splitter_config_values.get("chunk_overlap"), int(sentence_splitter_config_values.get("chunk_size") * 0.3)) sentence_splitter = SentenceSplitter( chunk_size=sentence_splitter_config_values.get("chunk_size"), chunk_overlap=validated_chunk_overlap, separator=sentence_splitter_config_values.get("separator"), paragraph_separator=sentence_splitter_config_values.get("paragraph_separator"), secondary_chunking_regex=sentence_splitter_config_values.get("secondary_chunking_regex"), include_metadata=sentence_splitter_config_values.get("include_metadata"), tokenizer=simple_whitespace_tokenizer ) else: sentence_splitter = SentenceSplitter( chunk_size=2048, chunk_overlap=204, separator=" ", paragraph_separator="\n\n\n", secondary_chunking_regex="[^,.;?!]+[,.;?!]?", include_metadata=True, tokenizer=simple_whitespace_tokenizer ) return (sentence_splitter,) @app.cell def text_state(): get_text_state, set_text_state = mo.state(None) return get_text_state, set_text_state @app.cell def chunk_state(): get_chunk_state, set_chunk_state = mo.state(None) return get_chunk_state, set_chunk_state @app.function def chunk_documents(loaded_texts, sentence_splitter, show_progress=True): """ Process each document in the loaded_texts dictionary using the sentence_splitter, with an optional marimo progress bar tracking progress at document level. Args: loaded_texts (dict): Dictionary containing lists of Document objects sentence_splitter: The sentence splitter object with get_nodes_from_documents method show_progress (bool): Whether to show a progress bar during processing Returns: dict: Dictionary with the same structure but containing chunked texts """ chunked_texts_dict = {} # Get the total number of documents across all keys total_docs = sum(len(docs) for docs in loaded_texts.values()) processed_docs = 0 # Process with or without progress bar if show_progress: import marimo as mo # Use progress bar with the total number of documents as total with mo.status.progress_bar( total=total_docs, title="Processing Documents", subtitle="Chunking documents...", completion_title="Processing Complete", completion_subtitle=f"{total_docs} documents processed", remove_on_exit=True ) as bar: # Process each key-value pair in the loaded_texts dictionary for key, documents in loaded_texts.items(): # Update progress bar subtitle to show current key doc_count = len(documents) bar.update(subtitle=f"Chunking {key}... ({doc_count} documents)") # Apply the sentence splitter to each list of documents chunked_texts = sentence_splitter.get_nodes_from_documents( documents, show_progress=False # Disable internal progress to avoid nested bars ) # Store the result with the same key chunked_texts_dict[key] = chunked_texts time.sleep(0.15) # Update progress bar with the number of documents in this batch bar.update(increment=doc_count) processed_docs += doc_count else: # Process without progress bar for key, documents in loaded_texts.items(): chunked_texts = sentence_splitter.get_nodes_from_documents( documents, show_progress=True # Use the internal progress bar if no marimo bar ) chunked_texts_dict[key] = chunked_texts return chunked_texts_dict @app.cell def chunked_nodes(chunked_texts, get_chunk_state, sentence_splitter): if chunked_texts is not None and sentence_splitter: chunked_documents = get_chunk_state() else: chunked_documents = None return (chunked_documents,) @app.cell def prep_cumulative_df(chunked_documents, llamaindex_convert_docs_multi): if chunked_documents is not None: dict_from_nodes = llamaindex_convert_docs_multi(chunked_documents) nodes_from_dict = llamaindex_convert_docs_multi(dict_from_nodes) else: dict_from_nodes = None nodes_from_dict = None return (dict_from_nodes,) @app.cell def chunks_to_process( dict_from_nodes, document_range_stack, get_data_in_range_triplequote, ): if dict_from_nodes is not None and document_range_stack is not None: chunk_dict_df = create_cumulative_dataframe(dict_from_nodes) if document_range_stack.value is not None: chunk_start_idx = document_range_stack.value[0] chunk_end_idx = document_range_stack.value[1] else: chunk_start_idx = 0 chunk_end_idx = len(chunk_dict_df) chunk_range_index = [chunk_start_idx, chunk_end_idx] chunks_dict = get_data_in_range_triplequote(chunk_dict_df, index_range=chunk_range_index, columns_to_include=["text"]) chunks_to_process = chunks_dict['text'] if 'text' in chunks_dict else [] else: chunk_objects = None chunks_dict = None chunks_to_process = None return chunks_dict, chunks_to_process @app.cell def helper_function_doc_formatting(): def llamaindex_convert_docs_multi(items): """ Automatically convert between document objects and dictionaries. This function handles: - Converting dictionaries to document objects - Converting document objects to dictionaries - Processing lists or individual items - Supporting dictionary structures where values are lists of documents Args: items: A document object, dictionary, or list of either. Can also be a dictionary mapping filenames to lists of documents. Returns: Converted item(s) maintaining the original structure """ # Handle empty or None input if not items: return [] # Handle dictionary mapping filenames to document lists (from load_pdf_data) if isinstance(items, dict) and all(isinstance(v, list) for v in items.values()): result = {} for filename, doc_list in items.items(): result[filename] = llamaindex_convert_docs(doc_list) return result # Handle single items (not in a list) if not isinstance(items, list): # Single dictionary to document if isinstance(items, dict): # Determine document class doc_class = None if 'doc_type' in items: import importlib module_path, class_name = items['doc_type'].rsplit('.', 1) module = importlib.import_module(module_path) doc_class = getattr(module, class_name) if not doc_class: from llama_index.core.schema import Document doc_class = Document return doc_class.from_dict(items) # Single document to dictionary elif hasattr(items, 'to_dict'): return items.to_dict() # Return as is if can't convert return items # Handle list input result = [] # Handle empty list if len(items) == 0: return result # Determine the type of conversion based on the first non-None item first_item = next((item for item in items if item is not None), None) # If we found no non-None items, return empty list if first_item is None: return result # Convert dictionaries to documents if isinstance(first_item, dict): # Get the right document class from the items themselves doc_class = None # Try to get doc class from metadata if available if 'doc_type' in first_item: import importlib module_path, class_name = first_item['doc_type'].rsplit('.', 1) module = importlib.import_module(module_path) doc_class = getattr(module, class_name) if not doc_class: # Fallback to default Document class from llama_index from llama_index.core.schema import Document doc_class = Document # Convert each dictionary to document for item in items: if isinstance(item, dict): result.append(doc_class.from_dict(item)) elif item is None: result.append(None) elif isinstance(item, list): result.append(llamaindex_convert_docs(item)) else: result.append(item) # Convert documents to dictionaries else: for item in items: if hasattr(item, 'to_dict'): result.append(item.to_dict()) elif item is None: result.append(None) elif isinstance(item, list): result.append(llamaindex_convert_docs(item)) else: result.append(item) return result def llamaindex_convert_docs(items): """ Automatically convert between document objects and dictionaries. Args: items: A list of document objects or dictionaries Returns: List of converted items (dictionaries or document objects) """ result = [] # Handle empty or None input if not items: return result # Determine the type of conversion based on the first item if isinstance(items[0], dict): # Get the right document class from the items themselves # Look for a 'doc_type' or '__class__' field in the dictionary doc_class = None # Try to get doc class from metadata if available if 'doc_type' in items[0]: import importlib module_path, class_name = items[0]['doc_type'].rsplit('.', 1) module = importlib.import_module(module_path) doc_class = getattr(module, class_name) if not doc_class: # Fallback to default Document class from llama_index from llama_index.core.schema import Document doc_class = Document # Convert dictionaries to documents for item in items: if isinstance(item, dict): result.append(doc_class.from_dict(item)) else: # Convert documents to dictionaries for item in items: if hasattr(item, 'to_dict'): result.append(item.to_dict()) return result return (llamaindex_convert_docs_multi,) @app.cell def helper_function_create_df(): def create_document_dataframes(dict_from_docs): """ Creates a pandas DataFrame for each file in the dictionary. Args: dict_from_docs: Dictionary mapping filenames to lists of documents Returns: List of pandas DataFrames, each representing all documents from a single file """ dataframes = [] for filename, docs in dict_from_docs.items(): # Create a list to hold all document records for this file file_records = [] for i, doc in enumerate(docs): # Convert the document to a format compatible with DataFrame if hasattr(doc, 'to_dict'): doc_data = doc.to_dict() elif isinstance(doc, dict): doc_data = doc else: doc_data = {'content': str(doc)} # Add document index information doc_data['doc_index'] = i # Add to the list of records for this file file_records.append(doc_data) # Create a single DataFrame for all documents in this file if file_records: df = pd.DataFrame(file_records) df['filename'] = filename # Add filename as a column dataframes.append(df) return dataframes def create_dataframe_previews(dataframe_list, page_size=5): """ Creates a list of mo.ui.dataframe components, one for each DataFrame in the input list. Args: dataframe_list: List of pandas DataFrames (output from create_document_dataframes) page_size: Number of rows to show per page for each component Returns: List of mo.ui.dataframe components """ # Create a list of mo.ui.dataframe components preview_components = [] for df in dataframe_list: # Create a mo.ui.dataframe component for this DataFrame preview = mo.ui.dataframe(df, page_size=page_size) preview_components.append(preview) return preview_components return @app.cell def helper_function_chart_preparation(): import altair as alt import numpy as np import plotly.express as px from sklearn.manifold import TSNE def prepare_embedding_data(embeddings, texts, model_id=None, embedding_dimensions=None): """ Prepare embedding data for visualization Args: embeddings: List of embeddings arrays texts: List of text strings model_id: Embedding model ID (optional) embedding_dimensions: Embedding dimensions (optional) Returns: DataFrame with processed data and metadata """ # Flatten embeddings (in case they're nested) flattened_embeddings = [] for emb in embeddings: if isinstance(emb, list) and len(emb) > 0 and isinstance(emb[0], list): flattened_embeddings.append(emb[0]) # Take first element if nested else: flattened_embeddings.append(emb) # Convert to numpy array embedding_array = np.array(flattened_embeddings) # Apply dimensionality reduction (t-SNE) tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(embedding_array)-1)) reduced_embeddings = tsne.fit_transform(embedding_array) # Create truncated texts for display truncated_texts = [text[:50] + "..." if len(text) > 50 else text for text in texts] # Create dataframe for visualization df = pd.DataFrame({ "x": reduced_embeddings[:, 0], "y": reduced_embeddings[:, 1], "text": truncated_texts, "full_text": texts, "index": range(len(texts)) }) # Add metadata metadata = { "model_id": model_id, "embedding_dimensions": embedding_dimensions } return df, metadata def create_embedding_chart(df, metadata=None): """ Create an Altair chart for embedding visualization Args: df: DataFrame with x, y coordinates and text metadata: Dictionary with model_id and embedding_dimensions Returns: Altair chart """ model_id = metadata.get("model_id") if metadata else None embedding_dimensions = metadata.get("embedding_dimensions") if metadata else None selection = alt.selection_multi(fields=['index']) base = alt.Chart(df).encode( x=alt.X("x:Q", title="Dimension 1"), y=alt.Y("y:Q", title="Dimension 2"), tooltip=["text", "index"] ) points = base.mark_circle(size=100).encode( color=alt.Color("index:N", legend=None), opacity=alt.condition(selection, alt.value(1), alt.value(0.2)) ).add_selection(selection) # Add this line to apply the selection text = base.mark_text(align="left", dx=7).encode( text="index:N" ) return (points + text).properties( width=700, height=500, title=f"Embedding Visualization{f' - Model: {model_id}' if model_id else ''}{f' ({embedding_dimensions} dimensions)' if embedding_dimensions else ''}" ).interactive() def show_selected_text(indices, texts): """ Create markdown display for selected texts Args: indices: List of selected indices texts: List of all texts Returns: Markdown string """ if not indices: return "No text selected" selected_texts = [texts[i] for i in indices if i < len(texts)] return "\n\n".join([f"**Document {i}**:\n{text}" for i, text in zip(indices, selected_texts)]) def prepare_embedding_data_3d(embeddings, texts, model_id=None, embedding_dimensions=None): """ Prepare embedding data for 3D visualization Args: embeddings: List of embeddings arrays texts: List of text strings model_id: Embedding model ID (optional) embedding_dimensions: Embedding dimensions (optional) Returns: DataFrame with processed data and metadata """ # Flatten embeddings (in case they're nested) flattened_embeddings = [] for emb in embeddings: if isinstance(emb, list) and len(emb) > 0 and isinstance(emb[0], list): flattened_embeddings.append(emb[0]) else: flattened_embeddings.append(emb) # Convert to numpy array embedding_array = np.array(flattened_embeddings) # Handle the case of a single embedding differently if len(embedding_array) == 1: # For a single point, we don't need t-SNE, just use a fixed position reduced_embeddings = np.array([[0.0, 0.0, 0.0]]) else: # Apply dimensionality reduction to 3D # Fix: Ensure perplexity is at least 1.0 perplexity_value = max(1.0, min(30, len(embedding_array)-1)) tsne = TSNE(n_components=3, random_state=42, perplexity=perplexity_value) reduced_embeddings = tsne.fit_transform(embedding_array) # Format texts for display formatted_texts = [] for text in texts: # Truncate if needed if len(text) > 500: text = text[:500] + "..." # Insert line breaks for wrapping wrapped_text = "" for i in range(0, len(text), 50): wrapped_text += text[i:i+50] + "
" formatted_texts.append(""+wrapped_text+"") # Create dataframe for visualization df = pd.DataFrame({ "x": reduced_embeddings[:, 0], "y": reduced_embeddings[:, 1], "z": reduced_embeddings[:, 2], "text": formatted_texts, "full_text": texts, "index": range(len(texts)), "embedding": flattened_embeddings # Store the original embeddings for later use }) # Add metadata metadata = { "model_id": model_id, "embedding_dimensions": embedding_dimensions } return df, metadata def create_3d_embedding_chart(df, metadata=None, chart_width=1200, chart_height=800, marker_size_var: int=3): """ Create a 3D Plotly chart for embedding visualization with proximity-based coloring """ model_id = metadata.get("model_id") if metadata else None embedding_dimensions = metadata.get("embedding_dimensions") if metadata else None # Calculate the proximity between points from scipy.spatial.distance import pdist, squareform # Get the coordinates as a numpy array coords = df[['x', 'y', 'z']].values # Calculate pairwise distances dist_matrix = squareform(pdist(coords)) # For each point, find its average distance to all other points avg_distances = np.mean(dist_matrix, axis=1) # Add this to the dataframe - smaller values = closer to other points df['proximity'] = avg_distances # Create 3D scatter plot with proximity-based coloring fig = px.scatter_3d( df, x='x', y='y', z='z', # x='petal_length', # Changed from 'x' to 'petal_length' # y='petal_width', # Changed from 'y' to 'petal_width' # z='petal_height', color='proximity', # Color based on proximity color_continuous_scale='Viridis_r', # Reversed so closer points are warmer colors hover_data=['text', 'index', 'proximity'], labels={'x': 'Dimension 1', 'y': 'Dimension 2', 'z': 'Dimension 3', 'proximity': 'Avg Distance'}, # labels={'x': 'Dimension 1', 'y': 'Dimension 2', 'z': 'Dimension 3', 'proximity': 'Avg Distance'}, title=f"3D Embedding Visualization{f' - Model: {model_id}' if model_id else ''}{f' ({embedding_dimensions} dimensions)' if embedding_dimensions else ''}", text='index', # size_max=marker_size_var ) # Update marker size and layout # fig.update_traces(marker=dict(size=3), selector=dict(mode='markers')) fig.update_traces( marker=dict( size=marker_size_var, # Very small marker size opacity=0.7, # Slightly transparent symbol="diamond", # Use circle markers (other options: "square", "diamond", "cross", "x") line=dict( width=0.5, # Very thin border color="white" # White outline makes small dots more visible ) ), textfont=dict( color="rgba(255, 255, 255, 0.3)", size=8 ), # hovertemplate="index=%{text}
%{customdata[0]}

Avg Distance=%{customdata[2]:.4f}", ### Hover Changes hovertemplate="text:
%{customdata[0]}
index: %{text}

Avg Distance: %{customdata[2]:.4f}", hoverinfo="text+name", hoverlabel=dict( bgcolor="white", # White background for hover labels font_size=12 # Font size for hover text ), selector=dict(type='scatter3d') ) # Keep your existing layout settings fig.update_layout( scene=dict( xaxis=dict( title='Dimension 1', nticks=40, backgroundcolor="rgb(10, 10, 20, 0.1)", gridcolor="white", showbackground=True, gridwidth=0.35, zerolinecolor="white", ), yaxis=dict( title='Dimension 2', nticks=40, backgroundcolor="rgb(10, 10, 20, 0.1)", gridcolor="white", showbackground=True, gridwidth=0.35, zerolinecolor="white", ), zaxis=dict( title='Dimension 3', nticks=40, backgroundcolor="rgb(10, 10, 20, 0.1)", gridcolor="white", showbackground=True, gridwidth=0.35, zerolinecolor="white", ), # Control camera view angle camera=dict( up=dict(x=0, y=0, z=1), center=dict(x=0, y=0, z=0), eye=dict(x=1.25, y=1.25, z=1.25), ), aspectratio=dict(x=1, y=1, z=1), aspectmode='data' ), width=int(chart_width), height=int(chart_height), margin=dict(r=20, l=10, b=10, t=50), paper_bgcolor="rgb(0, 0, 0)", plot_bgcolor="rgb(0, 0, 0)", coloraxis_colorbar=dict( title="Average Distance", thicknessmode="pixels", thickness=20, lenmode="pixels", len=400, yanchor="top", y=1, ticks="outside", dtick=0.1 ) ) return fig return create_3d_embedding_chart, prepare_embedding_data_3d @app.cell def helper_function_text_preparation(): def convert_table_to_json_docs(df, selected_columns=None): """ Convert a pandas DataFrame or dictionary to a list of JSON documents. Dynamically includes columns based on user selection. Column names are standardized to lowercase with underscores instead of spaces and special characters removed. Args: df: The DataFrame or dictionary to process selected_columns: List of column names to include in the output documents Returns: list: A list of dictionaries, each representing a row as a JSON document """ import pandas as pd import re def standardize_key(key): """Convert a column name to lowercase with underscores instead of spaces and no special characters""" if not isinstance(key, str): return str(key).lower() # Replace spaces with underscores and convert to lowercase key = key.lower().replace(' ', '_') # Remove special characters (keeping alphanumeric and underscores) return re.sub(r'[^\w]', '', key) # Handle case when input is a dictionary if isinstance(df, dict): # Filter the dictionary to include only selected columns if selected_columns: return [{standardize_key(k): df.get(k, None) for k in selected_columns}] else: # If no columns selected, return all key-value pairs with standardized keys return [{standardize_key(k): v for k, v in df.items()}] # Handle case when df is None if df is None: return [] # Ensure df is a DataFrame if not isinstance(df, pd.DataFrame): try: df = pd.DataFrame(df) except: return [] # Return empty list if conversion fails # Now check if DataFrame is empty if df.empty: return [] # If no columns are specifically selected, use all available columns if not selected_columns or not isinstance(selected_columns, list) or len(selected_columns) == 0: selected_columns = list(df.columns) # Determine which columns exist in the DataFrame available_columns = [] columns_lower = {col.lower(): col for col in df.columns if isinstance(col, str)} for col in selected_columns: if col in df.columns: available_columns.append(col) elif isinstance(col, str) and col.lower() in columns_lower: available_columns.append(columns_lower[col.lower()]) # If no valid columns found, return empty list if not available_columns: return [] # Process rows json_docs = [] for _, row in df.iterrows(): doc = {} for col in available_columns: value = row[col] # Standardize the column name when adding to document std_col = standardize_key(col) doc[std_col] = None if pd.isna(value) else value json_docs.append(doc) return json_docs def get_column_values(df, columns_to_include): """ Extract values from specified columns of a dataframe as lists. Args: df: A pandas DataFrame columns_to_include: A list of column names to extract Returns: Dictionary with column names as keys and their values as lists """ result = {} # Validate that columns exist in the dataframe valid_columns = [col for col in columns_to_include if col in df.columns] invalid_columns = set(columns_to_include) - set(valid_columns) if invalid_columns: print(f"Warning: These columns don't exist in the dataframe: {list(invalid_columns)}") # Extract values for each valid column for col in valid_columns: result[col] = df[col].tolist() return result def get_data_in_range(doc_dict_df, index_range, columns_to_include): """ Extract values from specified columns of a dataframe within a given index range. Args: doc_dict_df: The pandas DataFrame to extract data from index_range: An integer specifying the number of rows to include (from 0 to index_range-1) columns_to_include: A list of column names to extract Returns: Dictionary with column names as keys and their values (within the index range) as lists """ # Validate the index range max_index = len(doc_dict_df) if index_range <= 0: print(f"Warning: Invalid index range {index_range}. Must be positive.") return {} # Adjust index_range if it exceeds the dataframe length if index_range > max_index: print(f"Warning: Index range {index_range} exceeds dataframe length {max_index}. Using maximum length.") index_range = max_index # Slice the dataframe to get rows from 0 to index_range-1 df_subset = doc_dict_df.iloc[:index_range] # Use the provided get_column_values function to extract column data return get_column_values(df_subset, columns_to_include) def get_data_in_range_triplequote(doc_dict_df, index_range, columns_to_include): """ Extract values from specified columns of a dataframe within a given index range. Wraps string values with triple quotes and escapes URLs. Args: doc_dict_df: The pandas DataFrame to extract data from index_range: A list of two integers specifying the start and end indices of rows to include (e.g., [0, 10] includes rows from index 0 to 9 inclusive) columns_to_include: A list of column names to extract """ # Validate the index range start_idx, end_idx = index_range max_index = len(doc_dict_df) # Validate start index if start_idx < 0: print(f"Warning: Invalid start index {start_idx}. Using 0 instead.") start_idx = 0 # Validate end index if end_idx <= start_idx: print(f"Warning: End index {end_idx} must be greater than start index {start_idx}. Using {start_idx + 1} instead.") end_idx = start_idx + 1 # Adjust end index if it exceeds the dataframe length if end_idx > max_index: print(f"Warning: End index {end_idx} exceeds dataframe length {max_index}. Using maximum length.") end_idx = max_index # Slice the dataframe to get rows from start_idx to end_idx-1 # Using .loc with slice to preserve original indices df_subset = doc_dict_df.iloc[start_idx:end_idx] # Use the provided get_column_values function to extract column data result = get_column_values(df_subset, columns_to_include) # Process each string result to wrap in triple quotes for col in result: if isinstance(result[col], list): # Create a new list with items wrapped in triple quotes processed_items = [] for item in result[col]: if isinstance(item, str): # Replace http:// and https:// with escaped versions item = item.replace("http://", "http\\://").replace("https://", "https\\://") # processed_items.append('"""' + item + '"""') processed_items.append(item) else: processed_items.append(item) result[col] = processed_items return result return (get_data_in_range_triplequote,) @app.cell def prepare_doc_select(sentence_splitter_config): def prepare_document_selection(node_dict): """ Creates document selection UI component. Args: node_dict: Dictionary mapping filenames to lists of documents Returns: mo.ui component for document selection """ # Calculate total number of documents across all files total_docs = sum(len(docs) for docs in node_dict.values()) # Create a combined DataFrame of all documents for table selection all_docs_records = [] doc_index_global = 0 for filename, docs in node_dict.items(): for i, doc in enumerate(docs): # Convert the document to a format compatible with DataFrame if hasattr(doc, 'to_dict'): doc_data = doc.to_dict() elif isinstance(doc, dict): doc_data = doc else: doc_data = {'content': str(doc)} # Add metadata doc_data['filename'] = filename doc_data['doc_index'] = i doc_data['global_index'] = doc_index_global all_docs_records.append(doc_data) doc_index_global += 1 # Create UI component stop_value = max(total_docs, 1) llama_docs = mo.ui.range_slider( start=1, stop=stop_value, step=1, full_width=True, show_value=True, label="**Select a Range of Chunks to Visualize:**" ).form(submit_button_disabled=check_state(sentence_splitter_config.value)) return llama_docs return (prepare_document_selection,) @app.cell def document_range_selection( dict_from_nodes, prepare_document_selection, set_range_slider_state, ): if dict_from_nodes is not None: llama_docs = prepare_document_selection(dict_from_nodes) set_range_slider_state(llama_docs) else: bare_dict = {} llama_docs = prepare_document_selection(bare_dict) return @app.function def create_cumulative_dataframe(dict_from_docs): """ Creates a cumulative DataFrame from a nested dictionary of documents. Args: dict_from_docs: Dictionary mapping filenames to lists of documents Returns: DataFrame with all documents flattened with global indices """ # Create a list to hold all document records all_records = [] global_idx = 1 # Start from 1 to match range slider expectations for filename, docs in dict_from_docs.items(): for i, doc in enumerate(docs): # Convert the document to a dict format if hasattr(doc, 'to_dict'): doc_data = doc.to_dict() elif isinstance(doc, dict): doc_data = doc.copy() else: doc_data = {'content': str(doc)} # Add additional metadata doc_data['filename'] = filename doc_data['doc_index'] = i doc_data['global_index'] = global_idx # If there's 'content' but no 'text', create a 'text' field if 'content' in doc_data and 'text' not in doc_data: doc_data['text'] = doc_data['content'] all_records.append(doc_data) global_idx += 1 # Create DataFrame from all records return pd.DataFrame(all_records) @app.function def create_stats(texts_dict, bordered=False, object_names=None, group_by_row=False, items_per_row=6, gap=2, label="Chunk"): """ Create a list of stat objects for each item in the specified dictionary. Parameters: - texts_dict (dict): Dictionary containing the text data - bordered (bool): Whether the stats should be bordered - object_names (list or tuple): Two object names to use for label and value [label_object, value_object] - group_by_row (bool): Whether to group stats in rows (horizontal stacks) - items_per_row (int): Number of stat objects per row when group_by_row is True Returns: - object: A vertical stack of stat objects or rows of stat objects """ if not object_names or len(object_names) < 2: raise ValueError("You must provide two object names as a list or tuple") label_object = object_names[0] value_object = object_names[1] # Validate that both objects exist in the dictionary if label_object not in texts_dict: raise ValueError(f"Label object '{label_object}' not found in texts_dict") if value_object not in texts_dict: raise ValueError(f"Value object '{value_object}' not found in texts_dict") # Determine how many items to process (based on the label object length) num_items = len(texts_dict[label_object]) # Create individual stat objects individual_stats = [] for i in range(num_items): stat = mo.stat( label=texts_dict[label_object][i], value=f"{label} Number: {len(texts_dict[value_object][i])}", bordered=bordered ) individual_stats.append(stat) # If grouping is not enabled, just return a vertical stack of all stats if not group_by_row: return mo.vstack(individual_stats, wrap=False) # Group stats into rows based on items_per_row rows = [] for i in range(0, num_items, items_per_row): # Get a slice of stats for this row (up to items_per_row items) row_stats = individual_stats[i:i+items_per_row] # Create a horizontal stack for this row widths = [0.35] * len(row_stats) row = mo.hstack(row_stats, gap=gap, align="start", justify="center", widths=widths) rows.append(row) # Return a vertical stack of all rows return mo.vstack(rows) @app.cell def prepare_chart_embeddings( chunks_to_process, emb_model, emb_model_emb_dim, get_embedding_state, prepare_embedding_data_3d, ): # chart_dataframe, chart_metadata = None, None if chunks_to_process is not None and get_embedding_state() is not None: chart_dataframe, chart_metadata = prepare_embedding_data_3d( get_embedding_state(), chunks_to_process, model_id=emb_model, embedding_dimensions=emb_model_emb_dim ) else: chart_dataframe, chart_metadata = None, None return chart_dataframe, chart_metadata @app.cell def chart_dims(): chart_dimensions = ( mo.md(''' > **Adjust Chart Window** {chart_height} {chat_width} ''').batch( chart_height = mo.ui.slider(start=500, step=30, stop=1000, label="**Height:**", value=800, show_value=True), chat_width = mo.ui.slider(start=900, step=50, stop=1400, label="**Width:**", value=1200, show_value=True) ) ) return (chart_dimensions,) @app.cell def chart_dim_values(chart_dimensions): chart_height = chart_dimensions.value['chart_height'] chart_width = chart_dimensions.value['chat_width'] return chart_height, chart_width @app.cell def create_baseline_chart( chart_dataframe, chart_height, chart_metadata, chart_width, create_3d_embedding_chart, ): if chart_dataframe is not None and chart_metadata is not None: emb_plot = create_3d_embedding_chart(chart_dataframe, chart_metadata, chart_width, chart_height, marker_size_var=9) chart = mo.ui.plotly(emb_plot) else: emb_plot = None chart = None return (emb_plot,) @app.cell def test_query(get_chunk_state): placeholder = """How can i use watsonx.data to perform vector search?""" query = mo.ui.text_area(label="**Write text to check:**", full_width=True, rows=8, value=placeholder).form(show_clear_button=True, submit_button_disabled=check_state(get_chunk_state())) return (query,) @app.cell def query_stack(chart_dimensions, query): # query_stack = mo.hstack([query], justify="space-around", align="center", widths=[0.65]) query_stack = mo.hstack([query, chart_dimensions], justify="space-around", align="center", gap=15) return (query_stack,) @app.function def check_state(variable): return variable is None @app.cell def helper_function_add_query_to_chart(): def add_query_to_embedding_chart(existing_chart, query_coords, query_text, marker_size=12): """ Add a query point to an existing 3D embedding chart as a large red dot. Args: existing_chart: The existing plotly figure or chart data query_coords: Dictionary with 'x', 'y', 'z' coordinates for the query point query_text: Text of the query to display on hover marker_size: Size of the query marker (default: 18, typically 2x other markers) Returns: A modified plotly figure with the query point added as a red dot """ import plotly.graph_objects as go # Create a deep copy of the existing chart to avoid modifying the original import copy chart_copy = copy.deepcopy(existing_chart) # Handle case where chart_copy is a dictionary or list (from mo.ui.plotly) if isinstance(chart_copy, (dict, list)): # Create a new plotly figure from the data import plotly.graph_objects as go if isinstance(chart_copy, list): # If it's a list, assume it's a list of traces fig = go.Figure(data=chart_copy) else: # If it's a dict with 'data' and 'layout' fig = go.Figure(data=chart_copy.get('data', []), layout=chart_copy.get('layout', {})) chart_copy = fig # Create the query trace query_trace = go.Scatter3d( x=[query_coords['x']], y=[query_coords['y']], z=[query_coords['z']], mode='markers', name='Query', marker=dict( size=marker_size, # Typically 2x the size of other markers color='red', # Bright red color symbol='circle', # Circle shape opacity=0.70, # Fully opaque line=dict( width=1, # Thin white border color='white' ) ), # text=['Query: ' + query_text], text=['Query:
' + '
'.join([query_text[i:i+50] for i in range(0, len(query_text), 50)])], ### Text Wrapping hoverinfo="text+name" ) # Add the query trace to the chart copy chart_copy.add_trace(query_trace) return chart_copy def get_query_coordinates(reference_embeddings=None, query_embedding=None): """ Calculate appropriate coordinates for a query point based on reference embeddings. This function handles several scenarios: 1. If both reference embeddings and query embedding are provided, it places the query near similar documents. 2. If only reference embeddings are provided, it places the query at a visible location near the center of the chart. 3. If neither are provided, it returns default origin coordinates. Args: reference_embeddings: DataFrame with x, y, z coordinates from the main chart query_embedding: The embedding vector of the query Returns: Dictionary with x, y, z coordinates for the query point """ import numpy as np # Default coordinates (origin with slight offset) default_coords = {'x': 0.0, 'y': 0.0, 'z': 0.0} # If we don't have reference embeddings, return default if reference_embeddings is None or len(reference_embeddings) == 0: return default_coords # If we have reference embeddings but no query embedding, # position at a visible location near the center if query_embedding is None: center_coords = { 'x': reference_embeddings['x'].mean(), 'y': reference_embeddings['y'].mean(), 'z': reference_embeddings['z'].mean() } return center_coords # If we have both reference embeddings and query embedding, # try to position near similar documents try: from sklearn.metrics.pairwise import cosine_similarity # Check if original embeddings are in the dataframe if 'embedding' in reference_embeddings.columns: # Get all document embeddings as a 2D array if isinstance(reference_embeddings['embedding'].iloc[0], list): doc_embeddings = np.array(reference_embeddings['embedding'].tolist()) else: doc_embeddings = np.array([emb for emb in reference_embeddings['embedding'].values]) # Reshape query embedding for comparison query_emb_array = np.array(query_embedding) if query_emb_array.ndim == 1: query_emb_array = query_emb_array.reshape(1, -1) # Calculate cosine similarities similarities = cosine_similarity(query_emb_array, doc_embeddings)[0] # Find the closest document closest_idx = np.argmax(similarities) # Use the position of the closest document, with slight offset for visibility query_coords = { 'x': reference_embeddings['x'].iloc[closest_idx] + 0.2, 'y': reference_embeddings['y'].iloc[closest_idx] + 0.2, 'z': reference_embeddings['z'].iloc[closest_idx] + 0.2 } return query_coords except Exception as e: print(f"Error positioning query near similar documents: {e}") # Fallback to center position if similarity calculation fails center_coords = { 'x': reference_embeddings['x'].mean(), 'y': reference_embeddings['y'].mean(), 'z': reference_embeddings['z'].mean() } return center_coords return add_query_to_embedding_chart, get_query_coordinates @app.cell def combined_chart_visualization( add_query_to_embedding_chart, chart_dataframe, emb_plot, embedding, get_query_coordinates, get_query_state, query, set_chart_state, set_query_state, ): # Usage with highlight_closest=True if chart_dataframe is not None and query.value: with mo.status.spinner(title="Embedding Query...", remove_on_exit=True) as _spinner: query_emb = embedding.embed_documents([query.value]) set_query_state(query_emb) _spinner.update("Preparing Query Coordinates") # --- --- --- time.sleep(1.0) # Get appropriate coordinates for the query query_coords = get_query_coordinates( reference_embeddings=chart_dataframe, query_embedding=get_query_state() ) _spinner.update("Adding Query to Chart") # --- --- --- time.sleep(1.0) # Add the query to the chart with closest points highlighted result = add_query_to_embedding_chart( existing_chart=emb_plot, query_coords=query_coords, query_text=query.value, ) chart_with_query = result _spinner.update("Preparing Visualization") # --- --- --- time.sleep(1.0) # Create the visualization combined_viz = mo.ui.plotly(chart_with_query) set_chart_state(combined_viz) _spinner.update("Done") # --- --- --- else: combined_viz = None return @app.cell def _(): get_range_slider_state, set_range_slider_state = mo.state(None) return get_range_slider_state, set_range_slider_state @app.cell def _(get_range_slider_state): if get_range_slider_state() is not None: document_range_stack = get_range_slider_state() else: document_range_stack = None return (document_range_stack,) @app.cell def _(): get_chart_state, set_chart_state = mo.state(None) return get_chart_state, set_chart_state @app.cell def _(get_chart_state, query): if query.value is not None: chart_visualization = get_chart_state() else: chart_visualization = None return (chart_visualization,) @app.cell def c(document_range_stack): chart_range_selection = mo.hstack([document_range_stack], justify="space-around", align="center", widths=[0.65]) return (chart_range_selection,) if __name__ == "__main__": app.run()