MilanM commited on
Commit
b17b1a5
·
verified ·
1 Parent(s): 2eb0521

Create visualizer_app.py

Browse files
Files changed (1) hide show
  1. visualizer_app.py +2056 -0
visualizer_app.py ADDED
@@ -0,0 +1,2056 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import marimo
2
+
3
+ __generated_with = "0.13.0"
4
+ app = marimo.App(width="full")
5
+
6
+ with app.setup:
7
+ # Initialization code that runs before all other cells
8
+ import marimo as mo
9
+ from typing import Dict, Optional, List, Union, Any
10
+ from ibm_watsonx_ai import APIClient, Credentials
11
+ from pathlib import Path
12
+ import pandas as pd
13
+ import mimetypes
14
+ import requests
15
+ import zipfile
16
+ import tempfile
17
+ import base64
18
+ import polars
19
+ import time
20
+ import json
21
+ import ast
22
+ import os
23
+ import io
24
+ import re
25
+
26
+ def get_iam_token(api_key):
27
+ return requests.post(
28
+ 'https://iam.cloud.ibm.com/identity/token',
29
+ headers={'Content-Type': 'application/x-www-form-urlencoded'},
30
+ data={'grant_type': 'urn:ibm:params:oauth:grant-type:apikey', 'apikey': api_key}
31
+ ).json()['access_token']
32
+
33
+ def setup_task_credentials(client):
34
+ # Get existing task credentials
35
+ existing_credentials = client.task_credentials.get_details()
36
+
37
+ # Delete existing credentials if any
38
+ if "resources" in existing_credentials and existing_credentials["resources"]:
39
+ for cred in existing_credentials["resources"]:
40
+ cred_id = client.task_credentials.get_id(cred)
41
+ client.task_credentials.delete(cred_id)
42
+
43
+ # Store new credentials
44
+ return client.task_credentials.store()
45
+
46
+ def get_cred_value(key, creds_var_name="baked_in_creds", default=""): ### Helper for working with preset credentials
47
+ """
48
+ Helper function to safely get a value from a credentials dictionary.
49
+
50
+ Args:
51
+ key: The key to look up in the credentials dictionary.
52
+ creds_var_name: The variable name of the credentials dictionary.
53
+ default: The default value to return if the key is not found.
54
+
55
+ Returns:
56
+ The value from the credentials dictionary if it exists and contains the key,
57
+ otherwise returns the default value.
58
+ """
59
+ # Check if the credentials variable exists in globals
60
+ if creds_var_name in globals():
61
+ creds_dict = globals()[creds_var_name]
62
+ if isinstance(creds_dict, dict) and key in creds_dict:
63
+ return creds_dict[key]
64
+ return default
65
+
66
+ @app.cell
67
+ def client_variables(client_instantiation_form):
68
+ if client_instantiation_form.value:
69
+ client_setup = client_instantiation_form.value
70
+ else:
71
+ client_setup = None
72
+
73
+ ### Extract Credential Variables:
74
+ if client_setup is not None:
75
+ wx_url = client_setup["wx_region"]
76
+ wx_api_key = client_setup["wx_api_key"].strip()
77
+ os.environ["WATSONX_APIKEY"] = wx_api_key
78
+
79
+ if client_setup["project_id"] is not None:
80
+ project_id = client_setup["project_id"].strip()
81
+ else:
82
+ project_id = None
83
+
84
+ if client_setup["space_id"] is not None:
85
+ space_id = client_setup["space_id"].strip()
86
+ else:
87
+ space_id = None
88
+
89
+ else:
90
+ os.environ["WATSONX_APIKEY"] = ""
91
+ project_id = None
92
+ space_id = None
93
+ wx_api_key = None
94
+ wx_url = None
95
+ return client_setup, project_id, space_id, wx_api_key, wx_url
96
+
97
+
98
+ @app.cell
99
+ def _(client_setup, wx_api_key):
100
+ if client_setup:
101
+ token = get_iam_token(wx_api_key)
102
+ else:
103
+ token = None
104
+ return
105
+
106
+ @app.cell
107
+ def _():
108
+ baked_in_creds = {
109
+ "purpose": "",
110
+ "api_key": "",
111
+ "project_id": "",
112
+ "space_id": "",
113
+ }
114
+ return baked_in_creds
115
+
116
+
117
+ @app.cell
118
+ def client_instantiation(
119
+ client_setup,
120
+ project_id,
121
+ space_id,
122
+ wx_api_key,
123
+ wx_url,
124
+ ):
125
+ ### Instantiate the watsonx.ai client
126
+ if client_setup:
127
+ wx_credentials = Credentials(
128
+ url=wx_url,
129
+ api_key=wx_api_key
130
+ )
131
+
132
+ if project_id:
133
+ project_client = APIClient(credentials=wx_credentials, project_id=project_id)
134
+ else:
135
+ project_client = None
136
+
137
+ if space_id:
138
+ deployment_client = APIClient(credentials=wx_credentials, space_id=space_id)
139
+ else:
140
+ deployment_client = None
141
+
142
+ if project_client is not None:
143
+ task_credentials_details = setup_task_credentials(project_client)
144
+ else:
145
+ task_credentials_details = setup_task_credentials(deployment_client)
146
+ else:
147
+ wx_credentials = None
148
+ project_client = None
149
+ deployment_client = None
150
+ task_credentials_details = None
151
+
152
+ client_status = mo.md("### Client Instantiation Status will turn Green When Ready")
153
+
154
+ if project_client is not None or deployment_client is not None:
155
+ client_callout_kind = "success"
156
+ else:
157
+ client_callout_kind = "neutral"
158
+ return (
159
+ client_callout_kind,
160
+ client_status,
161
+ deployment_client,
162
+ project_client,
163
+ )
164
+
165
+
166
+ @app.cell
167
+ def _():
168
+ mo.md(
169
+ r"""
170
+ #watsonx.ai Embedding Visualizer - Marimo Notebook
171
+
172
+ #### This marimo notebook can be used to develop a more intuitive understanding of how vector embeddings work by creating a 3D visualization of vector embeddings based on chunked PDF document pages.
173
+
174
+ #### It can also serve as a useful tool for identifying gaps in model choice, chunking strategy or contents used in building collections by showing how far you are from what you want.
175
+ <br>
176
+
177
+ /// admonition
178
+ Created by ***Milan Mrdenovic*** [[email protected]] for IBM Ecosystem Client Engineering, NCEE - ***version 5.3** - 20.04.2025*
179
+ ///
180
+
181
+
182
+ >Licensed under apache 2.0, users hold full accountability for any use or modification of the code.
183
+ ><br>This asset is part of a set meant to support IBMers, IBM Partners, Clients in developing understanding of how to better utilize various watsonx features and generative AI as a subject matter.
184
+
185
+ <br>
186
+ """
187
+ )
188
+ return
189
+
190
+
191
+ @app.cell
192
+ def _():
193
+ mo.md("""###Part 1 - Client Setup, File Preparation and Chunking""")
194
+ return
195
+
196
+
197
+ @app.cell
198
+ def accordion_client_setup(client_selector, client_stack):
199
+ ui_accordion_part_1_1 = mo.accordion(
200
+ {
201
+ "Instantiate Client": mo.vstack([client_stack, client_selector], align="center"),
202
+ }
203
+ )
204
+
205
+ ui_accordion_part_1_1
206
+ return
207
+
208
+
209
+ @app.cell
210
+ def accordion_file_upload(select_stack):
211
+ ui_accordion_part_1_2 = mo.accordion(
212
+ {
213
+ "Select Model & Upload Files": select_stack
214
+ }
215
+ )
216
+
217
+ ui_accordion_part_1_2
218
+ return
219
+
220
+
221
+ @app.cell
222
+ def loaded_texts(
223
+ create_temp_files_from_uploads,
224
+ file_loader,
225
+ pdf_reader,
226
+ run_upload_button,
227
+ set_text_state,
228
+ ):
229
+ if file_loader.value is not None and run_upload_button.value:
230
+ filepaths = create_temp_files_from_uploads(file_loader.value)
231
+ loaded_texts = load_pdf_data_with_progress(pdf_reader, filepaths, file_loader.value, show_progress=True)
232
+
233
+ set_text_state(loaded_texts)
234
+ else:
235
+ filepaths = None
236
+ loaded_texts = None
237
+ return
238
+
239
+
240
+ @app.cell
241
+ def accordion_chunker_setup(chunker_setup):
242
+ ui_accordion_part_1_3 = mo.accordion(
243
+ {
244
+ "Chunker Setup": chunker_setup
245
+ }
246
+ )
247
+
248
+ ui_accordion_part_1_3
249
+ return
250
+
251
+
252
+ @app.cell
253
+ def chunk_documents_to_nodes(
254
+ get_text_state,
255
+ sentence_splitter,
256
+ sentence_splitter_config,
257
+ set_chunk_state,
258
+ ):
259
+ if sentence_splitter_config.value and sentence_splitter and get_text_state() is not None:
260
+ chunked_texts = chunk_documents(get_text_state(), sentence_splitter, show_progress=True)
261
+ set_chunk_state(chunked_texts)
262
+ else:
263
+ chunked_texts = None
264
+ return (chunked_texts,)
265
+
266
+
267
+ @app.cell
268
+ def _():
269
+ mo.md(r"""###Part 2 - Query Setup and Visualization""")
270
+ return
271
+
272
+
273
+ @app.cell
274
+ def accordion_chunk_range(chart_range_selection):
275
+ ui_accordion_part_2_1 = mo.accordion(
276
+ {
277
+ "Chunk Range Selection": chart_range_selection
278
+ }
279
+ )
280
+ ui_accordion_part_2_1
281
+ return
282
+
283
+
284
+ @app.cell
285
+ def chunk_embedding(
286
+ chunks_to_process,
287
+ embedding,
288
+ sentence_splitter_config,
289
+ set_embedding_state,
290
+ ):
291
+ if sentence_splitter_config.value is not None and chunks_to_process is not None:
292
+ with mo.status.spinner(title="Embedding Documents...", remove_on_exit=True) as _spinner:
293
+ output_embeddings = embedding.embed_documents(chunks_to_process)
294
+ _spinner.update("Almost Done")
295
+ time.sleep(1.5)
296
+ set_embedding_state(output_embeddings)
297
+ _spinner.update("Documents Embedded")
298
+ else:
299
+ output_embeddings = None
300
+ return
301
+
302
+
303
+ @app.cell
304
+ def preview_chunks(chunks_dict):
305
+ if chunks_dict is not None:
306
+ stats = create_stats(chunks_dict,
307
+ bordered=True,
308
+ object_names=['text','text'],
309
+ group_by_row=True,
310
+ items_per_row=5,
311
+ gap=1,
312
+ label="Chunk")
313
+ ui_chunk_viewer = mo.accordion(
314
+ {
315
+ "View Chunks": stats,
316
+ }
317
+ )
318
+ else:
319
+ ui_chunk_viewer = None
320
+
321
+ ui_chunk_viewer
322
+ return
323
+
324
+
325
+ @app.cell
326
+ def accordion_query_view(chart_visualization, query_stack):
327
+ ui_accordion_part_2_2 = mo.accordion(
328
+ {
329
+ "Query": mo.vstack([query_stack, mo.hstack([chart_visualization])], align="center", gap=3)
330
+ }
331
+ )
332
+ ui_accordion_part_2_2
333
+ return
334
+
335
+
336
+ @app.cell
337
+ def chunker_setup(sentence_splitter_config):
338
+ chunker_setup = mo.hstack([sentence_splitter_config], justify="space-around", align="center", widths=[0.55])
339
+ return (chunker_setup,)
340
+
341
+
342
+ @app.cell
343
+ def file_and_model_select(
344
+ file_loader,
345
+ get_embedding_model_list,
346
+ run_upload_button,
347
+ ):
348
+ select_stack = mo.hstack([get_embedding_model_list(), mo.vstack([file_loader, run_upload_button], align="center")], justify="space-around", align="center", widths=[0.3,0.3])
349
+ return (select_stack,)
350
+
351
+
352
+ @app.cell
353
+ def client_instantiation_form():
354
+ # Endpoints
355
+ wx_platform_url = "https://api.dataplatform.cloud.ibm.com"
356
+ regions = {
357
+ "US": "https://us-south.ml.cloud.ibm.com",
358
+ "EU": "https://eu-de.ml.cloud.ibm.com",
359
+ "GB": "https://eu-gb.ml.cloud.ibm.com",
360
+ "JP": "https://jp-tok.ml.cloud.ibm.com",
361
+ "AU": "https://au-syd.ml.cloud.ibm.com",
362
+ "CA": "https://ca-tor.ml.cloud.ibm.com"
363
+ }
364
+
365
+ # Create a form with multiple elements
366
+ client_instantiation_form = (
367
+ mo.md('''
368
+ ###**watsonx.ai credentials:**
369
+
370
+ {wx_region}
371
+
372
+ {wx_api_key}
373
+
374
+ {project_id}
375
+
376
+ {space_id}
377
+ ''')
378
+ .batch(
379
+ wx_region = mo.ui.dropdown(regions, label="Select your watsonx.ai region:", value="US", searchable=True),
380
+ wx_api_key = mo.ui.text(placeholder="Add your IBM Cloud api-key...", label="IBM Cloud Api-key:",
381
+ kind="password", value=get_cred_value('api_key', creds_var_name='baked_in_creds')),
382
+ project_id = mo.ui.text(placeholder="Add your watsonx.ai project_id...", label="Project_ID:",
383
+ kind="text", value=get_cred_value('project_id', creds_var_name='baked_in_creds')),
384
+ space_id = mo.ui.text(placeholder="Add your watsonx.ai space_id...", label="Space_ID:",
385
+ kind="text", value=get_cred_value('space_id', creds_var_name='baked_in_creds'))
386
+ ,)
387
+ .form(show_clear_button=True, bordered=False)
388
+ )
389
+ return (client_instantiation_form,)
390
+
391
+
392
+ @app.cell
393
+ def instantiation_status(
394
+ client_callout_kind,
395
+ client_instantiation_form,
396
+ client_status,
397
+ ):
398
+ client_callout = mo.callout(client_status, kind=client_callout_kind)
399
+ client_stack = mo.hstack([client_instantiation_form, client_callout], align="center", justify="space-around", gap=10)
400
+ return (client_stack,)
401
+
402
+
403
+ @app.cell
404
+ def client_selector(deployment_client, project_client):
405
+ if deployment_client is not None:
406
+ client_options = {"Deployment Client":deployment_client}
407
+
408
+ elif project_client is not None:
409
+ client_options = {"Project Client":project_client}
410
+
411
+ elif project_client is not None and deployment_client is not None:
412
+ client_options = {"Project Client":project_client,"Deployment Client":deployment_client}
413
+
414
+ else:
415
+ client_options = {"No Client": "Instantiate a Client"}
416
+
417
+ default_client = next(iter(client_options))
418
+ client_selector = mo.ui.dropdown(client_options, value=default_client, label="**Select your active client:**")
419
+
420
+ return (client_selector,)
421
+
422
+
423
+ @app.cell
424
+ def active_client(client_selector):
425
+ client = client_selector.value
426
+ return (client,)
427
+
428
+
429
+ @app.cell
430
+ def emb_model_selection(client, set_embedding_model_list):
431
+ if client:
432
+ model_specs = client.foundation_models.get_embeddings_model_specs()
433
+ # model_specs = client.foundation_models.get_model_specs()
434
+ resources = model_specs["resources"]
435
+ # Define embedding models reference data
436
+ embedding_models = {
437
+ "ibm/granite-embedding-107m-multilingual": {"max_tokens": 512, "embedding_dimensions": 384},
438
+ "ibm/granite-embedding-278m-multilingual": {"max_tokens": 512, "embedding_dimensions": 768},
439
+ "ibm/slate-125m-english-rtrvr-v2": {"max_tokens": 512, "embedding_dimensions": 768},
440
+ "ibm/slate-125m-english-rtrvr": {"max_tokens": 512, "embedding_dimensions": 768},
441
+ "ibm/slate-30m-english-rtrvr-v2": {"max_tokens": 512, "embedding_dimensions": 384},
442
+ "ibm/slate-30m-english-rtrvr": {"max_tokens": 512, "embedding_dimensions": 384},
443
+ "sentence-transformers/all-minilm-l6-v2": {"max_tokens": 128, "embedding_dimensions": 384},
444
+ "sentence-transformers/all-minilm-l12-v2": {"max_tokens": 128, "embedding_dimensions": 384},
445
+ "intfloat/multilingual-e5-large": {"max_tokens": 512, "embedding_dimensions": 1024}
446
+ }
447
+
448
+ # Get model IDs from resources
449
+ model_id_list = []
450
+ for resource in resources:
451
+ model_id_list.append(resource["model_id"])
452
+
453
+ # Create enhanced model data for the table
454
+ embedding_model_data = []
455
+ for model_id in model_id_list:
456
+ model_entry = {"model_id": model_id}
457
+
458
+ # Add properties if model exists in our reference, otherwise use 0
459
+ if model_id in embedding_models:
460
+ model_entry["max_tokens"] = embedding_models[model_id]["max_tokens"]
461
+ model_entry["embedding_dimensions"] = embedding_models[model_id]["embedding_dimensions"]
462
+ else:
463
+ model_entry["max_tokens"] = 0
464
+ model_entry["embedding_dimensions"] = 0
465
+
466
+ embedding_model_data.append(model_entry)
467
+
468
+ embedding_model_selection = mo.ui.table(
469
+ embedding_model_data,
470
+ selection="single", # Only allow selecting one row
471
+ label="Select an embedding model to use.",
472
+ page_size=30,
473
+ initial_selection=[1]
474
+ )
475
+ set_embedding_model_list(embedding_model_selection)
476
+ else:
477
+ default_model_data = [{
478
+ "model_id": "ibm/granite-embedding-107m-multilingual",
479
+ "max_tokens": 512,
480
+ "embedding_dimensions": 384
481
+ }]
482
+
483
+ set_embedding_model_list(create_emb_model_selection_table(default_model_data, initial_selection=0, selection_type="single", label="Select a model to use."))
484
+ return
485
+
486
+
487
+ @app.function
488
+ def create_emb_model_selection_table(model_data, initial_selection=0, selection_type="single", label="Select a model to use."):
489
+ embedding_model_selection = mo.ui.table(
490
+ model_data,
491
+ selection=selection_type, # Only allow selecting one row
492
+ label=label,
493
+ page_size=30,
494
+ initial_selection=[initial_selection]
495
+ )
496
+ return embedding_model_selection
497
+
498
+
499
+ @app.cell
500
+ def embedding_model():
501
+ get_embedding_model_list, set_embedding_model_list = mo.state(None)
502
+ return get_embedding_model_list, set_embedding_model_list
503
+
504
+
505
+ @app.cell
506
+ def emb_model_parameters(emb_model_max_tk):
507
+ from ibm_watsonx_ai.foundation_models import Embeddings
508
+ from ibm_watsonx_ai.metanames import EmbedTextParamsMetaNames as EmbedParams
509
+
510
+ embed_params = {
511
+ EmbedParams.TRUNCATE_INPUT_TOKENS: emb_model_max_tk,
512
+ EmbedParams.RETURN_OPTIONS: {
513
+ 'input_text': True
514
+ }
515
+ }
516
+ return Embeddings, embed_params
517
+
518
+
519
+ @app.cell
520
+ def emb_model_state(get_embedding_model_list):
521
+ embedding_model = get_embedding_model_list()
522
+ return (embedding_model,)
523
+
524
+
525
+ @app.cell
526
+ def emb_model_setup(embedding_model):
527
+ emb_model = embedding_model.value[0]['model_id']
528
+ emb_model_max_tk = embedding_model.value[0]['max_tokens']
529
+ emb_model_emb_dim = embedding_model.value[0]['embedding_dimensions']
530
+ return emb_model, emb_model_emb_dim, emb_model_max_tk
531
+
532
+
533
+ @app.cell
534
+ def emb_model_instantiation(Embeddings, client, emb_model, embed_params):
535
+ if client is not None:
536
+ embedding = Embeddings(
537
+ model_id=emb_model,
538
+ api_client=client,
539
+ params=embed_params,
540
+ batch_size=1000,
541
+ concurrency_limit=10
542
+ )
543
+ else:
544
+ embedding = None
545
+ return (embedding,)
546
+
547
+
548
+ @app.cell
549
+ def _():
550
+ get_embedding_state, set_embedding_state = mo.state(None)
551
+ return get_embedding_state, set_embedding_state
552
+
553
+
554
+ @app.cell
555
+ def _():
556
+ get_query_state, set_query_state = mo.state(None)
557
+ return get_query_state, set_query_state
558
+
559
+
560
+ @app.cell
561
+ def file_loader_input():
562
+ file_loader = mo.ui.file(
563
+ kind="area",
564
+ filetypes=[".pdf"],
565
+ label=" Load .pdf files ",
566
+ multiple=True
567
+ )
568
+ return (file_loader,)
569
+
570
+
571
+ @app.cell
572
+ def file_loader_run(file_loader):
573
+ if file_loader.value is not None:
574
+ run_upload_button = mo.ui.run_button(label="Load Files")
575
+ else:
576
+ run_upload_button = mo.ui.run_button(disabled=True, label="Load Files")
577
+ return (run_upload_button,)
578
+
579
+
580
+ @app.cell
581
+ def helper_function_tempfiles():
582
+ def create_temp_files_from_uploads(upload_results) -> List[str]:
583
+ """
584
+ Creates temporary files from a tuple of FileUploadResults objects and returns their paths.
585
+ Args:
586
+ upload_results: Object containing a value attribute that is a tuple of FileUploadResults
587
+ Returns:
588
+ List of temporary file paths
589
+ """
590
+ temp_file_paths = []
591
+
592
+ # Get the number of items in the tuple
593
+ num_items = len(upload_results)
594
+
595
+ # Process each item by index
596
+ for i in range(num_items):
597
+ result = upload_results[i] # Get item by index
598
+
599
+ # Create a temporary file with the original filename
600
+ temp_dir = tempfile.gettempdir()
601
+ file_name = result.name
602
+ temp_path = os.path.join(temp_dir, file_name)
603
+ # Write the contents to the temp file
604
+ with open(temp_path, 'wb') as temp_file:
605
+ temp_file.write(result.contents)
606
+ # Add the path to our list
607
+ temp_file_paths.append(temp_path)
608
+
609
+ return temp_file_paths
610
+
611
+ def cleanup_temp_files(temp_file_paths: List[str]) -> None:
612
+ """Delete temporary files after use."""
613
+ for path in temp_file_paths:
614
+ if os.path.exists(path):
615
+ os.unlink(path)
616
+ return (create_temp_files_from_uploads,)
617
+
618
+
619
+ @app.function
620
+ def load_pdf_data_with_progress(pdf_reader, filepaths, file_loader_value, show_progress=True):
621
+ """
622
+ Loads PDF data for each file path and organizes results by original filename.
623
+ Args:
624
+ pdf_reader: The PyMuPDFReader instance
625
+ filepaths: List of temporary file paths
626
+ file_loader_value: The original upload results value containing file information
627
+ show_progress: Whether to show a progress bar during loading (default: False)
628
+ Returns:
629
+ Dictionary mapping original filenames to their loaded text content
630
+ """
631
+ results = {}
632
+
633
+ # Process files with or without progress bar
634
+ if show_progress:
635
+ import marimo as mo
636
+ # Use progress bar with the length of filepaths as total
637
+ with mo.status.progress_bar(
638
+ total=len(filepaths),
639
+ title="Loading PDFs",
640
+ subtitle="Processing documents...",
641
+ completion_title="PDF Loading Complete",
642
+ completion_subtitle=f"{len(filepaths)} documents processed",
643
+ remove_on_exit=True
644
+ ) as bar:
645
+ # Process each file path
646
+ for i, file_path in enumerate(filepaths):
647
+
648
+ original_file_name = file_loader_value[i].name
649
+ bar.update(subtitle=f"Processing {original_file_name}...")
650
+ loaded_text = pdf_reader.load_data(file_path=file_path, metadata=True)
651
+
652
+ # Store the result with the original filename as the key
653
+ results[original_file_name] = loaded_text
654
+ # Update progress bar
655
+ bar.update(increment=1)
656
+ else:
657
+ # Original logic without progress bar
658
+ for i, file_path in enumerate(filepaths):
659
+ original_file_name = file_loader_value[i].name
660
+ loaded_text = pdf_reader.load_data(file_path=file_path, metadata=True)
661
+ results[original_file_name] = loaded_text
662
+
663
+ return results
664
+
665
+
666
+ @app.cell
667
+ def file_readers():
668
+ from llama_index.readers.file import PyMuPDFReader
669
+ from llama_index.readers.file import FlatReader
670
+ from llama_index.core.node_parser import SentenceSplitter
671
+
672
+ ### File Readers
673
+ pdf_reader = PyMuPDFReader()
674
+ # flat_file_reader = FlatReader()
675
+ return SentenceSplitter, pdf_reader
676
+
677
+
678
+ @app.cell
679
+ def sentence_splitter_setup():
680
+ ### Chunker Setup
681
+ sentence_splitter_config = (
682
+ mo.md('''
683
+ ###**Chunking Setup:**
684
+
685
+ > Unless you want to do some advanced sentence splitting, it's best to stick to adjusting only the chunk size and overlap. Changing the other settings might result in unexpected results.
686
+
687
+ Separator value is set to **" "** by default, while the paragraph separator is **"\\n\\n\\n"**.
688
+
689
+ {chunk_size} {chunk_overlap}
690
+
691
+ {separator} {paragraph_separator}
692
+
693
+ {secondary_chunking_regex}
694
+
695
+ {include_metadata}
696
+
697
+ ''')
698
+ .batch(
699
+ chunk_size = mo.ui.slider(start=100, stop=5000, step=1, label="**Chunk SizeL**", value=350, show_value=True),
700
+ chunk_overlap = mo.ui.slider(start=1, stop=1000, step=1, label="**Chunk Overlap:**", value=50, show_value=True),
701
+ separator = mo.ui.text(placeholder="Define a separator", label="**Separator:**", kind="text", value=" "),
702
+ paragraph_separator = mo.ui.text(placeholder="Define a paragraph separator",
703
+ label="**Paragraph Separator:**", kind="text",
704
+ value="\n\n\n"),
705
+ secondary_chunking_regex = mo.ui.text(placeholder="Define a secondary chunking regex",
706
+ label="**Chunking Regex:**", kind="text",
707
+ value="[^,.;?!]+[,.;?!]?"),
708
+ include_metadata= mo.ui.checkbox(value=True, label="**Include Metadata**")
709
+ )
710
+ .form(show_clear_button=True, bordered=False)
711
+ )
712
+ return (sentence_splitter_config,)
713
+
714
+
715
+ @app.cell
716
+ def sentence_splitter_instantiation(
717
+ SentenceSplitter,
718
+ sentence_splitter_config,
719
+ ):
720
+ ### Chunker/Sentence Splitter
721
+ def simple_whitespace_tokenizer(text):
722
+ return text.split()
723
+
724
+ if sentence_splitter_config.value is not None:
725
+ sentence_splitter_config_values = sentence_splitter_config.value
726
+ validated_chunk_overlap = min(sentence_splitter_config_values.get("chunk_overlap"),
727
+ int(sentence_splitter_config_values.get("chunk_size") * 0.3))
728
+
729
+ sentence_splitter = SentenceSplitter(
730
+ chunk_size=sentence_splitter_config_values.get("chunk_size"),
731
+ chunk_overlap=validated_chunk_overlap,
732
+ separator=sentence_splitter_config_values.get("separator"),
733
+ paragraph_separator=sentence_splitter_config_values.get("paragraph_separator"),
734
+ secondary_chunking_regex=sentence_splitter_config_values.get("secondary_chunking_regex"),
735
+ include_metadata=sentence_splitter_config_values.get("include_metadata"),
736
+ tokenizer=simple_whitespace_tokenizer
737
+ )
738
+
739
+ else:
740
+ sentence_splitter = SentenceSplitter(
741
+ chunk_size=2048,
742
+ chunk_overlap=204,
743
+ separator=" ",
744
+ paragraph_separator="\n\n\n",
745
+ secondary_chunking_regex="[^,.;?!]+[,.;?!]?",
746
+ include_metadata=True,
747
+ tokenizer=simple_whitespace_tokenizer
748
+ )
749
+ return (sentence_splitter,)
750
+
751
+
752
+ @app.cell
753
+ def text_state():
754
+ get_text_state, set_text_state = mo.state(None)
755
+ return get_text_state, set_text_state
756
+
757
+
758
+ @app.cell
759
+ def chunk_state():
760
+ get_chunk_state, set_chunk_state = mo.state(None)
761
+ return get_chunk_state, set_chunk_state
762
+
763
+
764
+ @app.function
765
+ def chunk_documents(loaded_texts, sentence_splitter, show_progress=True):
766
+ """
767
+ Process each document in the loaded_texts dictionary using the sentence_splitter,
768
+ with an optional marimo progress bar tracking progress at document level.
769
+
770
+ Args:
771
+ loaded_texts (dict): Dictionary containing lists of Document objects
772
+ sentence_splitter: The sentence splitter object with get_nodes_from_documents method
773
+ show_progress (bool): Whether to show a progress bar during processing
774
+
775
+ Returns:
776
+ dict: Dictionary with the same structure but containing chunked texts
777
+ """
778
+ chunked_texts_dict = {}
779
+
780
+ # Get the total number of documents across all keys
781
+ total_docs = sum(len(docs) for docs in loaded_texts.values())
782
+ processed_docs = 0
783
+
784
+ # Process with or without progress bar
785
+ if show_progress:
786
+ import marimo as mo
787
+ # Use progress bar with the total number of documents as total
788
+ with mo.status.progress_bar(
789
+ total=total_docs,
790
+ title="Processing Documents",
791
+ subtitle="Chunking documents...",
792
+ completion_title="Processing Complete",
793
+ completion_subtitle=f"{total_docs} documents processed",
794
+ remove_on_exit=True
795
+ ) as bar:
796
+ # Process each key-value pair in the loaded_texts dictionary
797
+ for key, documents in loaded_texts.items():
798
+ # Update progress bar subtitle to show current key
799
+ doc_count = len(documents)
800
+ bar.update(subtitle=f"Chunking {key}... ({doc_count} documents)")
801
+
802
+ # Apply the sentence splitter to each list of documents
803
+ chunked_texts = sentence_splitter.get_nodes_from_documents(
804
+ documents,
805
+ show_progress=False # Disable internal progress to avoid nested bars
806
+ )
807
+
808
+ # Store the result with the same key
809
+ chunked_texts_dict[key] = chunked_texts
810
+ time.sleep(0.15)
811
+
812
+ # Update progress bar with the number of documents in this batch
813
+ bar.update(increment=doc_count)
814
+ processed_docs += doc_count
815
+ else:
816
+ # Process without progress bar
817
+ for key, documents in loaded_texts.items():
818
+ chunked_texts = sentence_splitter.get_nodes_from_documents(
819
+ documents,
820
+ show_progress=True # Use the internal progress bar if no marimo bar
821
+ )
822
+ chunked_texts_dict[key] = chunked_texts
823
+
824
+ return chunked_texts_dict
825
+
826
+
827
+ @app.cell
828
+ def chunked_nodes(chunked_texts, get_chunk_state, sentence_splitter):
829
+ if chunked_texts is not None and sentence_splitter:
830
+ chunked_documents = get_chunk_state()
831
+ else:
832
+ chunked_documents = None
833
+ return (chunked_documents,)
834
+
835
+
836
+ @app.cell
837
+ def prep_cumulative_df(chunked_documents, llamaindex_convert_docs_multi):
838
+ if chunked_documents is not None:
839
+ dict_from_nodes = llamaindex_convert_docs_multi(chunked_documents)
840
+ nodes_from_dict = llamaindex_convert_docs_multi(dict_from_nodes)
841
+ else:
842
+ dict_from_nodes = None
843
+ nodes_from_dict = None
844
+ return (dict_from_nodes,)
845
+
846
+
847
+ @app.cell
848
+ def chunks_to_process(
849
+ dict_from_nodes,
850
+ document_range_stack,
851
+ get_data_in_range_triplequote,
852
+ ):
853
+ if dict_from_nodes is not None and document_range_stack.value is not None:
854
+
855
+ chunk_dict_df = create_cumulative_dataframe(dict_from_nodes)
856
+
857
+ if document_range_stack.value is not None:
858
+ chunk_start_idx = document_range_stack.value[0]
859
+ chunk_end_idx = document_range_stack.value[1]
860
+ else:
861
+ chunk_start_idx = 0
862
+ chunk_end_idx = len(chunk_dict_df)
863
+
864
+ chunk_range_index = [chunk_start_idx, chunk_end_idx]
865
+ chunks_dict = get_data_in_range_triplequote(chunk_dict_df,
866
+ index_range=chunk_range_index,
867
+ columns_to_include=["text"])
868
+
869
+ chunks_to_process = chunks_dict['text'] if 'text' in chunks_dict else []
870
+ else:
871
+ chunk_objects = None
872
+ chunks_dict = None
873
+ chunks_to_process = None
874
+ return chunks_dict, chunks_to_process
875
+
876
+
877
+ @app.cell
878
+ def helper_function_doc_formatting():
879
+ def llamaindex_convert_docs_multi(items):
880
+ """
881
+ Automatically convert between document objects and dictionaries.
882
+
883
+ This function handles:
884
+ - Converting dictionaries to document objects
885
+ - Converting document objects to dictionaries
886
+ - Processing lists or individual items
887
+ - Supporting dictionary structures where values are lists of documents
888
+
889
+ Args:
890
+ items: A document object, dictionary, or list of either.
891
+ Can also be a dictionary mapping filenames to lists of documents.
892
+
893
+ Returns:
894
+ Converted item(s) maintaining the original structure
895
+ """
896
+ # Handle empty or None input
897
+ if not items:
898
+ return []
899
+
900
+ # Handle dictionary mapping filenames to document lists (from load_pdf_data)
901
+ if isinstance(items, dict) and all(isinstance(v, list) for v in items.values()):
902
+ result = {}
903
+ for filename, doc_list in items.items():
904
+ result[filename] = llamaindex_convert_docs(doc_list)
905
+ return result
906
+
907
+ # Handle single items (not in a list)
908
+ if not isinstance(items, list):
909
+ # Single dictionary to document
910
+ if isinstance(items, dict):
911
+ # Determine document class
912
+ doc_class = None
913
+ if 'doc_type' in items:
914
+ import importlib
915
+ module_path, class_name = items['doc_type'].rsplit('.', 1)
916
+ module = importlib.import_module(module_path)
917
+ doc_class = getattr(module, class_name)
918
+ if not doc_class:
919
+ from llama_index.core.schema import Document
920
+ doc_class = Document
921
+ return doc_class.from_dict(items)
922
+ # Single document to dictionary
923
+ elif hasattr(items, 'to_dict'):
924
+ return items.to_dict()
925
+ # Return as is if can't convert
926
+ return items
927
+
928
+ # Handle list input
929
+ result = []
930
+
931
+ # Handle empty list
932
+ if len(items) == 0:
933
+ return result
934
+
935
+ # Determine the type of conversion based on the first non-None item
936
+ first_item = next((item for item in items if item is not None), None)
937
+
938
+ # If we found no non-None items, return empty list
939
+ if first_item is None:
940
+ return result
941
+
942
+ # Convert dictionaries to documents
943
+ if isinstance(first_item, dict):
944
+ # Get the right document class from the items themselves
945
+ doc_class = None
946
+ # Try to get doc class from metadata if available
947
+ if 'doc_type' in first_item:
948
+ import importlib
949
+ module_path, class_name = first_item['doc_type'].rsplit('.', 1)
950
+ module = importlib.import_module(module_path)
951
+ doc_class = getattr(module, class_name)
952
+ if not doc_class:
953
+ # Fallback to default Document class from llama_index
954
+ from llama_index.core.schema import Document
955
+ doc_class = Document
956
+
957
+ # Convert each dictionary to document
958
+ for item in items:
959
+ if isinstance(item, dict):
960
+ result.append(doc_class.from_dict(item))
961
+ elif item is None:
962
+ result.append(None)
963
+ elif isinstance(item, list):
964
+ result.append(llamaindex_convert_docs(item))
965
+ else:
966
+ result.append(item)
967
+
968
+ # Convert documents to dictionaries
969
+ else:
970
+ for item in items:
971
+ if hasattr(item, 'to_dict'):
972
+ result.append(item.to_dict())
973
+ elif item is None:
974
+ result.append(None)
975
+ elif isinstance(item, list):
976
+ result.append(llamaindex_convert_docs(item))
977
+ else:
978
+ result.append(item)
979
+
980
+ return result
981
+
982
+ def llamaindex_convert_docs(items):
983
+ """
984
+ Automatically convert between document objects and dictionaries.
985
+
986
+ Args:
987
+ items: A list of document objects or dictionaries
988
+
989
+ Returns:
990
+ List of converted items (dictionaries or document objects)
991
+ """
992
+ result = []
993
+
994
+ # Handle empty or None input
995
+ if not items:
996
+ return result
997
+
998
+ # Determine the type of conversion based on the first item
999
+ if isinstance(items[0], dict):
1000
+ # Get the right document class from the items themselves
1001
+ # Look for a 'doc_type' or '__class__' field in the dictionary
1002
+ doc_class = None
1003
+
1004
+ # Try to get doc class from metadata if available
1005
+ if 'doc_type' in items[0]:
1006
+ import importlib
1007
+ module_path, class_name = items[0]['doc_type'].rsplit('.', 1)
1008
+ module = importlib.import_module(module_path)
1009
+ doc_class = getattr(module, class_name)
1010
+
1011
+ if not doc_class:
1012
+ # Fallback to default Document class from llama_index
1013
+ from llama_index.core.schema import Document
1014
+ doc_class = Document
1015
+
1016
+ # Convert dictionaries to documents
1017
+ for item in items:
1018
+ if isinstance(item, dict):
1019
+ result.append(doc_class.from_dict(item))
1020
+ else:
1021
+ # Convert documents to dictionaries
1022
+ for item in items:
1023
+ if hasattr(item, 'to_dict'):
1024
+ result.append(item.to_dict())
1025
+
1026
+ return result
1027
+ return (llamaindex_convert_docs_multi,)
1028
+
1029
+
1030
+ @app.cell
1031
+ def helper_function_create_df():
1032
+ def create_document_dataframes(dict_from_docs):
1033
+ """
1034
+ Creates a pandas DataFrame for each file in the dictionary.
1035
+
1036
+ Args:
1037
+ dict_from_docs: Dictionary mapping filenames to lists of documents
1038
+
1039
+ Returns:
1040
+ List of pandas DataFrames, each representing all documents from a single file
1041
+ """
1042
+ dataframes = []
1043
+
1044
+ for filename, docs in dict_from_docs.items():
1045
+ # Create a list to hold all document records for this file
1046
+ file_records = []
1047
+
1048
+ for i, doc in enumerate(docs):
1049
+ # Convert the document to a format compatible with DataFrame
1050
+ if hasattr(doc, 'to_dict'):
1051
+ doc_data = doc.to_dict()
1052
+ elif isinstance(doc, dict):
1053
+ doc_data = doc
1054
+ else:
1055
+ doc_data = {'content': str(doc)}
1056
+
1057
+ # Add document index information
1058
+ doc_data['doc_index'] = i
1059
+
1060
+ # Add to the list of records for this file
1061
+ file_records.append(doc_data)
1062
+
1063
+ # Create a single DataFrame for all documents in this file
1064
+ if file_records:
1065
+ df = pd.DataFrame(file_records)
1066
+ df['filename'] = filename # Add filename as a column
1067
+ dataframes.append(df)
1068
+
1069
+ return dataframes
1070
+
1071
+ def create_dataframe_previews(dataframe_list, page_size=5):
1072
+ """
1073
+ Creates a list of mo.ui.dataframe components, one for each DataFrame in the input list.
1074
+
1075
+ Args:
1076
+ dataframe_list: List of pandas DataFrames (output from create_document_dataframes)
1077
+ page_size: Number of rows to show per page for each component
1078
+
1079
+ Returns:
1080
+ List of mo.ui.dataframe components
1081
+ """
1082
+ # Create a list of mo.ui.dataframe components
1083
+ preview_components = []
1084
+
1085
+ for df in dataframe_list:
1086
+ # Create a mo.ui.dataframe component for this DataFrame
1087
+ preview = mo.ui.dataframe(df, page_size=page_size)
1088
+ preview_components.append(preview)
1089
+
1090
+ return preview_components
1091
+ return
1092
+
1093
+
1094
+ @app.cell
1095
+ def helper_function_chart_preparation():
1096
+ import altair as alt
1097
+ import numpy as np
1098
+ import plotly.express as px
1099
+ from sklearn.manifold import TSNE
1100
+
1101
+ def prepare_embedding_data(embeddings, texts, model_id=None, embedding_dimensions=None):
1102
+ """
1103
+ Prepare embedding data for visualization
1104
+
1105
+ Args:
1106
+ embeddings: List of embeddings arrays
1107
+ texts: List of text strings
1108
+ model_id: Embedding model ID (optional)
1109
+ embedding_dimensions: Embedding dimensions (optional)
1110
+
1111
+ Returns:
1112
+ DataFrame with processed data and metadata
1113
+ """
1114
+ # Flatten embeddings (in case they're nested)
1115
+ flattened_embeddings = []
1116
+ for emb in embeddings:
1117
+ if isinstance(emb, list) and len(emb) > 0 and isinstance(emb[0], list):
1118
+ flattened_embeddings.append(emb[0]) # Take first element if nested
1119
+ else:
1120
+ flattened_embeddings.append(emb)
1121
+
1122
+ # Convert to numpy array
1123
+ embedding_array = np.array(flattened_embeddings)
1124
+
1125
+ # Apply dimensionality reduction (t-SNE)
1126
+ tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(embedding_array)-1))
1127
+ reduced_embeddings = tsne.fit_transform(embedding_array)
1128
+
1129
+ # Create truncated texts for display
1130
+ truncated_texts = [text[:50] + "..." if len(text) > 50 else text for text in texts]
1131
+
1132
+ # Create dataframe for visualization
1133
+ df = pd.DataFrame({
1134
+ "x": reduced_embeddings[:, 0],
1135
+ "y": reduced_embeddings[:, 1],
1136
+ "text": truncated_texts,
1137
+ "full_text": texts,
1138
+ "index": range(len(texts))
1139
+ })
1140
+
1141
+ # Add metadata
1142
+ metadata = {
1143
+ "model_id": model_id,
1144
+ "embedding_dimensions": embedding_dimensions
1145
+ }
1146
+
1147
+ return df, metadata
1148
+
1149
+ def create_embedding_chart(df, metadata=None):
1150
+ """
1151
+ Create an Altair chart for embedding visualization
1152
+
1153
+ Args:
1154
+ df: DataFrame with x, y coordinates and text
1155
+ metadata: Dictionary with model_id and embedding_dimensions
1156
+
1157
+ Returns:
1158
+ Altair chart
1159
+ """
1160
+ model_id = metadata.get("model_id") if metadata else None
1161
+ embedding_dimensions = metadata.get("embedding_dimensions") if metadata else None
1162
+
1163
+ selection = alt.selection_multi(fields=['index'])
1164
+
1165
+ base = alt.Chart(df).encode(
1166
+ x=alt.X("x:Q", title="Dimension 1"),
1167
+ y=alt.Y("y:Q", title="Dimension 2"),
1168
+ tooltip=["text", "index"]
1169
+ )
1170
+
1171
+ points = base.mark_circle(size=100).encode(
1172
+ color=alt.Color("index:N", legend=None),
1173
+ opacity=alt.condition(selection, alt.value(1), alt.value(0.2))
1174
+ ).add_selection(selection) # Add this line to apply the selection
1175
+
1176
+ text = base.mark_text(align="left", dx=7).encode(
1177
+ text="index:N"
1178
+ )
1179
+
1180
+ return (points + text).properties(
1181
+ width=700,
1182
+ height=500,
1183
+ title=f"Embedding Visualization{f' - Model: {model_id}' if model_id else ''}{f' ({embedding_dimensions} dimensions)' if embedding_dimensions else ''}"
1184
+ ).interactive()
1185
+
1186
+ def show_selected_text(indices, texts):
1187
+ """
1188
+ Create markdown display for selected texts
1189
+
1190
+ Args:
1191
+ indices: List of selected indices
1192
+ texts: List of all texts
1193
+
1194
+ Returns:
1195
+ Markdown string
1196
+ """
1197
+ if not indices:
1198
+ return "No text selected"
1199
+
1200
+ selected_texts = [texts[i] for i in indices if i < len(texts)]
1201
+ return "\n\n".join([f"**Document {i}**:\n{text}" for i, text in zip(indices, selected_texts)])
1202
+
1203
+ def prepare_embedding_data_3d(embeddings, texts, model_id=None, embedding_dimensions=None):
1204
+ """
1205
+ Prepare embedding data for 3D visualization
1206
+
1207
+ Args:
1208
+ embeddings: List of embeddings arrays
1209
+ texts: List of text strings
1210
+ model_id: Embedding model ID (optional)
1211
+ embedding_dimensions: Embedding dimensions (optional)
1212
+
1213
+ Returns:
1214
+ DataFrame with processed data and metadata
1215
+ """
1216
+ # Flatten embeddings (in case they're nested)
1217
+ flattened_embeddings = []
1218
+ for emb in embeddings:
1219
+ if isinstance(emb, list) and len(emb) > 0 and isinstance(emb[0], list):
1220
+ flattened_embeddings.append(emb[0])
1221
+ else:
1222
+ flattened_embeddings.append(emb)
1223
+
1224
+ # Convert to numpy array
1225
+ embedding_array = np.array(flattened_embeddings)
1226
+
1227
+ # Handle the case of a single embedding differently
1228
+ if len(embedding_array) == 1:
1229
+ # For a single point, we don't need t-SNE, just use a fixed position
1230
+ reduced_embeddings = np.array([[0.0, 0.0, 0.0]])
1231
+ else:
1232
+ # Apply dimensionality reduction to 3D
1233
+ # Fix: Ensure perplexity is at least 1.0
1234
+ perplexity_value = max(1.0, min(30, len(embedding_array)-1))
1235
+ tsne = TSNE(n_components=3, random_state=42, perplexity=perplexity_value)
1236
+ reduced_embeddings = tsne.fit_transform(embedding_array)
1237
+
1238
+ # Format texts for display
1239
+ formatted_texts = []
1240
+ for text in texts:
1241
+ # Truncate if needed
1242
+ if len(text) > 500:
1243
+ text = text[:500] + "..."
1244
+
1245
+ # Insert line breaks for wrapping
1246
+ wrapped_text = ""
1247
+ for i in range(0, len(text), 50):
1248
+ wrapped_text += text[i:i+50] + "<br>"
1249
+
1250
+ formatted_texts.append("<b>"+wrapped_text+"</b>")
1251
+
1252
+ # Create dataframe for visualization
1253
+ df = pd.DataFrame({
1254
+ "x": reduced_embeddings[:, 0],
1255
+ "y": reduced_embeddings[:, 1],
1256
+ "z": reduced_embeddings[:, 2],
1257
+ "text": formatted_texts,
1258
+ "full_text": texts,
1259
+ "index": range(len(texts)),
1260
+ "embedding": flattened_embeddings # Store the original embeddings for later use
1261
+ })
1262
+
1263
+ # Add metadata
1264
+ metadata = {
1265
+ "model_id": model_id,
1266
+ "embedding_dimensions": embedding_dimensions
1267
+ }
1268
+
1269
+ return df, metadata
1270
+
1271
+ def create_3d_embedding_chart(df, metadata=None, chart_width=1200, chart_height=800, marker_size_var: int=3):
1272
+ """
1273
+ Create a 3D Plotly chart for embedding visualization with proximity-based coloring
1274
+ """
1275
+ model_id = metadata.get("model_id") if metadata else None
1276
+ embedding_dimensions = metadata.get("embedding_dimensions") if metadata else None
1277
+
1278
+ # Calculate the proximity between points
1279
+ from scipy.spatial.distance import pdist, squareform
1280
+ # Get the coordinates as a numpy array
1281
+ coords = df[['x', 'y', 'z']].values
1282
+
1283
+ # Calculate pairwise distances
1284
+ dist_matrix = squareform(pdist(coords))
1285
+
1286
+ # For each point, find its average distance to all other points
1287
+ avg_distances = np.mean(dist_matrix, axis=1)
1288
+
1289
+ # Add this to the dataframe - smaller values = closer to other points
1290
+ df['proximity'] = avg_distances
1291
+
1292
+ # Create 3D scatter plot with proximity-based coloring
1293
+ fig = px.scatter_3d(
1294
+ df,
1295
+ x='x',
1296
+ y='y',
1297
+ z='z',
1298
+ # x='petal_length', # Changed from 'x' to 'petal_length'
1299
+ # y='petal_width', # Changed from 'y' to 'petal_width'
1300
+ # z='petal_height',
1301
+ color='proximity', # Color based on proximity
1302
+ color_continuous_scale='Viridis_r', # Reversed so closer points are warmer colors
1303
+ hover_data=['text', 'index', 'proximity'],
1304
+ labels={'x': 'Dimension 1', 'y': 'Dimension 2', 'z': 'Dimension 3', 'proximity': 'Avg Distance'},
1305
+ # labels={'x': 'Dimension 1', 'y': 'Dimension 2', 'z': 'Dimension 3', 'proximity': 'Avg Distance'},
1306
+ title=f"<b>3D Embedding Visualization</b>{f' - Model: <b>{model_id}</b>' if model_id else ''}{f' <i>({embedding_dimensions} dimensions)</i>' if embedding_dimensions else ''}",
1307
+ text='index',
1308
+ # size_max=marker_size_var
1309
+ )
1310
+
1311
+ # Update marker size and layout
1312
+ # fig.update_traces(marker=dict(size=3), selector=dict(mode='markers'))
1313
+ fig.update_traces(
1314
+ marker=dict(
1315
+ size=marker_size_var, # Very small marker size
1316
+ opacity=0.7, # Slightly transparent
1317
+ symbol="diamond", # Use circle markers (other options: "square", "diamond", "cross", "x")
1318
+ line=dict(
1319
+ width=0.5, # Very thin border
1320
+ color="white" # White outline makes small dots more visible
1321
+ )
1322
+ ),
1323
+ textfont=dict(
1324
+ color="rgba(255, 255, 255, 0.3)",
1325
+ size=8
1326
+ ),
1327
+ # hovertemplate="<b>index=%{text}</b><br>%{customdata[0]}<br><br>Avg Distance=%{customdata[2]:.4f}<extra></extra>", ### Hover Changes
1328
+ hovertemplate="text:<br><b>%{customdata[0]}</b><br>index: <b>%{text}</b><br><br>Avg Distance: <b>%{customdata[2]:.4f}</b><extra></extra>",
1329
+ hoverinfo="text+name",
1330
+ hoverlabel=dict(
1331
+ bgcolor="white", # White background for hover labels
1332
+ font_size=12 # Font size for hover text
1333
+ ),
1334
+ selector=dict(type='scatter3d')
1335
+ )
1336
+
1337
+ # Keep your existing layout settings
1338
+ fig.update_layout(
1339
+ scene=dict(
1340
+ xaxis=dict(
1341
+ title='Dimension 1',
1342
+ nticks=40,
1343
+ backgroundcolor="rgb(10, 10, 20, 0.1)",
1344
+ gridcolor="white",
1345
+ showbackground=True,
1346
+ gridwidth=0.35,
1347
+ zerolinecolor="white",
1348
+ ),
1349
+ yaxis=dict(
1350
+ title='Dimension 2',
1351
+ nticks=40,
1352
+ backgroundcolor="rgb(10, 10, 20, 0.1)",
1353
+ gridcolor="white",
1354
+ showbackground=True,
1355
+ gridwidth=0.35,
1356
+ zerolinecolor="white",
1357
+ ),
1358
+ zaxis=dict(
1359
+ title='Dimension 3',
1360
+ nticks=40,
1361
+ backgroundcolor="rgb(10, 10, 20, 0.1)",
1362
+ gridcolor="white",
1363
+ showbackground=True,
1364
+ gridwidth=0.35,
1365
+ zerolinecolor="white",
1366
+ ),
1367
+ # Control camera view angle
1368
+ camera=dict(
1369
+ up=dict(x=0, y=0, z=1),
1370
+ center=dict(x=0, y=0, z=0),
1371
+ eye=dict(x=1.25, y=1.25, z=1.25),
1372
+ ),
1373
+ aspectratio=dict(x=1, y=1, z=1),
1374
+ aspectmode='data'
1375
+ ),
1376
+ width=int(chart_width),
1377
+ height=int(chart_height),
1378
+ margin=dict(r=20, l=10, b=10, t=50),
1379
+ paper_bgcolor="rgb(0, 0, 0)",
1380
+ plot_bgcolor="rgb(0, 0, 0)",
1381
+ coloraxis_colorbar=dict(
1382
+ title="Average Distance",
1383
+ thicknessmode="pixels", thickness=20,
1384
+ lenmode="pixels", len=400,
1385
+ yanchor="top", y=1,
1386
+ ticks="outside",
1387
+ dtick=0.1
1388
+ )
1389
+ )
1390
+
1391
+ return fig
1392
+ return create_3d_embedding_chart, prepare_embedding_data_3d
1393
+
1394
+
1395
+ @app.cell
1396
+ def helper_function_text_preparation():
1397
+ def convert_table_to_json_docs(df, selected_columns=None):
1398
+ """
1399
+ Convert a pandas DataFrame or dictionary to a list of JSON documents.
1400
+ Dynamically includes columns based on user selection.
1401
+ Column names are standardized to lowercase with underscores instead of spaces
1402
+ and special characters removed.
1403
+
1404
+ Args:
1405
+ df: The DataFrame or dictionary to process
1406
+ selected_columns: List of column names to include in the output documents
1407
+
1408
+ Returns:
1409
+ list: A list of dictionaries, each representing a row as a JSON document
1410
+ """
1411
+ import pandas as pd
1412
+ import re
1413
+
1414
+ def standardize_key(key):
1415
+ """Convert a column name to lowercase with underscores instead of spaces and no special characters"""
1416
+ if not isinstance(key, str):
1417
+ return str(key).lower()
1418
+ # Replace spaces with underscores and convert to lowercase
1419
+ key = key.lower().replace(' ', '_')
1420
+ # Remove special characters (keeping alphanumeric and underscores)
1421
+ return re.sub(r'[^\w]', '', key)
1422
+
1423
+ # Handle case when input is a dictionary
1424
+ if isinstance(df, dict):
1425
+ # Filter the dictionary to include only selected columns
1426
+ if selected_columns:
1427
+ return [{standardize_key(k): df.get(k, None) for k in selected_columns}]
1428
+ else:
1429
+ # If no columns selected, return all key-value pairs with standardized keys
1430
+ return [{standardize_key(k): v for k, v in df.items()}]
1431
+
1432
+ # Handle case when df is None
1433
+ if df is None:
1434
+ return []
1435
+
1436
+ # Ensure df is a DataFrame
1437
+ if not isinstance(df, pd.DataFrame):
1438
+ try:
1439
+ df = pd.DataFrame(df)
1440
+ except:
1441
+ return [] # Return empty list if conversion fails
1442
+
1443
+ # Now check if DataFrame is empty
1444
+ if df.empty:
1445
+ return []
1446
+
1447
+ # If no columns are specifically selected, use all available columns
1448
+ if not selected_columns or not isinstance(selected_columns, list) or len(selected_columns) == 0:
1449
+ selected_columns = list(df.columns)
1450
+
1451
+ # Determine which columns exist in the DataFrame
1452
+ available_columns = []
1453
+ columns_lower = {col.lower(): col for col in df.columns if isinstance(col, str)}
1454
+
1455
+ for col in selected_columns:
1456
+ if col in df.columns:
1457
+ available_columns.append(col)
1458
+ elif isinstance(col, str) and col.lower() in columns_lower:
1459
+ available_columns.append(columns_lower[col.lower()])
1460
+
1461
+ # If no valid columns found, return empty list
1462
+ if not available_columns:
1463
+ return []
1464
+
1465
+ # Process rows
1466
+ json_docs = []
1467
+ for _, row in df.iterrows():
1468
+ doc = {}
1469
+ for col in available_columns:
1470
+ value = row[col]
1471
+ # Standardize the column name when adding to document
1472
+ std_col = standardize_key(col)
1473
+ doc[std_col] = None if pd.isna(value) else value
1474
+ json_docs.append(doc)
1475
+
1476
+ return json_docs
1477
+
1478
+ def get_column_values(df, columns_to_include):
1479
+ """
1480
+ Extract values from specified columns of a dataframe as lists.
1481
+
1482
+ Args:
1483
+ df: A pandas DataFrame
1484
+ columns_to_include: A list of column names to extract
1485
+
1486
+ Returns:
1487
+ Dictionary with column names as keys and their values as lists
1488
+ """
1489
+ result = {}
1490
+
1491
+ # Validate that columns exist in the dataframe
1492
+ valid_columns = [col for col in columns_to_include if col in df.columns]
1493
+ invalid_columns = set(columns_to_include) - set(valid_columns)
1494
+
1495
+ if invalid_columns:
1496
+ print(f"Warning: These columns don't exist in the dataframe: {list(invalid_columns)}")
1497
+
1498
+ # Extract values for each valid column
1499
+ for col in valid_columns:
1500
+ result[col] = df[col].tolist()
1501
+
1502
+ return result
1503
+
1504
+ def get_data_in_range(doc_dict_df, index_range, columns_to_include):
1505
+ """
1506
+ Extract values from specified columns of a dataframe within a given index range.
1507
+
1508
+ Args:
1509
+ doc_dict_df: The pandas DataFrame to extract data from
1510
+ index_range: An integer specifying the number of rows to include (from 0 to index_range-1)
1511
+ columns_to_include: A list of column names to extract
1512
+
1513
+ Returns:
1514
+ Dictionary with column names as keys and their values (within the index range) as lists
1515
+ """
1516
+ # Validate the index range
1517
+ max_index = len(doc_dict_df)
1518
+ if index_range <= 0:
1519
+ print(f"Warning: Invalid index range {index_range}. Must be positive.")
1520
+ return {}
1521
+
1522
+ # Adjust index_range if it exceeds the dataframe length
1523
+ if index_range > max_index:
1524
+ print(f"Warning: Index range {index_range} exceeds dataframe length {max_index}. Using maximum length.")
1525
+ index_range = max_index
1526
+
1527
+ # Slice the dataframe to get rows from 0 to index_range-1
1528
+ df_subset = doc_dict_df.iloc[:index_range]
1529
+
1530
+ # Use the provided get_column_values function to extract column data
1531
+ return get_column_values(df_subset, columns_to_include)
1532
+
1533
+ def get_data_in_range_triplequote(doc_dict_df, index_range, columns_to_include):
1534
+ """
1535
+ Extract values from specified columns of a dataframe within a given index range.
1536
+ Wraps string values with triple quotes and escapes URLs.
1537
+
1538
+ Args:
1539
+ doc_dict_df: The pandas DataFrame to extract data from
1540
+ index_range: A list of two integers specifying the start and end indices of rows to include
1541
+ (e.g., [0, 10] includes rows from index 0 to 9 inclusive)
1542
+ columns_to_include: A list of column names to extract
1543
+ """
1544
+ # Validate the index range
1545
+ start_idx, end_idx = index_range
1546
+ max_index = len(doc_dict_df)
1547
+
1548
+ # Validate start index
1549
+ if start_idx < 0:
1550
+ print(f"Warning: Invalid start index {start_idx}. Using 0 instead.")
1551
+ start_idx = 0
1552
+
1553
+ # Validate end index
1554
+ if end_idx <= start_idx:
1555
+ print(f"Warning: End index {end_idx} must be greater than start index {start_idx}. Using {start_idx + 1} instead.")
1556
+ end_idx = start_idx + 1
1557
+
1558
+ # Adjust end index if it exceeds the dataframe length
1559
+ if end_idx > max_index:
1560
+ print(f"Warning: End index {end_idx} exceeds dataframe length {max_index}. Using maximum length.")
1561
+ end_idx = max_index
1562
+
1563
+ # Slice the dataframe to get rows from start_idx to end_idx-1
1564
+ # Using .loc with slice to preserve original indices
1565
+ df_subset = doc_dict_df.iloc[start_idx:end_idx]
1566
+
1567
+ # Use the provided get_column_values function to extract column data
1568
+ result = get_column_values(df_subset, columns_to_include)
1569
+
1570
+ # Process each string result to wrap in triple quotes
1571
+ for col in result:
1572
+ if isinstance(result[col], list):
1573
+ # Create a new list with items wrapped in triple quotes
1574
+ processed_items = []
1575
+ for item in result[col]:
1576
+ if isinstance(item, str):
1577
+ # Replace http:// and https:// with escaped versions
1578
+ item = item.replace("http://", "http\\://").replace("https://", "https\\://")
1579
+ # processed_items.append('"""' + item + '"""')
1580
+ processed_items.append(item)
1581
+ else:
1582
+ processed_items.append(item)
1583
+ result[col] = processed_items
1584
+ return result
1585
+ return (get_data_in_range_triplequote,)
1586
+
1587
+
1588
+ @app.cell
1589
+ def prepare_doc_select(sentence_splitter_config):
1590
+ def prepare_document_selection(node_dict):
1591
+ """
1592
+ Creates document selection UI component.
1593
+ Args:
1594
+ node_dict: Dictionary mapping filenames to lists of documents
1595
+ Returns:
1596
+ mo.ui component for document selection
1597
+ """
1598
+ # Calculate total number of documents across all files
1599
+ total_docs = sum(len(docs) for docs in node_dict.values())
1600
+
1601
+ # Create a combined DataFrame of all documents for table selection
1602
+ all_docs_records = []
1603
+ doc_index_global = 0
1604
+ for filename, docs in node_dict.items():
1605
+ for i, doc in enumerate(docs):
1606
+ # Convert the document to a format compatible with DataFrame
1607
+ if hasattr(doc, 'to_dict'):
1608
+ doc_data = doc.to_dict()
1609
+ elif isinstance(doc, dict):
1610
+ doc_data = doc
1611
+ else:
1612
+ doc_data = {'content': str(doc)}
1613
+
1614
+ # Add metadata
1615
+ doc_data['filename'] = filename
1616
+ doc_data['doc_index'] = i
1617
+ doc_data['global_index'] = doc_index_global
1618
+ all_docs_records.append(doc_data)
1619
+ doc_index_global += 1
1620
+
1621
+ # Create UI component
1622
+ stop_value = max(total_docs, 2)
1623
+ llama_docs = mo.ui.range_slider(
1624
+ start=1,
1625
+ stop=stop_value,
1626
+ step=1,
1627
+ full_width=True,
1628
+ show_value=True,
1629
+ label="**Select a Range of Chunks to Visualize:**"
1630
+ ).form(submit_button_disabled=check_state(sentence_splitter_config.value))
1631
+
1632
+ return llama_docs
1633
+ return (prepare_document_selection,)
1634
+
1635
+
1636
+ @app.cell
1637
+ def document_range_selection(
1638
+ dict_from_nodes,
1639
+ prepare_document_selection,
1640
+ set_range_slider_state,
1641
+ ):
1642
+ if dict_from_nodes is not None:
1643
+ llama_docs = prepare_document_selection(dict_from_nodes)
1644
+ set_range_slider_state(llama_docs)
1645
+ else:
1646
+ bare_dict = {}
1647
+ llama_docs = prepare_document_selection(bare_dict)
1648
+ return
1649
+
1650
+
1651
+ @app.function
1652
+ def create_cumulative_dataframe(dict_from_docs):
1653
+ """
1654
+ Creates a cumulative DataFrame from a nested dictionary of documents.
1655
+
1656
+ Args:
1657
+ dict_from_docs: Dictionary mapping filenames to lists of documents
1658
+
1659
+ Returns:
1660
+ DataFrame with all documents flattened with global indices
1661
+ """
1662
+ # Create a list to hold all document records
1663
+ all_records = []
1664
+ global_idx = 1 # Start from 1 to match range slider expectations
1665
+
1666
+ for filename, docs in dict_from_docs.items():
1667
+ for i, doc in enumerate(docs):
1668
+ # Convert the document to a dict format
1669
+ if hasattr(doc, 'to_dict'):
1670
+ doc_data = doc.to_dict()
1671
+ elif isinstance(doc, dict):
1672
+ doc_data = doc.copy()
1673
+ else:
1674
+ doc_data = {'content': str(doc)}
1675
+
1676
+ # Add additional metadata
1677
+ doc_data['filename'] = filename
1678
+ doc_data['doc_index'] = i
1679
+ doc_data['global_index'] = global_idx
1680
+
1681
+ # If there's 'content' but no 'text', create a 'text' field
1682
+ if 'content' in doc_data and 'text' not in doc_data:
1683
+ doc_data['text'] = doc_data['content']
1684
+
1685
+ all_records.append(doc_data)
1686
+ global_idx += 1
1687
+
1688
+ # Create DataFrame from all records
1689
+ return pd.DataFrame(all_records)
1690
+
1691
+
1692
+ @app.function
1693
+ def create_stats(texts_dict, bordered=False, object_names=None, group_by_row=False, items_per_row=6, gap=2, label="Chunk"):
1694
+ """
1695
+ Create a list of stat objects for each item in the specified dictionary.
1696
+
1697
+ Parameters:
1698
+ - texts_dict (dict): Dictionary containing the text data
1699
+ - bordered (bool): Whether the stats should be bordered
1700
+ - object_names (list or tuple): Two object names to use for label and value
1701
+ [label_object, value_object]
1702
+ - group_by_row (bool): Whether to group stats in rows (horizontal stacks)
1703
+ - items_per_row (int): Number of stat objects per row when group_by_row is True
1704
+
1705
+ Returns:
1706
+ - object: A vertical stack of stat objects or rows of stat objects
1707
+ """
1708
+ if not object_names or len(object_names) < 2:
1709
+ raise ValueError("You must provide two object names as a list or tuple")
1710
+
1711
+ label_object = object_names[0]
1712
+ value_object = object_names[1]
1713
+
1714
+ # Validate that both objects exist in the dictionary
1715
+ if label_object not in texts_dict:
1716
+ raise ValueError(f"Label object '{label_object}' not found in texts_dict")
1717
+ if value_object not in texts_dict:
1718
+ raise ValueError(f"Value object '{value_object}' not found in texts_dict")
1719
+
1720
+ # Determine how many items to process (based on the label object length)
1721
+ num_items = len(texts_dict[label_object])
1722
+
1723
+ # Create individual stat objects
1724
+ individual_stats = []
1725
+ for i in range(num_items):
1726
+ stat = mo.stat(
1727
+ label=texts_dict[label_object][i],
1728
+ value=f"{label} Number: {len(texts_dict[value_object][i])}",
1729
+ bordered=bordered
1730
+ )
1731
+ individual_stats.append(stat)
1732
+
1733
+ # If grouping is not enabled, just return a vertical stack of all stats
1734
+ if not group_by_row:
1735
+ return mo.vstack(individual_stats, wrap=False)
1736
+
1737
+ # Group stats into rows based on items_per_row
1738
+ rows = []
1739
+ for i in range(0, num_items, items_per_row):
1740
+ # Get a slice of stats for this row (up to items_per_row items)
1741
+ row_stats = individual_stats[i:i+items_per_row]
1742
+ # Create a horizontal stack for this row
1743
+ widths = [0.35] * len(row_stats)
1744
+ row = mo.hstack(row_stats, gap=gap, align="start", justify="center", widths=widths)
1745
+ rows.append(row)
1746
+
1747
+ # Return a vertical stack of all rows
1748
+ return mo.vstack(rows)
1749
+
1750
+
1751
+ @app.cell
1752
+ def prepare_chart_embeddings(
1753
+ chunks_to_process,
1754
+ emb_model,
1755
+ emb_model_emb_dim,
1756
+ get_embedding_state,
1757
+ prepare_embedding_data_3d,
1758
+ ):
1759
+ # chart_dataframe, chart_metadata = None, None
1760
+ if chunks_to_process is not None and get_embedding_state() is not None:
1761
+ chart_dataframe, chart_metadata = prepare_embedding_data_3d(
1762
+ get_embedding_state(),
1763
+ chunks_to_process,
1764
+ model_id=emb_model,
1765
+ embedding_dimensions=emb_model_emb_dim
1766
+ )
1767
+ else:
1768
+ chart_dataframe, chart_metadata = None, None
1769
+ return chart_dataframe, chart_metadata
1770
+
1771
+
1772
+ @app.cell
1773
+ def chart_dims():
1774
+ chart_dimensions = (
1775
+ mo.md('''
1776
+ > **Adjust Chart Window**
1777
+
1778
+ {chart_height}
1779
+
1780
+ {chat_width}
1781
+
1782
+ ''').batch(
1783
+ chart_height = mo.ui.slider(start=500, step=30, stop=1000, label="**Height:**", value=800, show_value=True),
1784
+ chat_width = mo.ui.slider(start=900, step=50, stop=1400, label="**Width:**", value=1200, show_value=True)
1785
+ )
1786
+ )
1787
+ return (chart_dimensions,)
1788
+
1789
+
1790
+ @app.cell
1791
+ def chart_dim_values(chart_dimensions):
1792
+ chart_height = chart_dimensions.value['chart_height']
1793
+ chart_width = chart_dimensions.value['chat_width']
1794
+ return chart_height, chart_width
1795
+
1796
+
1797
+ @app.cell
1798
+ def create_baseline_chart(
1799
+ chart_dataframe,
1800
+ chart_height,
1801
+ chart_metadata,
1802
+ chart_width,
1803
+ create_3d_embedding_chart,
1804
+ ):
1805
+ if chart_dataframe is not None and chart_metadata is not None:
1806
+ emb_plot = create_3d_embedding_chart(chart_dataframe, chart_metadata, chart_width, chart_height, marker_size_var=9)
1807
+ chart = mo.ui.plotly(emb_plot)
1808
+ else:
1809
+ emb_plot = None
1810
+ chart = None
1811
+ return (emb_plot,)
1812
+
1813
+
1814
+ @app.cell
1815
+ def test_query(get_chunk_state):
1816
+ placeholder = """How can i use watsonx.data to perform vector search?"""
1817
+
1818
+ query = mo.ui.text_area(label="**Write text to check:**", full_width=True, rows=8, value=placeholder).form(show_clear_button=True, submit_button_disabled=check_state(get_chunk_state()))
1819
+ return (query,)
1820
+
1821
+
1822
+ @app.cell
1823
+ def query_stack(chart_dimensions, query):
1824
+ # query_stack = mo.hstack([query], justify="space-around", align="center", widths=[0.65])
1825
+ query_stack = mo.hstack([query, chart_dimensions], justify="space-around", align="center", gap=15)
1826
+ return (query_stack,)
1827
+
1828
+
1829
+ @app.function
1830
+ def check_state(variable):
1831
+ return variable is None
1832
+
1833
+
1834
+ @app.cell
1835
+ def helper_function_add_query_to_chart():
1836
+ def add_query_to_embedding_chart(existing_chart, query_coords, query_text, marker_size=12):
1837
+ """
1838
+ Add a query point to an existing 3D embedding chart as a large red dot.
1839
+
1840
+ Args:
1841
+ existing_chart: The existing plotly figure or chart data
1842
+ query_coords: Dictionary with 'x', 'y', 'z' coordinates for the query point
1843
+ query_text: Text of the query to display on hover
1844
+ marker_size: Size of the query marker (default: 18, typically 2x other markers)
1845
+
1846
+ Returns:
1847
+ A modified plotly figure with the query point added as a red dot
1848
+ """
1849
+ import plotly.graph_objects as go
1850
+
1851
+ # Create a deep copy of the existing chart to avoid modifying the original
1852
+ import copy
1853
+ chart_copy = copy.deepcopy(existing_chart)
1854
+
1855
+ # Handle case where chart_copy is a dictionary or list (from mo.ui.plotly)
1856
+ if isinstance(chart_copy, (dict, list)):
1857
+ # Create a new plotly figure from the data
1858
+ import plotly.graph_objects as go
1859
+
1860
+ if isinstance(chart_copy, list):
1861
+ # If it's a list, assume it's a list of traces
1862
+ fig = go.Figure(data=chart_copy)
1863
+ else:
1864
+ # If it's a dict with 'data' and 'layout'
1865
+ fig = go.Figure(data=chart_copy.get('data', []), layout=chart_copy.get('layout', {}))
1866
+
1867
+ chart_copy = fig
1868
+
1869
+ # Create the query trace
1870
+ query_trace = go.Scatter3d(
1871
+ x=[query_coords['x']],
1872
+ y=[query_coords['y']],
1873
+ z=[query_coords['z']],
1874
+ mode='markers',
1875
+ name='Query',
1876
+ marker=dict(
1877
+ size=marker_size, # Typically 2x the size of other markers
1878
+ color='red', # Bright red color
1879
+ symbol='circle', # Circle shape
1880
+ opacity=0.70, # Fully opaque
1881
+ line=dict(
1882
+ width=1, # Thin white border
1883
+ color='white'
1884
+ )
1885
+ ),
1886
+ # text=['Query: ' + query_text],
1887
+ text=['<b>Query:</b><br>' + '<br>'.join([query_text[i:i+50] for i in range(0, len(query_text), 50)])], ### Text Wrapping
1888
+ hoverinfo="text+name"
1889
+ )
1890
+
1891
+ # Add the query trace to the chart copy
1892
+ chart_copy.add_trace(query_trace)
1893
+
1894
+ return chart_copy
1895
+
1896
+
1897
+ def get_query_coordinates(reference_embeddings=None, query_embedding=None):
1898
+ """
1899
+ Calculate appropriate coordinates for a query point based on reference embeddings.
1900
+
1901
+ This function handles several scenarios:
1902
+ 1. If both reference embeddings and query embedding are provided, it places the
1903
+ query near similar documents.
1904
+ 2. If only reference embeddings are provided, it places the query at a visible
1905
+ location near the center of the chart.
1906
+ 3. If neither are provided, it returns default origin coordinates.
1907
+
1908
+ Args:
1909
+ reference_embeddings: DataFrame with x, y, z coordinates from the main chart
1910
+ query_embedding: The embedding vector of the query
1911
+
1912
+ Returns:
1913
+ Dictionary with x, y, z coordinates for the query point
1914
+ """
1915
+ import numpy as np
1916
+
1917
+ # Default coordinates (origin with slight offset)
1918
+ default_coords = {'x': 0.0, 'y': 0.0, 'z': 0.0}
1919
+
1920
+ # If we don't have reference embeddings, return default
1921
+ if reference_embeddings is None or len(reference_embeddings) == 0:
1922
+ return default_coords
1923
+
1924
+ # If we have reference embeddings but no query embedding,
1925
+ # position at a visible location near the center
1926
+ if query_embedding is None:
1927
+ center_coords = {
1928
+ 'x': reference_embeddings['x'].mean(),
1929
+ 'y': reference_embeddings['y'].mean(),
1930
+ 'z': reference_embeddings['z'].mean()
1931
+ }
1932
+ return center_coords
1933
+
1934
+ # If we have both reference embeddings and query embedding,
1935
+ # try to position near similar documents
1936
+ try:
1937
+ from sklearn.metrics.pairwise import cosine_similarity
1938
+
1939
+ # Check if original embeddings are in the dataframe
1940
+ if 'embedding' in reference_embeddings.columns:
1941
+ # Get all document embeddings as a 2D array
1942
+ if isinstance(reference_embeddings['embedding'].iloc[0], list):
1943
+ doc_embeddings = np.array(reference_embeddings['embedding'].tolist())
1944
+ else:
1945
+ doc_embeddings = np.array([emb for emb in reference_embeddings['embedding'].values])
1946
+
1947
+ # Reshape query embedding for comparison
1948
+ query_emb_array = np.array(query_embedding)
1949
+ if query_emb_array.ndim == 1:
1950
+ query_emb_array = query_emb_array.reshape(1, -1)
1951
+
1952
+ # Calculate cosine similarities
1953
+ similarities = cosine_similarity(query_emb_array, doc_embeddings)[0]
1954
+
1955
+ # Find the closest document
1956
+ closest_idx = np.argmax(similarities)
1957
+
1958
+ # Use the position of the closest document, with slight offset for visibility
1959
+ query_coords = {
1960
+ 'x': reference_embeddings['x'].iloc[closest_idx] + 0.2,
1961
+ 'y': reference_embeddings['y'].iloc[closest_idx] + 0.2,
1962
+ 'z': reference_embeddings['z'].iloc[closest_idx] + 0.2
1963
+ }
1964
+ return query_coords
1965
+ except Exception as e:
1966
+ print(f"Error positioning query near similar documents: {e}")
1967
+
1968
+ # Fallback to center position if similarity calculation fails
1969
+ center_coords = {
1970
+ 'x': reference_embeddings['x'].mean(),
1971
+ 'y': reference_embeddings['y'].mean(),
1972
+ 'z': reference_embeddings['z'].mean()
1973
+ }
1974
+ return center_coords
1975
+ return add_query_to_embedding_chart, get_query_coordinates
1976
+
1977
+
1978
+ @app.cell
1979
+ def combined_chart_visualization(
1980
+ add_query_to_embedding_chart,
1981
+ chart_dataframe,
1982
+ emb_plot,
1983
+ embedding,
1984
+ get_query_coordinates,
1985
+ get_query_state,
1986
+ query,
1987
+ set_chart_state,
1988
+ set_query_state,
1989
+ ):
1990
+ # Usage with highlight_closest=True
1991
+ if chart_dataframe is not None and query.value:
1992
+ # Get the query embedding
1993
+ query_emb = embedding.embed_documents([query.value])
1994
+ set_query_state(query_emb)
1995
+
1996
+ # Get appropriate coordinates for the query
1997
+ query_coords = get_query_coordinates(
1998
+ reference_embeddings=chart_dataframe,
1999
+ query_embedding=get_query_state()
2000
+ )
2001
+
2002
+ # Add the query to the chart with closest points highlighted
2003
+ result = add_query_to_embedding_chart(
2004
+ existing_chart=emb_plot,
2005
+ query_coords=query_coords,
2006
+ query_text=query.value,
2007
+ )
2008
+
2009
+ chart_with_query = result
2010
+
2011
+ # Create the visualization
2012
+ combined_viz = mo.ui.plotly(chart_with_query)
2013
+ set_chart_state(combined_viz)
2014
+ else:
2015
+ combined_viz = None
2016
+ return
2017
+
2018
+
2019
+ @app.cell
2020
+ def _():
2021
+ get_range_slider_state, set_range_slider_state = mo.state(None)
2022
+ return get_range_slider_state, set_range_slider_state
2023
+
2024
+
2025
+ @app.cell
2026
+ def _(get_range_slider_state):
2027
+ if get_range_slider_state() is not None:
2028
+ document_range_stack = get_range_slider_state()
2029
+ else:
2030
+ document_range_stack = None
2031
+ return (document_range_stack,)
2032
+
2033
+
2034
+ @app.cell
2035
+ def _():
2036
+ get_chart_state, set_chart_state = mo.state(None)
2037
+ return get_chart_state, set_chart_state
2038
+
2039
+
2040
+ @app.cell
2041
+ def _(get_chart_state, query):
2042
+ if query.value is not None:
2043
+ chart_visualization = get_chart_state()
2044
+ else:
2045
+ chart_visualization = None
2046
+ return (chart_visualization,)
2047
+
2048
+
2049
+ @app.cell
2050
+ def c(document_range_stack):
2051
+ chart_range_selection = mo.hstack([document_range_stack], justify="space-around", align="center", widths=[0.65])
2052
+ return (chart_range_selection,)
2053
+
2054
+
2055
+ if __name__ == "__main__":
2056
+ app.run()