|
import marimo |
|
|
|
__generated_with = "0.13.0" |
|
app = marimo.App(width="full") |
|
|
|
with app.setup: |
|
|
|
import marimo as mo |
|
from typing import Dict, Optional, List, Union, Any |
|
from ibm_watsonx_ai import APIClient, Credentials |
|
from pathlib import Path |
|
import pandas as pd |
|
import mimetypes |
|
import requests |
|
import zipfile |
|
import tempfile |
|
import certifi |
|
import base64 |
|
import polars |
|
import nltk |
|
import time |
|
import json |
|
import ast |
|
import os |
|
import io |
|
import re |
|
|
|
def get_iam_token(api_key): |
|
return requests.post( |
|
'https://iam.cloud.ibm.com/identity/token', |
|
headers={'Content-Type': 'application/x-www-form-urlencoded'}, |
|
data={'grant_type': 'urn:ibm:params:oauth:grant-type:apikey', 'apikey': api_key}, |
|
verify=certifi.where() |
|
).json()['access_token'] |
|
|
|
def setup_task_credentials(client): |
|
|
|
existing_credentials = client.task_credentials.get_details() |
|
|
|
|
|
if "resources" in existing_credentials and existing_credentials["resources"]: |
|
for cred in existing_credentials["resources"]: |
|
cred_id = client.task_credentials.get_id(cred) |
|
client.task_credentials.delete(cred_id) |
|
|
|
|
|
return client.task_credentials.store() |
|
|
|
def get_cred_value(key, creds_var_name="baked_in_creds", default=""): |
|
""" |
|
Helper function to safely get a value from a credentials dictionary. |
|
|
|
Args: |
|
key: The key to look up in the credentials dictionary. |
|
creds_var_name: The variable name of the credentials dictionary. |
|
default: The default value to return if the key is not found. |
|
|
|
Returns: |
|
The value from the credentials dictionary if it exists and contains the key, |
|
otherwise returns the default value. |
|
""" |
|
|
|
if creds_var_name in globals(): |
|
creds_dict = globals()[creds_var_name] |
|
if isinstance(creds_dict, dict) and key in creds_dict: |
|
return creds_dict[key] |
|
return default |
|
|
|
@app.cell |
|
def client_variables(client_instantiation_form): |
|
if client_instantiation_form.value: |
|
client_setup = client_instantiation_form.value |
|
else: |
|
client_setup = None |
|
|
|
|
|
if client_setup is not None: |
|
wx_url = client_setup["wx_region"] |
|
wx_api_key = client_setup["wx_api_key"].strip() |
|
os.environ["WATSONX_APIKEY"] = wx_api_key |
|
|
|
if client_setup["project_id"] is not None: |
|
project_id = client_setup["project_id"].strip() |
|
else: |
|
project_id = None |
|
|
|
if client_setup["space_id"] is not None: |
|
space_id = client_setup["space_id"].strip() |
|
else: |
|
space_id = None |
|
|
|
else: |
|
os.environ["WATSONX_APIKEY"] = "" |
|
project_id = None |
|
space_id = None |
|
wx_api_key = None |
|
wx_url = None |
|
return client_setup, project_id, space_id, wx_api_key, wx_url |
|
|
|
|
|
@app.cell |
|
def _(client_setup, wx_api_key): |
|
if client_setup: |
|
token = get_iam_token(wx_api_key) |
|
else: |
|
token = None |
|
return |
|
|
|
@app.cell |
|
def _(): |
|
baked_in_creds = { |
|
"purpose": "", |
|
"api_key": "", |
|
"project_id": "", |
|
"space_id": "", |
|
} |
|
return baked_in_creds |
|
|
|
|
|
@app.cell |
|
def client_instantiation( |
|
client_setup, |
|
project_id, |
|
space_id, |
|
wx_api_key, |
|
wx_url, |
|
): |
|
|
|
if client_setup: |
|
wx_credentials = Credentials( |
|
url=wx_url, |
|
api_key=wx_api_key |
|
) |
|
|
|
if project_id: |
|
project_client = APIClient(credentials=wx_credentials, project_id=project_id) |
|
else: |
|
project_client = None |
|
|
|
if space_id: |
|
deployment_client = APIClient(credentials=wx_credentials, space_id=space_id) |
|
else: |
|
deployment_client = None |
|
|
|
|
|
if project_client is not None: |
|
task_credentials_details = setup_task_credentials(project_client) |
|
else: |
|
task_credentials_details = setup_task_credentials(deployment_client) |
|
|
|
else: |
|
wx_credentials = None |
|
project_client = None |
|
deployment_client = None |
|
task_credentials_details = None |
|
|
|
client_status = mo.md("### Client Instantiation Status will turn Green When Ready") |
|
|
|
if project_client is not None or deployment_client is not None: |
|
client_callout_kind = "success" |
|
else: |
|
client_callout_kind = "neutral" |
|
return ( |
|
client_callout_kind, |
|
client_status, |
|
deployment_client, |
|
project_client, |
|
) |
|
|
|
|
|
@app.cell |
|
def _(): |
|
mo.md( |
|
r""" |
|
#watsonx.ai Embedding Visualizer - Marimo Notebook |
|
|
|
#### This marimo notebook can be used to develop a more intuitive understanding of how vector embeddings work by creating a 3D visualization of vector embeddings based on chunked PDF document pages. |
|
|
|
#### It can also serve as a useful tool for identifying gaps in model choice, chunking strategy or contents used in building collections by showing how far you are from what you want. |
|
<br> |
|
|
|
/// admonition |
|
Created by ***Milan Mrdenovic*** [[email protected]] for IBM Ecosystem Client Engineering, NCEE - ***version 5.3** - 20.04.2025* |
|
/// |
|
|
|
|
|
>Licensed under apache 2.0, users hold full accountability for any use or modification of the code. |
|
><br>This asset is part of a set meant to support IBMers, IBM Partners, Clients in developing understanding of how to better utilize various watsonx features and generative AI as a subject matter. |
|
|
|
<br> |
|
""" |
|
) |
|
return |
|
|
|
|
|
@app.cell |
|
def _(): |
|
mo.md("""###Part 1 - Client Setup, File Preparation and Chunking""") |
|
return |
|
|
|
|
|
@app.cell |
|
def accordion_client_setup(client_selector, client_stack): |
|
ui_accordion_part_1_1 = mo.accordion( |
|
{ |
|
"Instantiate Client": mo.vstack([client_stack, client_selector], align="center"), |
|
} |
|
) |
|
|
|
ui_accordion_part_1_1 |
|
return |
|
|
|
|
|
@app.cell |
|
def accordion_file_upload(select_stack): |
|
ui_accordion_part_1_2 = mo.accordion( |
|
{ |
|
"Select Model & Upload Files": select_stack |
|
} |
|
) |
|
|
|
ui_accordion_part_1_2 |
|
return |
|
|
|
|
|
@app.cell |
|
def loaded_texts( |
|
create_temp_files_from_uploads, |
|
file_loader, |
|
pdf_reader, |
|
run_upload_button, |
|
set_text_state, |
|
): |
|
if file_loader.value is not None and run_upload_button.value: |
|
filepaths = create_temp_files_from_uploads(file_loader.value) |
|
loaded_texts = load_pdf_data_with_progress(pdf_reader, filepaths, file_loader.value, show_progress=True) |
|
|
|
set_text_state(loaded_texts) |
|
else: |
|
filepaths = None |
|
loaded_texts = None |
|
return |
|
|
|
|
|
@app.cell |
|
def accordion_chunker_setup(chunker_setup): |
|
ui_accordion_part_1_3 = mo.accordion( |
|
{ |
|
"Chunker Setup": chunker_setup |
|
} |
|
) |
|
|
|
ui_accordion_part_1_3 |
|
return |
|
|
|
|
|
@app.cell |
|
def chunk_documents_to_nodes( |
|
get_text_state, |
|
sentence_splitter, |
|
sentence_splitter_config, |
|
set_chunk_state, |
|
): |
|
if sentence_splitter_config.value and sentence_splitter and get_text_state() is not None: |
|
chunked_texts = chunk_documents(get_text_state(), sentence_splitter, show_progress=True) |
|
set_chunk_state(chunked_texts) |
|
else: |
|
chunked_texts = None |
|
return (chunked_texts,) |
|
|
|
|
|
@app.cell |
|
def _(): |
|
mo.md(r"""###Part 2 - Query Setup and Visualization""") |
|
return |
|
|
|
|
|
@app.cell |
|
def accordion_chunk_range(chart_range_selection): |
|
ui_accordion_part_2_1 = mo.accordion( |
|
{ |
|
"Chunk Range Selection": chart_range_selection |
|
} |
|
) |
|
ui_accordion_part_2_1 |
|
return |
|
|
|
|
|
@app.cell |
|
def chunk_embedding( |
|
chunks_to_process, |
|
embedding, |
|
sentence_splitter_config, |
|
set_embedding_state, |
|
): |
|
if sentence_splitter_config.value is not None and chunks_to_process is not None: |
|
with mo.status.spinner(title="Embedding Documents...", remove_on_exit=True) as _spinner: |
|
output_embeddings = embedding.embed_documents(chunks_to_process) |
|
_spinner.update("Almost Done") |
|
time.sleep(1.5) |
|
set_embedding_state(output_embeddings) |
|
_spinner.update("Documents Embedded") |
|
else: |
|
output_embeddings = None |
|
return |
|
|
|
|
|
@app.cell |
|
def preview_chunks(chunks_dict): |
|
if chunks_dict is not None: |
|
stats = create_stats(chunks_dict, |
|
bordered=True, |
|
object_names=['text','text'], |
|
group_by_row=True, |
|
items_per_row=5, |
|
gap=1, |
|
label="Chunk") |
|
ui_chunk_viewer = mo.accordion( |
|
{ |
|
"View Chunks": stats, |
|
} |
|
) |
|
else: |
|
ui_chunk_viewer = None |
|
|
|
ui_chunk_viewer |
|
return |
|
|
|
|
|
@app.cell |
|
def accordion_query_view(chart_visualization, query_stack): |
|
ui_accordion_part_2_2 = mo.accordion( |
|
{ |
|
"Query": mo.vstack([query_stack, mo.hstack([chart_visualization])], align="center", gap=3) |
|
} |
|
) |
|
ui_accordion_part_2_2 |
|
return |
|
|
|
|
|
@app.cell |
|
def chunker_setup(sentence_splitter_config): |
|
chunker_setup = mo.hstack([sentence_splitter_config], justify="space-around", align="center", widths=[0.55]) |
|
return (chunker_setup,) |
|
|
|
|
|
@app.cell |
|
def file_and_model_select( |
|
file_loader, |
|
get_embedding_model_list, |
|
run_upload_button, |
|
): |
|
select_stack = mo.hstack([get_embedding_model_list(), mo.vstack([file_loader, run_upload_button], align="center")], justify="space-around", align="center", widths=[0.3,0.3]) |
|
return (select_stack,) |
|
|
|
|
|
@app.cell |
|
def client_instantiation_form(): |
|
|
|
wx_platform_url = "https://api.dataplatform.cloud.ibm.com" |
|
regions = { |
|
"US": "https://us-south.ml.cloud.ibm.com", |
|
"EU": "https://eu-de.ml.cloud.ibm.com", |
|
"GB": "https://eu-gb.ml.cloud.ibm.com", |
|
"JP": "https://jp-tok.ml.cloud.ibm.com", |
|
"AU": "https://au-syd.ml.cloud.ibm.com", |
|
"CA": "https://ca-tor.ml.cloud.ibm.com" |
|
} |
|
|
|
|
|
client_instantiation_form = ( |
|
mo.md(''' |
|
###**watsonx.ai credentials:** |
|
|
|
{wx_region} |
|
|
|
{wx_api_key} |
|
|
|
{project_id} |
|
|
|
{space_id} |
|
|
|
> You can add either a project_id, space_id or both, **only one is required**. |
|
> If you provide both you can switch the active one in the dropdown. |
|
''') |
|
.batch( |
|
wx_region = mo.ui.dropdown(regions, label="Select your watsonx.ai region:", value="US", searchable=True), |
|
wx_api_key = mo.ui.text(placeholder="Add your IBM Cloud api-key...", label="IBM Cloud Api-key:", |
|
kind="password", value=get_cred_value('api_key', creds_var_name='baked_in_creds')), |
|
project_id = mo.ui.text(placeholder="Add your watsonx.ai project_id...", label="Project_ID:", |
|
kind="text", value=get_cred_value('project_id', creds_var_name='baked_in_creds')), |
|
space_id = mo.ui.text(placeholder="Add your watsonx.ai space_id...", label="Space_ID:", |
|
kind="text", value=get_cred_value('space_id', creds_var_name='baked_in_creds')) |
|
,) |
|
.form(show_clear_button=True, bordered=False) |
|
) |
|
return (client_instantiation_form,) |
|
|
|
|
|
@app.cell |
|
def instantiation_status( |
|
client_callout_kind, |
|
client_instantiation_form, |
|
client_status, |
|
): |
|
client_callout = mo.callout(client_status, kind=client_callout_kind) |
|
client_stack = mo.hstack([client_instantiation_form, client_callout], align="center", justify="space-around", gap=10) |
|
return (client_stack,) |
|
|
|
|
|
@app.cell |
|
def client_selector(deployment_client, project_client): |
|
if deployment_client is not None: |
|
client_options = {"Deployment Client":deployment_client} |
|
|
|
elif project_client is not None: |
|
client_options = {"Project Client":project_client} |
|
|
|
elif project_client is not None and deployment_client is not None: |
|
client_options = {"Project Client":project_client,"Deployment Client":deployment_client} |
|
|
|
else: |
|
client_options = {"No Client": "Instantiate a Client"} |
|
|
|
default_client = next(iter(client_options)) |
|
client_selector = mo.ui.dropdown(client_options, value=default_client, label="**Select your active client:**") |
|
|
|
return (client_selector,) |
|
|
|
|
|
@app.cell |
|
def active_client(client_selector): |
|
client_key = client_selector.value |
|
if client_key == "Instantiate a Client": |
|
client = None |
|
else: |
|
client = client_key |
|
return (client,) |
|
|
|
|
|
@app.cell |
|
def emb_model_selection(client, set_embedding_model_list): |
|
if client is not None: |
|
model_specs = client.foundation_models.get_embeddings_model_specs() |
|
|
|
resources = model_specs["resources"] |
|
|
|
embedding_models = { |
|
"ibm/granite-embedding-107m-multilingual": {"max_tokens": 512, "embedding_dimensions": 384}, |
|
"ibm/granite-embedding-278m-multilingual": {"max_tokens": 512, "embedding_dimensions": 768}, |
|
"ibm/slate-125m-english-rtrvr-v2": {"max_tokens": 512, "embedding_dimensions": 768}, |
|
"ibm/slate-125m-english-rtrvr": {"max_tokens": 512, "embedding_dimensions": 768}, |
|
"ibm/slate-30m-english-rtrvr-v2": {"max_tokens": 512, "embedding_dimensions": 384}, |
|
"ibm/slate-30m-english-rtrvr": {"max_tokens": 512, "embedding_dimensions": 384}, |
|
"sentence-transformers/all-minilm-l6-v2": {"max_tokens": 128, "embedding_dimensions": 384}, |
|
"sentence-transformers/all-minilm-l12-v2": {"max_tokens": 128, "embedding_dimensions": 384}, |
|
"intfloat/multilingual-e5-large": {"max_tokens": 512, "embedding_dimensions": 1024} |
|
} |
|
|
|
|
|
model_id_list = [] |
|
for resource in resources: |
|
model_id_list.append(resource["model_id"]) |
|
|
|
|
|
embedding_model_data = [] |
|
for model_id in model_id_list: |
|
model_entry = {"model_id": model_id} |
|
|
|
|
|
if model_id in embedding_models: |
|
model_entry["max_tokens"] = embedding_models[model_id]["max_tokens"] |
|
model_entry["embedding_dimensions"] = embedding_models[model_id]["embedding_dimensions"] |
|
else: |
|
model_entry["max_tokens"] = 0 |
|
model_entry["embedding_dimensions"] = 0 |
|
|
|
embedding_model_data.append(model_entry) |
|
|
|
embedding_model_selection = mo.ui.table( |
|
embedding_model_data, |
|
selection="single", |
|
label="Select an embedding model to use.", |
|
page_size=30, |
|
initial_selection=[1] |
|
) |
|
set_embedding_model_list(embedding_model_selection) |
|
else: |
|
default_model_data = [{ |
|
"model_id": "ibm/granite-embedding-107m-multilingual", |
|
"max_tokens": 512, |
|
"embedding_dimensions": 384 |
|
}] |
|
|
|
set_embedding_model_list(create_emb_model_selection_table(default_model_data, initial_selection=0, selection_type="single", label="Select a model to use.")) |
|
return |
|
|
|
|
|
@app.function |
|
def create_emb_model_selection_table(model_data, initial_selection=0, selection_type="single", label="Select a model to use."): |
|
embedding_model_selection = mo.ui.table( |
|
model_data, |
|
selection=selection_type, |
|
label=label, |
|
page_size=30, |
|
initial_selection=[initial_selection] |
|
) |
|
return embedding_model_selection |
|
|
|
|
|
@app.cell |
|
def embedding_model(): |
|
get_embedding_model_list, set_embedding_model_list = mo.state(None) |
|
return get_embedding_model_list, set_embedding_model_list |
|
|
|
|
|
@app.cell |
|
def emb_model_parameters(emb_model_max_tk): |
|
from ibm_watsonx_ai.metanames import EmbedTextParamsMetaNames as EmbedParams |
|
if embedding_model is not None: |
|
embed_params = { |
|
EmbedParams.TRUNCATE_INPUT_TOKENS: emb_model_max_tk, |
|
EmbedParams.RETURN_OPTIONS: { |
|
'input_text': True |
|
} |
|
} |
|
else: |
|
embed_params = { |
|
EmbedParams.TRUNCATE_INPUT_TOKENS: 128, |
|
EmbedParams.RETURN_OPTIONS: { |
|
'input_text': True |
|
} |
|
} |
|
return embed_params |
|
|
|
|
|
@app.cell |
|
def emb_model_state(get_embedding_model_list): |
|
embedding_model = get_embedding_model_list() |
|
return (embedding_model,) |
|
|
|
|
|
@app.cell |
|
def emb_model_setup(embedding_model): |
|
if embedding_model is not None: |
|
emb_model = embedding_model.value[0]['model_id'] |
|
emb_model_max_tk = embedding_model.value[0]['max_tokens'] |
|
emb_model_emb_dim = embedding_model.value[0]['embedding_dimensions'] |
|
else: |
|
emb_model = None |
|
emb_model_max_tk = None |
|
emb_model_emb_dim = None |
|
return emb_model, emb_model_emb_dim, emb_model_max_tk |
|
|
|
|
|
@app.cell |
|
def emb_model_instantiation(client, emb_model, embed_params): |
|
from ibm_watsonx_ai.foundation_models import Embeddings |
|
if client is not None: |
|
embedding = Embeddings( |
|
model_id=emb_model, |
|
api_client=client, |
|
params=embed_params, |
|
batch_size=1000, |
|
concurrency_limit=10 |
|
) |
|
else: |
|
embedding = None |
|
return (embedding,) |
|
|
|
|
|
@app.cell |
|
def _(): |
|
get_embedding_state, set_embedding_state = mo.state(None) |
|
return get_embedding_state, set_embedding_state |
|
|
|
|
|
@app.cell |
|
def _(): |
|
get_query_state, set_query_state = mo.state(None) |
|
return get_query_state, set_query_state |
|
|
|
|
|
@app.cell |
|
def file_loader_input(): |
|
file_loader = mo.ui.file( |
|
kind="area", |
|
filetypes=[".pdf"], |
|
label=" Load .pdf files ", |
|
multiple=True |
|
) |
|
return (file_loader,) |
|
|
|
|
|
@app.cell |
|
def file_loader_run(file_loader): |
|
if file_loader.value: |
|
run_upload_button = mo.ui.run_button(label="Load Files") |
|
else: |
|
run_upload_button = mo.ui.run_button(disabled=True, label="Load Files") |
|
return (run_upload_button,) |
|
|
|
|
|
@app.cell |
|
def helper_function_tempfiles(): |
|
def create_temp_files_from_uploads(upload_results) -> List[str]: |
|
""" |
|
Creates temporary files from a tuple of FileUploadResults objects and returns their paths. |
|
Args: |
|
upload_results: Object containing a value attribute that is a tuple of FileUploadResults |
|
Returns: |
|
List of temporary file paths |
|
""" |
|
temp_file_paths = [] |
|
|
|
|
|
num_items = len(upload_results) |
|
|
|
|
|
for i in range(num_items): |
|
result = upload_results[i] |
|
|
|
|
|
temp_dir = tempfile.gettempdir() |
|
file_name = result.name |
|
temp_path = os.path.join(temp_dir, file_name) |
|
|
|
with open(temp_path, 'wb') as temp_file: |
|
temp_file.write(result.contents) |
|
|
|
temp_file_paths.append(temp_path) |
|
|
|
return temp_file_paths |
|
|
|
def cleanup_temp_files(temp_file_paths: List[str]) -> None: |
|
"""Delete temporary files after use.""" |
|
for path in temp_file_paths: |
|
if os.path.exists(path): |
|
os.unlink(path) |
|
return (create_temp_files_from_uploads,) |
|
|
|
|
|
@app.function |
|
def load_pdf_data_with_progress(pdf_reader, filepaths, file_loader_value, show_progress=True): |
|
""" |
|
Loads PDF data for each file path and organizes results by original filename. |
|
Args: |
|
pdf_reader: The PyMuPDFReader instance |
|
filepaths: List of temporary file paths |
|
file_loader_value: The original upload results value containing file information |
|
show_progress: Whether to show a progress bar during loading (default: False) |
|
Returns: |
|
Dictionary mapping original filenames to their loaded text content |
|
""" |
|
results = {} |
|
|
|
|
|
if show_progress: |
|
import marimo as mo |
|
|
|
with mo.status.progress_bar( |
|
total=len(filepaths), |
|
title="Loading PDFs", |
|
subtitle="Processing documents...", |
|
completion_title="PDF Loading Complete", |
|
completion_subtitle=f"{len(filepaths)} documents processed", |
|
remove_on_exit=True |
|
) as bar: |
|
|
|
for i, file_path in enumerate(filepaths): |
|
|
|
original_file_name = file_loader_value[i].name |
|
bar.update(subtitle=f"Processing {original_file_name}...") |
|
loaded_text = pdf_reader.load_data(file_path=file_path, metadata=True) |
|
|
|
|
|
results[original_file_name] = loaded_text |
|
|
|
bar.update(increment=1) |
|
else: |
|
|
|
for i, file_path in enumerate(filepaths): |
|
original_file_name = file_loader_value[i].name |
|
loaded_text = pdf_reader.load_data(file_path=file_path, metadata=True) |
|
results[original_file_name] = loaded_text |
|
|
|
return results |
|
|
|
|
|
@app.cell |
|
def file_readers(): |
|
from llama_index.readers.file import PyMuPDFReader |
|
from llama_index.readers.file import FlatReader |
|
from llama_index.core.node_parser import SentenceSplitter |
|
|
|
|
|
pdf_reader = PyMuPDFReader() |
|
|
|
return SentenceSplitter, pdf_reader |
|
|
|
|
|
@app.cell |
|
def sentence_splitter_setup(): |
|
|
|
sentence_splitter_config = ( |
|
mo.md(''' |
|
###**Chunking Setup:** |
|
|
|
> Unless you want to do some advanced sentence splitting, it's best to stick to adjusting only the chunk size and overlap. Changing the other settings might result in unexpected results. |
|
|
|
Separator value is set to **" "** by default, while the paragraph separator is **"\\n\\n\\n"**. |
|
|
|
{chunk_size} |
|
|
|
{chunk_overlap} |
|
|
|
{separator} {paragraph_separator} |
|
|
|
{secondary_chunking_regex} {include_metadata} |
|
|
|
''') |
|
.batch( |
|
chunk_size = mo.ui.slider(start=100, stop=5000, step=1, label="**Chunk Size:**", value=350, show_value=True, full_width=True), |
|
chunk_overlap = mo.ui.slider(start=1, stop=1000, step=1, label="**Chunk Overlap** *(Must always be smaller than Chunk Size)* **:**", value=50, show_value=True, full_width=True), |
|
separator = mo.ui.text(placeholder="Define a separator", label="**Separator:**", kind="text", value=" "), |
|
paragraph_separator = mo.ui.text(placeholder="Define a paragraph separator", |
|
label="**Paragraph Separator:**", kind="text", |
|
value="\n\n\n"), |
|
secondary_chunking_regex = mo.ui.text(placeholder="Define a secondary chunking regex", |
|
label="**Chunking Regex:**", kind="text", |
|
value="[^,.;?!]+[,.;?!]?"), |
|
include_metadata= mo.ui.checkbox(value=True, label="**Include Metadata**") |
|
) |
|
.form(show_clear_button=True, bordered=False) |
|
) |
|
return (sentence_splitter_config,) |
|
|
|
|
|
@app.cell |
|
def sentence_splitter_instantiation( |
|
SentenceSplitter, |
|
sentence_splitter_config, |
|
): |
|
|
|
def simple_whitespace_tokenizer(text): |
|
return text.split() |
|
|
|
if sentence_splitter_config.value is not None: |
|
sentence_splitter_config_values = sentence_splitter_config.value |
|
validated_chunk_overlap = min(sentence_splitter_config_values.get("chunk_overlap"), |
|
int(sentence_splitter_config_values.get("chunk_size") * 0.3)) |
|
|
|
sentence_splitter = SentenceSplitter( |
|
chunk_size=sentence_splitter_config_values.get("chunk_size"), |
|
chunk_overlap=validated_chunk_overlap, |
|
separator=sentence_splitter_config_values.get("separator"), |
|
paragraph_separator=sentence_splitter_config_values.get("paragraph_separator"), |
|
secondary_chunking_regex=sentence_splitter_config_values.get("secondary_chunking_regex"), |
|
include_metadata=sentence_splitter_config_values.get("include_metadata"), |
|
tokenizer=simple_whitespace_tokenizer |
|
) |
|
|
|
else: |
|
sentence_splitter = SentenceSplitter( |
|
chunk_size=2048, |
|
chunk_overlap=204, |
|
separator=" ", |
|
paragraph_separator="\n\n\n", |
|
secondary_chunking_regex="[^,.;?!]+[,.;?!]?", |
|
include_metadata=True, |
|
tokenizer=simple_whitespace_tokenizer |
|
) |
|
return (sentence_splitter,) |
|
|
|
|
|
@app.cell |
|
def text_state(): |
|
get_text_state, set_text_state = mo.state(None) |
|
return get_text_state, set_text_state |
|
|
|
|
|
@app.cell |
|
def chunk_state(): |
|
get_chunk_state, set_chunk_state = mo.state(None) |
|
return get_chunk_state, set_chunk_state |
|
|
|
|
|
@app.function |
|
def chunk_documents(loaded_texts, sentence_splitter, show_progress=True): |
|
""" |
|
Process each document in the loaded_texts dictionary using the sentence_splitter, |
|
with an optional marimo progress bar tracking progress at document level. |
|
|
|
Args: |
|
loaded_texts (dict): Dictionary containing lists of Document objects |
|
sentence_splitter: The sentence splitter object with get_nodes_from_documents method |
|
show_progress (bool): Whether to show a progress bar during processing |
|
|
|
Returns: |
|
dict: Dictionary with the same structure but containing chunked texts |
|
""" |
|
chunked_texts_dict = {} |
|
|
|
|
|
total_docs = sum(len(docs) for docs in loaded_texts.values()) |
|
processed_docs = 0 |
|
|
|
|
|
if show_progress: |
|
import marimo as mo |
|
|
|
with mo.status.progress_bar( |
|
total=total_docs, |
|
title="Processing Documents", |
|
subtitle="Chunking documents...", |
|
completion_title="Processing Complete", |
|
completion_subtitle=f"{total_docs} documents processed", |
|
remove_on_exit=True |
|
) as bar: |
|
|
|
for key, documents in loaded_texts.items(): |
|
|
|
doc_count = len(documents) |
|
bar.update(subtitle=f"Chunking {key}... ({doc_count} documents)") |
|
|
|
|
|
chunked_texts = sentence_splitter.get_nodes_from_documents( |
|
documents, |
|
show_progress=False |
|
) |
|
|
|
|
|
chunked_texts_dict[key] = chunked_texts |
|
time.sleep(0.15) |
|
|
|
|
|
bar.update(increment=doc_count) |
|
processed_docs += doc_count |
|
else: |
|
|
|
for key, documents in loaded_texts.items(): |
|
chunked_texts = sentence_splitter.get_nodes_from_documents( |
|
documents, |
|
show_progress=True |
|
) |
|
chunked_texts_dict[key] = chunked_texts |
|
|
|
return chunked_texts_dict |
|
|
|
|
|
@app.cell |
|
def chunked_nodes(chunked_texts, get_chunk_state, sentence_splitter): |
|
if chunked_texts is not None and sentence_splitter: |
|
chunked_documents = get_chunk_state() |
|
else: |
|
chunked_documents = None |
|
return (chunked_documents,) |
|
|
|
|
|
@app.cell |
|
def prep_cumulative_df(chunked_documents, llamaindex_convert_docs_multi): |
|
if chunked_documents is not None: |
|
dict_from_nodes = llamaindex_convert_docs_multi(chunked_documents) |
|
nodes_from_dict = llamaindex_convert_docs_multi(dict_from_nodes) |
|
else: |
|
dict_from_nodes = None |
|
nodes_from_dict = None |
|
return (dict_from_nodes,) |
|
|
|
|
|
@app.cell |
|
def chunks_to_process( |
|
dict_from_nodes, |
|
document_range_stack, |
|
get_data_in_range_triplequote, |
|
): |
|
if dict_from_nodes is not None and document_range_stack is not None: |
|
|
|
chunk_dict_df = create_cumulative_dataframe(dict_from_nodes) |
|
|
|
if document_range_stack.value is not None: |
|
chunk_start_idx = document_range_stack.value[0] |
|
chunk_end_idx = document_range_stack.value[1] |
|
else: |
|
chunk_start_idx = 0 |
|
chunk_end_idx = len(chunk_dict_df) |
|
|
|
chunk_range_index = [chunk_start_idx, chunk_end_idx] |
|
chunks_dict = get_data_in_range_triplequote(chunk_dict_df, |
|
index_range=chunk_range_index, |
|
columns_to_include=["text"]) |
|
|
|
chunks_to_process = chunks_dict['text'] if 'text' in chunks_dict else [] |
|
else: |
|
chunk_objects = None |
|
chunks_dict = None |
|
chunks_to_process = None |
|
return chunks_dict, chunks_to_process |
|
|
|
|
|
@app.cell |
|
def helper_function_doc_formatting(): |
|
def llamaindex_convert_docs_multi(items): |
|
""" |
|
Automatically convert between document objects and dictionaries. |
|
|
|
This function handles: |
|
- Converting dictionaries to document objects |
|
- Converting document objects to dictionaries |
|
- Processing lists or individual items |
|
- Supporting dictionary structures where values are lists of documents |
|
|
|
Args: |
|
items: A document object, dictionary, or list of either. |
|
Can also be a dictionary mapping filenames to lists of documents. |
|
|
|
Returns: |
|
Converted item(s) maintaining the original structure |
|
""" |
|
|
|
if not items: |
|
return [] |
|
|
|
|
|
if isinstance(items, dict) and all(isinstance(v, list) for v in items.values()): |
|
result = {} |
|
for filename, doc_list in items.items(): |
|
result[filename] = llamaindex_convert_docs(doc_list) |
|
return result |
|
|
|
|
|
if not isinstance(items, list): |
|
|
|
if isinstance(items, dict): |
|
|
|
doc_class = None |
|
if 'doc_type' in items: |
|
import importlib |
|
module_path, class_name = items['doc_type'].rsplit('.', 1) |
|
module = importlib.import_module(module_path) |
|
doc_class = getattr(module, class_name) |
|
if not doc_class: |
|
from llama_index.core.schema import Document |
|
doc_class = Document |
|
return doc_class.from_dict(items) |
|
|
|
elif hasattr(items, 'to_dict'): |
|
return items.to_dict() |
|
|
|
return items |
|
|
|
|
|
result = [] |
|
|
|
|
|
if len(items) == 0: |
|
return result |
|
|
|
|
|
first_item = next((item for item in items if item is not None), None) |
|
|
|
|
|
if first_item is None: |
|
return result |
|
|
|
|
|
if isinstance(first_item, dict): |
|
|
|
doc_class = None |
|
|
|
if 'doc_type' in first_item: |
|
import importlib |
|
module_path, class_name = first_item['doc_type'].rsplit('.', 1) |
|
module = importlib.import_module(module_path) |
|
doc_class = getattr(module, class_name) |
|
if not doc_class: |
|
|
|
from llama_index.core.schema import Document |
|
doc_class = Document |
|
|
|
|
|
for item in items: |
|
if isinstance(item, dict): |
|
result.append(doc_class.from_dict(item)) |
|
elif item is None: |
|
result.append(None) |
|
elif isinstance(item, list): |
|
result.append(llamaindex_convert_docs(item)) |
|
else: |
|
result.append(item) |
|
|
|
|
|
else: |
|
for item in items: |
|
if hasattr(item, 'to_dict'): |
|
result.append(item.to_dict()) |
|
elif item is None: |
|
result.append(None) |
|
elif isinstance(item, list): |
|
result.append(llamaindex_convert_docs(item)) |
|
else: |
|
result.append(item) |
|
|
|
return result |
|
|
|
def llamaindex_convert_docs(items): |
|
""" |
|
Automatically convert between document objects and dictionaries. |
|
|
|
Args: |
|
items: A list of document objects or dictionaries |
|
|
|
Returns: |
|
List of converted items (dictionaries or document objects) |
|
""" |
|
result = [] |
|
|
|
|
|
if not items: |
|
return result |
|
|
|
|
|
if isinstance(items[0], dict): |
|
|
|
|
|
doc_class = None |
|
|
|
|
|
if 'doc_type' in items[0]: |
|
import importlib |
|
module_path, class_name = items[0]['doc_type'].rsplit('.', 1) |
|
module = importlib.import_module(module_path) |
|
doc_class = getattr(module, class_name) |
|
|
|
if not doc_class: |
|
|
|
from llama_index.core.schema import Document |
|
doc_class = Document |
|
|
|
|
|
for item in items: |
|
if isinstance(item, dict): |
|
result.append(doc_class.from_dict(item)) |
|
else: |
|
|
|
for item in items: |
|
if hasattr(item, 'to_dict'): |
|
result.append(item.to_dict()) |
|
|
|
return result |
|
return (llamaindex_convert_docs_multi,) |
|
|
|
|
|
@app.cell |
|
def helper_function_create_df(): |
|
def create_document_dataframes(dict_from_docs): |
|
""" |
|
Creates a pandas DataFrame for each file in the dictionary. |
|
|
|
Args: |
|
dict_from_docs: Dictionary mapping filenames to lists of documents |
|
|
|
Returns: |
|
List of pandas DataFrames, each representing all documents from a single file |
|
""" |
|
dataframes = [] |
|
|
|
for filename, docs in dict_from_docs.items(): |
|
|
|
file_records = [] |
|
|
|
for i, doc in enumerate(docs): |
|
|
|
if hasattr(doc, 'to_dict'): |
|
doc_data = doc.to_dict() |
|
elif isinstance(doc, dict): |
|
doc_data = doc |
|
else: |
|
doc_data = {'content': str(doc)} |
|
|
|
|
|
doc_data['doc_index'] = i |
|
|
|
|
|
file_records.append(doc_data) |
|
|
|
|
|
if file_records: |
|
df = pd.DataFrame(file_records) |
|
df['filename'] = filename |
|
dataframes.append(df) |
|
|
|
return dataframes |
|
|
|
def create_dataframe_previews(dataframe_list, page_size=5): |
|
""" |
|
Creates a list of mo.ui.dataframe components, one for each DataFrame in the input list. |
|
|
|
Args: |
|
dataframe_list: List of pandas DataFrames (output from create_document_dataframes) |
|
page_size: Number of rows to show per page for each component |
|
|
|
Returns: |
|
List of mo.ui.dataframe components |
|
""" |
|
|
|
preview_components = [] |
|
|
|
for df in dataframe_list: |
|
|
|
preview = mo.ui.dataframe(df, page_size=page_size) |
|
preview_components.append(preview) |
|
|
|
return preview_components |
|
return |
|
|
|
|
|
@app.cell |
|
def helper_function_chart_preparation(): |
|
import altair as alt |
|
import numpy as np |
|
import plotly.express as px |
|
from sklearn.manifold import TSNE |
|
|
|
def prepare_embedding_data(embeddings, texts, model_id=None, embedding_dimensions=None): |
|
""" |
|
Prepare embedding data for visualization |
|
|
|
Args: |
|
embeddings: List of embeddings arrays |
|
texts: List of text strings |
|
model_id: Embedding model ID (optional) |
|
embedding_dimensions: Embedding dimensions (optional) |
|
|
|
Returns: |
|
DataFrame with processed data and metadata |
|
""" |
|
|
|
flattened_embeddings = [] |
|
for emb in embeddings: |
|
if isinstance(emb, list) and len(emb) > 0 and isinstance(emb[0], list): |
|
flattened_embeddings.append(emb[0]) |
|
else: |
|
flattened_embeddings.append(emb) |
|
|
|
|
|
embedding_array = np.array(flattened_embeddings) |
|
|
|
|
|
tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(embedding_array)-1)) |
|
reduced_embeddings = tsne.fit_transform(embedding_array) |
|
|
|
|
|
truncated_texts = [text[:50] + "..." if len(text) > 50 else text for text in texts] |
|
|
|
|
|
df = pd.DataFrame({ |
|
"x": reduced_embeddings[:, 0], |
|
"y": reduced_embeddings[:, 1], |
|
"text": truncated_texts, |
|
"full_text": texts, |
|
"index": range(len(texts)) |
|
}) |
|
|
|
|
|
metadata = { |
|
"model_id": model_id, |
|
"embedding_dimensions": embedding_dimensions |
|
} |
|
|
|
return df, metadata |
|
|
|
def create_embedding_chart(df, metadata=None): |
|
""" |
|
Create an Altair chart for embedding visualization |
|
|
|
Args: |
|
df: DataFrame with x, y coordinates and text |
|
metadata: Dictionary with model_id and embedding_dimensions |
|
|
|
Returns: |
|
Altair chart |
|
""" |
|
model_id = metadata.get("model_id") if metadata else None |
|
embedding_dimensions = metadata.get("embedding_dimensions") if metadata else None |
|
|
|
selection = alt.selection_multi(fields=['index']) |
|
|
|
base = alt.Chart(df).encode( |
|
x=alt.X("x:Q", title="Dimension 1"), |
|
y=alt.Y("y:Q", title="Dimension 2"), |
|
tooltip=["text", "index"] |
|
) |
|
|
|
points = base.mark_circle(size=100).encode( |
|
color=alt.Color("index:N", legend=None), |
|
opacity=alt.condition(selection, alt.value(1), alt.value(0.2)) |
|
).add_selection(selection) |
|
|
|
text = base.mark_text(align="left", dx=7).encode( |
|
text="index:N" |
|
) |
|
|
|
return (points + text).properties( |
|
width=700, |
|
height=500, |
|
title=f"Embedding Visualization{f' - Model: {model_id}' if model_id else ''}{f' ({embedding_dimensions} dimensions)' if embedding_dimensions else ''}" |
|
).interactive() |
|
|
|
def show_selected_text(indices, texts): |
|
""" |
|
Create markdown display for selected texts |
|
|
|
Args: |
|
indices: List of selected indices |
|
texts: List of all texts |
|
|
|
Returns: |
|
Markdown string |
|
""" |
|
if not indices: |
|
return "No text selected" |
|
|
|
selected_texts = [texts[i] for i in indices if i < len(texts)] |
|
return "\n\n".join([f"**Document {i}**:\n{text}" for i, text in zip(indices, selected_texts)]) |
|
|
|
def prepare_embedding_data_3d(embeddings, texts, model_id=None, embedding_dimensions=None): |
|
""" |
|
Prepare embedding data for 3D visualization |
|
|
|
Args: |
|
embeddings: List of embeddings arrays |
|
texts: List of text strings |
|
model_id: Embedding model ID (optional) |
|
embedding_dimensions: Embedding dimensions (optional) |
|
|
|
Returns: |
|
DataFrame with processed data and metadata |
|
""" |
|
|
|
flattened_embeddings = [] |
|
for emb in embeddings: |
|
if isinstance(emb, list) and len(emb) > 0 and isinstance(emb[0], list): |
|
flattened_embeddings.append(emb[0]) |
|
else: |
|
flattened_embeddings.append(emb) |
|
|
|
|
|
embedding_array = np.array(flattened_embeddings) |
|
|
|
|
|
if len(embedding_array) == 1: |
|
|
|
reduced_embeddings = np.array([[0.0, 0.0, 0.0]]) |
|
else: |
|
|
|
|
|
perplexity_value = max(1.0, min(30, len(embedding_array)-1)) |
|
tsne = TSNE(n_components=3, random_state=42, perplexity=perplexity_value) |
|
reduced_embeddings = tsne.fit_transform(embedding_array) |
|
|
|
|
|
formatted_texts = [] |
|
for text in texts: |
|
|
|
if len(text) > 500: |
|
text = text[:500] + "..." |
|
|
|
|
|
wrapped_text = "" |
|
for i in range(0, len(text), 50): |
|
wrapped_text += text[i:i+50] + "<br>" |
|
|
|
formatted_texts.append("<b>"+wrapped_text+"</b>") |
|
|
|
|
|
df = pd.DataFrame({ |
|
"x": reduced_embeddings[:, 0], |
|
"y": reduced_embeddings[:, 1], |
|
"z": reduced_embeddings[:, 2], |
|
"text": formatted_texts, |
|
"full_text": texts, |
|
"index": range(len(texts)), |
|
"embedding": flattened_embeddings |
|
}) |
|
|
|
|
|
metadata = { |
|
"model_id": model_id, |
|
"embedding_dimensions": embedding_dimensions |
|
} |
|
|
|
return df, metadata |
|
|
|
def create_3d_embedding_chart(df, metadata=None, chart_width=1200, chart_height=800, marker_size_var: int=3): |
|
""" |
|
Create a 3D Plotly chart for embedding visualization with proximity-based coloring |
|
""" |
|
model_id = metadata.get("model_id") if metadata else None |
|
embedding_dimensions = metadata.get("embedding_dimensions") if metadata else None |
|
|
|
|
|
from scipy.spatial.distance import pdist, squareform |
|
|
|
coords = df[['x', 'y', 'z']].values |
|
|
|
|
|
dist_matrix = squareform(pdist(coords)) |
|
|
|
|
|
avg_distances = np.mean(dist_matrix, axis=1) |
|
|
|
|
|
df['proximity'] = avg_distances |
|
|
|
|
|
fig = px.scatter_3d( |
|
df, |
|
x='x', |
|
y='y', |
|
z='z', |
|
|
|
|
|
|
|
color='proximity', |
|
color_continuous_scale='Viridis_r', |
|
hover_data=['text', 'index', 'proximity'], |
|
labels={'x': 'Dimension 1', 'y': 'Dimension 2', 'z': 'Dimension 3', 'proximity': 'Avg Distance'}, |
|
|
|
title=f"<b>3D Embedding Visualization</b>{f' - Model: <b>{model_id}</b>' if model_id else ''}{f' <i>({embedding_dimensions} dimensions)</i>' if embedding_dimensions else ''}", |
|
text='index', |
|
|
|
) |
|
|
|
|
|
|
|
fig.update_traces( |
|
marker=dict( |
|
size=marker_size_var, |
|
opacity=0.7, |
|
symbol="diamond", |
|
line=dict( |
|
width=0.5, |
|
color="white" |
|
) |
|
), |
|
textfont=dict( |
|
color="rgba(255, 255, 255, 0.3)", |
|
size=8 |
|
), |
|
|
|
hovertemplate="text:<br><b>%{customdata[0]}</b><br>index: <b>%{text}</b><br><br>Avg Distance: <b>%{customdata[2]:.4f}</b><extra></extra>", |
|
hoverinfo="text+name", |
|
hoverlabel=dict( |
|
bgcolor="white", |
|
font_size=12 |
|
), |
|
selector=dict(type='scatter3d') |
|
) |
|
|
|
|
|
fig.update_layout( |
|
scene=dict( |
|
xaxis=dict( |
|
title='Dimension 1', |
|
nticks=40, |
|
backgroundcolor="rgb(10, 10, 20, 0.1)", |
|
gridcolor="white", |
|
showbackground=True, |
|
gridwidth=0.35, |
|
zerolinecolor="white", |
|
), |
|
yaxis=dict( |
|
title='Dimension 2', |
|
nticks=40, |
|
backgroundcolor="rgb(10, 10, 20, 0.1)", |
|
gridcolor="white", |
|
showbackground=True, |
|
gridwidth=0.35, |
|
zerolinecolor="white", |
|
), |
|
zaxis=dict( |
|
title='Dimension 3', |
|
nticks=40, |
|
backgroundcolor="rgb(10, 10, 20, 0.1)", |
|
gridcolor="white", |
|
showbackground=True, |
|
gridwidth=0.35, |
|
zerolinecolor="white", |
|
), |
|
|
|
camera=dict( |
|
up=dict(x=0, y=0, z=1), |
|
center=dict(x=0, y=0, z=0), |
|
eye=dict(x=1.25, y=1.25, z=1.25), |
|
), |
|
aspectratio=dict(x=1, y=1, z=1), |
|
aspectmode='data' |
|
), |
|
width=int(chart_width), |
|
height=int(chart_height), |
|
margin=dict(r=20, l=10, b=10, t=50), |
|
paper_bgcolor="rgb(0, 0, 0)", |
|
plot_bgcolor="rgb(0, 0, 0)", |
|
coloraxis_colorbar=dict( |
|
title="Average Distance", |
|
thicknessmode="pixels", thickness=20, |
|
lenmode="pixels", len=400, |
|
yanchor="top", y=1, |
|
ticks="outside", |
|
dtick=0.1 |
|
) |
|
) |
|
|
|
return fig |
|
return create_3d_embedding_chart, prepare_embedding_data_3d |
|
|
|
|
|
@app.cell |
|
def helper_function_text_preparation(): |
|
def convert_table_to_json_docs(df, selected_columns=None): |
|
""" |
|
Convert a pandas DataFrame or dictionary to a list of JSON documents. |
|
Dynamically includes columns based on user selection. |
|
Column names are standardized to lowercase with underscores instead of spaces |
|
and special characters removed. |
|
|
|
Args: |
|
df: The DataFrame or dictionary to process |
|
selected_columns: List of column names to include in the output documents |
|
|
|
Returns: |
|
list: A list of dictionaries, each representing a row as a JSON document |
|
""" |
|
import pandas as pd |
|
import re |
|
|
|
def standardize_key(key): |
|
"""Convert a column name to lowercase with underscores instead of spaces and no special characters""" |
|
if not isinstance(key, str): |
|
return str(key).lower() |
|
|
|
key = key.lower().replace(' ', '_') |
|
|
|
return re.sub(r'[^\w]', '', key) |
|
|
|
|
|
if isinstance(df, dict): |
|
|
|
if selected_columns: |
|
return [{standardize_key(k): df.get(k, None) for k in selected_columns}] |
|
else: |
|
|
|
return [{standardize_key(k): v for k, v in df.items()}] |
|
|
|
|
|
if df is None: |
|
return [] |
|
|
|
|
|
if not isinstance(df, pd.DataFrame): |
|
try: |
|
df = pd.DataFrame(df) |
|
except: |
|
return [] |
|
|
|
|
|
if df.empty: |
|
return [] |
|
|
|
|
|
if not selected_columns or not isinstance(selected_columns, list) or len(selected_columns) == 0: |
|
selected_columns = list(df.columns) |
|
|
|
|
|
available_columns = [] |
|
columns_lower = {col.lower(): col for col in df.columns if isinstance(col, str)} |
|
|
|
for col in selected_columns: |
|
if col in df.columns: |
|
available_columns.append(col) |
|
elif isinstance(col, str) and col.lower() in columns_lower: |
|
available_columns.append(columns_lower[col.lower()]) |
|
|
|
|
|
if not available_columns: |
|
return [] |
|
|
|
|
|
json_docs = [] |
|
for _, row in df.iterrows(): |
|
doc = {} |
|
for col in available_columns: |
|
value = row[col] |
|
|
|
std_col = standardize_key(col) |
|
doc[std_col] = None if pd.isna(value) else value |
|
json_docs.append(doc) |
|
|
|
return json_docs |
|
|
|
def get_column_values(df, columns_to_include): |
|
""" |
|
Extract values from specified columns of a dataframe as lists. |
|
|
|
Args: |
|
df: A pandas DataFrame |
|
columns_to_include: A list of column names to extract |
|
|
|
Returns: |
|
Dictionary with column names as keys and their values as lists |
|
""" |
|
result = {} |
|
|
|
|
|
valid_columns = [col for col in columns_to_include if col in df.columns] |
|
invalid_columns = set(columns_to_include) - set(valid_columns) |
|
|
|
if invalid_columns: |
|
print(f"Warning: These columns don't exist in the dataframe: {list(invalid_columns)}") |
|
|
|
|
|
for col in valid_columns: |
|
result[col] = df[col].tolist() |
|
|
|
return result |
|
|
|
def get_data_in_range(doc_dict_df, index_range, columns_to_include): |
|
""" |
|
Extract values from specified columns of a dataframe within a given index range. |
|
|
|
Args: |
|
doc_dict_df: The pandas DataFrame to extract data from |
|
index_range: An integer specifying the number of rows to include (from 0 to index_range-1) |
|
columns_to_include: A list of column names to extract |
|
|
|
Returns: |
|
Dictionary with column names as keys and their values (within the index range) as lists |
|
""" |
|
|
|
max_index = len(doc_dict_df) |
|
if index_range <= 0: |
|
print(f"Warning: Invalid index range {index_range}. Must be positive.") |
|
return {} |
|
|
|
|
|
if index_range > max_index: |
|
print(f"Warning: Index range {index_range} exceeds dataframe length {max_index}. Using maximum length.") |
|
index_range = max_index |
|
|
|
|
|
df_subset = doc_dict_df.iloc[:index_range] |
|
|
|
|
|
return get_column_values(df_subset, columns_to_include) |
|
|
|
def get_data_in_range_triplequote(doc_dict_df, index_range, columns_to_include): |
|
""" |
|
Extract values from specified columns of a dataframe within a given index range. |
|
Wraps string values with triple quotes and escapes URLs. |
|
|
|
Args: |
|
doc_dict_df: The pandas DataFrame to extract data from |
|
index_range: A list of two integers specifying the start and end indices of rows to include |
|
(e.g., [0, 10] includes rows from index 0 to 9 inclusive) |
|
columns_to_include: A list of column names to extract |
|
""" |
|
|
|
start_idx, end_idx = index_range |
|
max_index = len(doc_dict_df) |
|
|
|
|
|
if start_idx < 0: |
|
print(f"Warning: Invalid start index {start_idx}. Using 0 instead.") |
|
start_idx = 0 |
|
|
|
|
|
if end_idx <= start_idx: |
|
print(f"Warning: End index {end_idx} must be greater than start index {start_idx}. Using {start_idx + 1} instead.") |
|
end_idx = start_idx + 1 |
|
|
|
|
|
if end_idx > max_index: |
|
print(f"Warning: End index {end_idx} exceeds dataframe length {max_index}. Using maximum length.") |
|
end_idx = max_index |
|
|
|
|
|
|
|
df_subset = doc_dict_df.iloc[start_idx:end_idx] |
|
|
|
|
|
result = get_column_values(df_subset, columns_to_include) |
|
|
|
|
|
for col in result: |
|
if isinstance(result[col], list): |
|
|
|
processed_items = [] |
|
for item in result[col]: |
|
if isinstance(item, str): |
|
|
|
item = item.replace("http://", "http\\://").replace("https://", "https\\://") |
|
|
|
processed_items.append(item) |
|
else: |
|
processed_items.append(item) |
|
result[col] = processed_items |
|
return result |
|
return (get_data_in_range_triplequote,) |
|
|
|
|
|
@app.cell |
|
def prepare_doc_select(sentence_splitter_config): |
|
def prepare_document_selection(node_dict): |
|
""" |
|
Creates document selection UI component. |
|
Args: |
|
node_dict: Dictionary mapping filenames to lists of documents |
|
Returns: |
|
mo.ui component for document selection |
|
""" |
|
|
|
total_docs = sum(len(docs) for docs in node_dict.values()) |
|
|
|
|
|
all_docs_records = [] |
|
doc_index_global = 0 |
|
for filename, docs in node_dict.items(): |
|
for i, doc in enumerate(docs): |
|
|
|
if hasattr(doc, 'to_dict'): |
|
doc_data = doc.to_dict() |
|
elif isinstance(doc, dict): |
|
doc_data = doc |
|
else: |
|
doc_data = {'content': str(doc)} |
|
|
|
|
|
doc_data['filename'] = filename |
|
doc_data['doc_index'] = i |
|
doc_data['global_index'] = doc_index_global |
|
all_docs_records.append(doc_data) |
|
doc_index_global += 1 |
|
|
|
|
|
stop_value = max(total_docs, 1) |
|
llama_docs = mo.ui.range_slider( |
|
start=1, |
|
stop=stop_value, |
|
step=1, |
|
full_width=True, |
|
show_value=True, |
|
label="**Select a Range of Chunks to Visualize:**" |
|
).form(submit_button_disabled=check_state(sentence_splitter_config.value)) |
|
|
|
return llama_docs |
|
return (prepare_document_selection,) |
|
|
|
|
|
@app.cell |
|
def document_range_selection( |
|
dict_from_nodes, |
|
prepare_document_selection, |
|
set_range_slider_state, |
|
): |
|
if dict_from_nodes is not None: |
|
llama_docs = prepare_document_selection(dict_from_nodes) |
|
set_range_slider_state(llama_docs) |
|
else: |
|
bare_dict = {} |
|
llama_docs = prepare_document_selection(bare_dict) |
|
return |
|
|
|
|
|
@app.function |
|
def create_cumulative_dataframe(dict_from_docs): |
|
""" |
|
Creates a cumulative DataFrame from a nested dictionary of documents. |
|
|
|
Args: |
|
dict_from_docs: Dictionary mapping filenames to lists of documents |
|
|
|
Returns: |
|
DataFrame with all documents flattened with global indices |
|
""" |
|
|
|
all_records = [] |
|
global_idx = 1 |
|
|
|
for filename, docs in dict_from_docs.items(): |
|
for i, doc in enumerate(docs): |
|
|
|
if hasattr(doc, 'to_dict'): |
|
doc_data = doc.to_dict() |
|
elif isinstance(doc, dict): |
|
doc_data = doc.copy() |
|
else: |
|
doc_data = {'content': str(doc)} |
|
|
|
|
|
doc_data['filename'] = filename |
|
doc_data['doc_index'] = i |
|
doc_data['global_index'] = global_idx |
|
|
|
|
|
if 'content' in doc_data and 'text' not in doc_data: |
|
doc_data['text'] = doc_data['content'] |
|
|
|
all_records.append(doc_data) |
|
global_idx += 1 |
|
|
|
|
|
return pd.DataFrame(all_records) |
|
|
|
|
|
@app.function |
|
def create_stats(texts_dict, bordered=False, object_names=None, group_by_row=False, items_per_row=6, gap=2, label="Chunk"): |
|
""" |
|
Create a list of stat objects for each item in the specified dictionary. |
|
|
|
Parameters: |
|
- texts_dict (dict): Dictionary containing the text data |
|
- bordered (bool): Whether the stats should be bordered |
|
- object_names (list or tuple): Two object names to use for label and value |
|
[label_object, value_object] |
|
- group_by_row (bool): Whether to group stats in rows (horizontal stacks) |
|
- items_per_row (int): Number of stat objects per row when group_by_row is True |
|
|
|
Returns: |
|
- object: A vertical stack of stat objects or rows of stat objects |
|
""" |
|
if not object_names or len(object_names) < 2: |
|
raise ValueError("You must provide two object names as a list or tuple") |
|
|
|
label_object = object_names[0] |
|
value_object = object_names[1] |
|
|
|
|
|
if label_object not in texts_dict: |
|
raise ValueError(f"Label object '{label_object}' not found in texts_dict") |
|
if value_object not in texts_dict: |
|
raise ValueError(f"Value object '{value_object}' not found in texts_dict") |
|
|
|
|
|
num_items = len(texts_dict[label_object]) |
|
|
|
|
|
individual_stats = [] |
|
for i in range(num_items): |
|
stat = mo.stat( |
|
label=texts_dict[label_object][i], |
|
value=f"{label} Number: {len(texts_dict[value_object][i])}", |
|
bordered=bordered |
|
) |
|
individual_stats.append(stat) |
|
|
|
|
|
if not group_by_row: |
|
return mo.vstack(individual_stats, wrap=False) |
|
|
|
|
|
rows = [] |
|
for i in range(0, num_items, items_per_row): |
|
|
|
row_stats = individual_stats[i:i+items_per_row] |
|
|
|
widths = [0.35] * len(row_stats) |
|
row = mo.hstack(row_stats, gap=gap, align="start", justify="center", widths=widths) |
|
rows.append(row) |
|
|
|
|
|
return mo.vstack(rows) |
|
|
|
|
|
@app.cell |
|
def prepare_chart_embeddings( |
|
chunks_to_process, |
|
emb_model, |
|
emb_model_emb_dim, |
|
get_embedding_state, |
|
prepare_embedding_data_3d, |
|
): |
|
|
|
if chunks_to_process is not None and get_embedding_state() is not None: |
|
chart_dataframe, chart_metadata = prepare_embedding_data_3d( |
|
get_embedding_state(), |
|
chunks_to_process, |
|
model_id=emb_model, |
|
embedding_dimensions=emb_model_emb_dim |
|
) |
|
else: |
|
chart_dataframe, chart_metadata = None, None |
|
return chart_dataframe, chart_metadata |
|
|
|
|
|
@app.cell |
|
def chart_dims(): |
|
chart_dimensions = ( |
|
mo.md(''' |
|
> **Adjust Chart Window** |
|
|
|
{chart_height} |
|
|
|
{chat_width} |
|
|
|
''').batch( |
|
chart_height = mo.ui.slider(start=500, step=30, stop=1000, label="**Height:**", value=800, show_value=True), |
|
chat_width = mo.ui.slider(start=900, step=50, stop=1400, label="**Width:**", value=1200, show_value=True) |
|
) |
|
) |
|
return (chart_dimensions,) |
|
|
|
|
|
@app.cell |
|
def chart_dim_values(chart_dimensions): |
|
chart_height = chart_dimensions.value['chart_height'] |
|
chart_width = chart_dimensions.value['chat_width'] |
|
return chart_height, chart_width |
|
|
|
|
|
@app.cell |
|
def create_baseline_chart( |
|
chart_dataframe, |
|
chart_height, |
|
chart_metadata, |
|
chart_width, |
|
create_3d_embedding_chart, |
|
): |
|
if chart_dataframe is not None and chart_metadata is not None: |
|
emb_plot = create_3d_embedding_chart(chart_dataframe, chart_metadata, chart_width, chart_height, marker_size_var=9) |
|
chart = mo.ui.plotly(emb_plot) |
|
else: |
|
emb_plot = None |
|
chart = None |
|
return (emb_plot,) |
|
|
|
|
|
@app.cell |
|
def test_query(get_chunk_state): |
|
placeholder = """How can i use watsonx.data to perform vector search?""" |
|
|
|
query = mo.ui.text_area(label="**Write text to check:**", full_width=True, rows=8, value=placeholder).form(show_clear_button=True, submit_button_disabled=check_state(get_chunk_state())) |
|
return (query,) |
|
|
|
|
|
@app.cell |
|
def query_stack(chart_dimensions, query): |
|
|
|
query_stack = mo.hstack([query, chart_dimensions], justify="space-around", align="center", gap=15) |
|
return (query_stack,) |
|
|
|
|
|
@app.function |
|
def check_state(variable): |
|
return variable is None |
|
|
|
|
|
@app.cell |
|
def helper_function_add_query_to_chart(): |
|
def add_query_to_embedding_chart(existing_chart, query_coords, query_text, marker_size=12): |
|
""" |
|
Add a query point to an existing 3D embedding chart as a large red dot. |
|
|
|
Args: |
|
existing_chart: The existing plotly figure or chart data |
|
query_coords: Dictionary with 'x', 'y', 'z' coordinates for the query point |
|
query_text: Text of the query to display on hover |
|
marker_size: Size of the query marker (default: 18, typically 2x other markers) |
|
|
|
Returns: |
|
A modified plotly figure with the query point added as a red dot |
|
""" |
|
import plotly.graph_objects as go |
|
|
|
|
|
import copy |
|
chart_copy = copy.deepcopy(existing_chart) |
|
|
|
|
|
if isinstance(chart_copy, (dict, list)): |
|
|
|
import plotly.graph_objects as go |
|
|
|
if isinstance(chart_copy, list): |
|
|
|
fig = go.Figure(data=chart_copy) |
|
else: |
|
|
|
fig = go.Figure(data=chart_copy.get('data', []), layout=chart_copy.get('layout', {})) |
|
|
|
chart_copy = fig |
|
|
|
|
|
query_trace = go.Scatter3d( |
|
x=[query_coords['x']], |
|
y=[query_coords['y']], |
|
z=[query_coords['z']], |
|
mode='markers', |
|
name='Query', |
|
marker=dict( |
|
size=marker_size, |
|
color='red', |
|
symbol='circle', |
|
opacity=0.70, |
|
line=dict( |
|
width=1, |
|
color='white' |
|
) |
|
), |
|
|
|
text=['<b>Query:</b><br>' + '<br>'.join([query_text[i:i+50] for i in range(0, len(query_text), 50)])], |
|
hoverinfo="text+name" |
|
) |
|
|
|
|
|
chart_copy.add_trace(query_trace) |
|
|
|
return chart_copy |
|
|
|
|
|
def get_query_coordinates(reference_embeddings=None, query_embedding=None): |
|
""" |
|
Calculate appropriate coordinates for a query point based on reference embeddings. |
|
|
|
This function handles several scenarios: |
|
1. If both reference embeddings and query embedding are provided, it places the |
|
query near similar documents. |
|
2. If only reference embeddings are provided, it places the query at a visible |
|
location near the center of the chart. |
|
3. If neither are provided, it returns default origin coordinates. |
|
|
|
Args: |
|
reference_embeddings: DataFrame with x, y, z coordinates from the main chart |
|
query_embedding: The embedding vector of the query |
|
|
|
Returns: |
|
Dictionary with x, y, z coordinates for the query point |
|
""" |
|
import numpy as np |
|
|
|
|
|
default_coords = {'x': 0.0, 'y': 0.0, 'z': 0.0} |
|
|
|
|
|
if reference_embeddings is None or len(reference_embeddings) == 0: |
|
return default_coords |
|
|
|
|
|
|
|
if query_embedding is None: |
|
center_coords = { |
|
'x': reference_embeddings['x'].mean(), |
|
'y': reference_embeddings['y'].mean(), |
|
'z': reference_embeddings['z'].mean() |
|
} |
|
return center_coords |
|
|
|
|
|
|
|
try: |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
|
|
|
if 'embedding' in reference_embeddings.columns: |
|
|
|
if isinstance(reference_embeddings['embedding'].iloc[0], list): |
|
doc_embeddings = np.array(reference_embeddings['embedding'].tolist()) |
|
else: |
|
doc_embeddings = np.array([emb for emb in reference_embeddings['embedding'].values]) |
|
|
|
|
|
query_emb_array = np.array(query_embedding) |
|
if query_emb_array.ndim == 1: |
|
query_emb_array = query_emb_array.reshape(1, -1) |
|
|
|
|
|
similarities = cosine_similarity(query_emb_array, doc_embeddings)[0] |
|
|
|
|
|
closest_idx = np.argmax(similarities) |
|
|
|
|
|
query_coords = { |
|
'x': reference_embeddings['x'].iloc[closest_idx] + 0.2, |
|
'y': reference_embeddings['y'].iloc[closest_idx] + 0.2, |
|
'z': reference_embeddings['z'].iloc[closest_idx] + 0.2 |
|
} |
|
return query_coords |
|
except Exception as e: |
|
print(f"Error positioning query near similar documents: {e}") |
|
|
|
|
|
center_coords = { |
|
'x': reference_embeddings['x'].mean(), |
|
'y': reference_embeddings['y'].mean(), |
|
'z': reference_embeddings['z'].mean() |
|
} |
|
return center_coords |
|
return add_query_to_embedding_chart, get_query_coordinates |
|
|
|
|
|
@app.cell |
|
def combined_chart_visualization( |
|
add_query_to_embedding_chart, |
|
chart_dataframe, |
|
emb_plot, |
|
embedding, |
|
get_query_coordinates, |
|
get_query_state, |
|
query, |
|
set_chart_state, |
|
set_query_state, |
|
): |
|
|
|
if chart_dataframe is not None and query.value: |
|
with mo.status.spinner(title="Embedding Query...", remove_on_exit=True) as _spinner: |
|
query_emb = embedding.embed_documents([query.value]) |
|
set_query_state(query_emb) |
|
|
|
_spinner.update("Preparing Query Coordinates") |
|
time.sleep(1.0) |
|
|
|
|
|
query_coords = get_query_coordinates( |
|
reference_embeddings=chart_dataframe, |
|
query_embedding=get_query_state() |
|
) |
|
|
|
_spinner.update("Adding Query to Chart") |
|
time.sleep(1.0) |
|
|
|
|
|
result = add_query_to_embedding_chart( |
|
existing_chart=emb_plot, |
|
query_coords=query_coords, |
|
query_text=query.value, |
|
) |
|
|
|
chart_with_query = result |
|
|
|
_spinner.update("Preparing Visualization") |
|
time.sleep(1.0) |
|
|
|
|
|
combined_viz = mo.ui.plotly(chart_with_query) |
|
set_chart_state(combined_viz) |
|
|
|
_spinner.update("Done") |
|
else: |
|
combined_viz = None |
|
return |
|
|
|
|
|
@app.cell |
|
def _(): |
|
get_range_slider_state, set_range_slider_state = mo.state(None) |
|
return get_range_slider_state, set_range_slider_state |
|
|
|
|
|
@app.cell |
|
def _(get_range_slider_state): |
|
if get_range_slider_state() is not None: |
|
document_range_stack = get_range_slider_state() |
|
else: |
|
document_range_stack = None |
|
return (document_range_stack,) |
|
|
|
|
|
@app.cell |
|
def _(): |
|
get_chart_state, set_chart_state = mo.state(None) |
|
return get_chart_state, set_chart_state |
|
|
|
|
|
@app.cell |
|
def _(get_chart_state, query): |
|
if query.value is not None: |
|
chart_visualization = get_chart_state() |
|
else: |
|
chart_visualization = None |
|
return (chart_visualization,) |
|
|
|
|
|
@app.cell |
|
def c(document_range_stack): |
|
chart_range_selection = mo.hstack([document_range_stack], justify="space-around", align="center", widths=[0.65]) |
|
return (chart_range_selection,) |
|
|
|
|
|
if __name__ == "__main__": |
|
app.run() |
|
|