MilanM commited on
Commit
e83e4ae
·
verified ·
1 Parent(s): 58266a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1899 -329
app.py CHANGED
@@ -1,469 +1,2039 @@
1
  import marimo
2
 
3
- __generated_with = "0.9.2"
4
- app = marimo.App()
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  @app.cell
8
- def __():
9
- import marimo as mo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- mo.md("# Welcome to marimo! 🌊🍃")
12
- return (mo,)
13
 
 
 
 
 
 
 
 
14
 
15
  @app.cell
16
- def __(mo):
17
- slider = mo.ui.slider(1, 22)
18
- return (slider,)
 
 
 
 
 
19
 
20
 
21
  @app.cell
22
- def __(mo, slider):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  mo.md(
24
- f"""
25
- marimo is a **reactive** Python notebook.
26
 
27
- This means that unlike traditional notebooks, marimo notebooks **run
28
- automatically** when you modify them or
29
- interact with UI elements, like this slider: {slider}.
30
 
31
- {"##" + "🍃" * slider.value}
32
- """
33
- )
34
- return
35
 
 
 
 
36
 
37
- @app.cell(hide_code=True)
38
- def __(mo):
39
- mo.accordion(
40
- {
41
- "Tip: disabling automatic execution": mo.md(
42
- rf"""
43
- marimo lets you disable automatic execution: just go into the
44
- notebook settings and set
45
 
46
- "Runtime > On Cell Change" to "lazy".
 
47
 
48
- When the runtime is lazy, after running a cell, marimo marks its
49
- descendants as stale instead of automatically running them. The
50
- lazy runtime puts you in control over when cells are run, while
51
- still giving guarantees about the notebook state.
52
- """
53
- )
54
- }
55
  )
56
  return
57
 
58
 
59
- @app.cell(hide_code=True)
60
- def __(mo):
61
- mo.md(
62
- """
63
- Tip: This is a tutorial notebook. You can create your own notebooks
64
- by entering `marimo edit` at the command line.
65
- """
66
- ).callout()
67
  return
68
 
69
 
70
- @app.cell(hide_code=True)
71
- def __(mo):
72
- mo.md(
73
- """
74
- ## 1. Reactive execution
 
 
75
 
76
- A marimo notebook is made up of small blocks of Python code called
77
- cells.
78
 
79
- marimo reads your cells and models the dependencies among them: whenever
80
- a cell that defines a global variable is run, marimo
81
- **automatically runs** all cells that reference that variable.
82
 
83
- Reactivity keeps your program state and outputs in sync with your code,
84
- making for a dynamic programming environment that prevents bugs before they
85
- happen.
86
- """
 
 
87
  )
 
 
88
  return
89
 
90
 
91
- @app.cell(hide_code=True)
92
- def __(changed, mo):
93
- (
94
- mo.md(
95
- f"""
96
- **✨ Nice!** The value of `changed` is now {changed}.
 
 
 
 
 
 
 
 
 
 
 
97
 
98
- When you updated the value of the variable `changed`, marimo
99
- **reacted** by running this cell automatically, because this cell
100
- references the global variable `changed`.
101
 
102
- Reactivity ensures that your notebook state is always
103
- consistent, which is crucial for doing good science; it's also what
104
- enables marimo notebooks to double as tools and apps.
105
- """
106
- )
107
- if changed
108
- else mo.md(
109
- """
110
- **🌊 See it in action.** In the next cell, change the value of the
111
- variable `changed` to `True`, then click the run button.
112
- """
113
- )
114
  )
 
 
115
  return
116
 
117
 
118
  @app.cell
119
- def __():
120
- changed = False
121
- return (changed,)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
 
124
- @app.cell(hide_code=True)
125
- def __(mo):
126
- mo.accordion(
127
  {
128
- "Tip: execution order": (
129
- """
130
- The order of cells on the page has no bearing on
131
- the order in which cells are executed: marimo knows that a cell
132
- reading a variable must run after the cell that defines it. This
133
- frees you to organize your code in the way that makes the most
134
- sense for you.
135
- """
136
- )
137
  }
138
  )
 
139
  return
140
 
141
 
142
- @app.cell(hide_code=True)
143
- def __(mo):
144
- mo.md(
145
- """
146
- **Global names must be unique.** To enable reactivity, marimo imposes a
147
- constraint on how names appear in cells: no two cells may define the same
148
- variable.
149
- """
150
- )
 
 
 
 
 
 
 
151
  return
152
 
153
 
154
- @app.cell(hide_code=True)
155
- def __(mo):
156
- mo.accordion(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  {
158
- "Tip: encapsulation": (
159
- """
160
- By encapsulating logic in functions, classes, or Python modules,
161
- you can minimize the number of global variables in your notebook.
162
- """
163
- )
164
  }
165
  )
 
166
  return
167
 
168
 
169
- @app.cell(hide_code=True)
170
- def __(mo):
171
- mo.accordion(
172
- {
173
- "Tip: private variables": (
174
- """
175
- Variables prefixed with an underscore are "private" to a cell, so
176
- they can be defined by multiple cells.
177
- """
178
- )
179
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  return
182
 
183
 
184
- @app.cell(hide_code=True)
185
- def __(mo):
186
- mo.md(
187
- """
188
- ## 2. UI elements
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
- Cells can output interactive UI elements. Interacting with a UI
191
- element **automatically triggers notebook execution**: when
192
- you interact with a UI element, its value is sent back to Python, and
193
- every cell that references that element is re-run.
194
 
195
- marimo provides a library of UI elements to choose from under
196
- `marimo.ui`.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  )
199
- return
200
 
201
 
202
  @app.cell
203
- def __(mo):
204
- mo.md("""**🌊 Some UI elements.** Try interacting with the below elements.""")
205
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
 
208
  @app.cell
209
- def __(mo):
210
- icon = mo.ui.dropdown(["🍃", "🌊", "✨"], value="🍃")
211
- return (icon,)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
 
214
  @app.cell
215
- def __(icon, mo):
216
- repetitions = mo.ui.slider(1, 16, label=f"number of {icon.value}: ")
217
- return (repetitions,)
 
 
 
218
 
219
 
220
  @app.cell
221
- def __(icon, repetitions):
222
- icon, repetitions
223
- return
 
 
 
 
 
224
 
225
 
226
  @app.cell
227
- def __(icon, mo, repetitions):
228
- mo.md("# " + icon.value * repetitions.value)
229
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
 
232
- @app.cell(hide_code=True)
233
- def __(mo):
234
- mo.md(
235
  """
236
- ## 3. marimo is just Python
237
 
238
- marimo cells parse Python (and only Python), and marimo notebooks are
239
- stored as pure Python files — outputs are _not_ included. There's no
240
- magical syntax.
 
 
241
 
242
- The Python files generated by marimo are:
 
 
243
 
244
- - easily versioned with git, yielding minimal diffs
245
- - legible for both humans and machines
246
- - formattable using your tool of choice,
247
- - usable as Python scripts, with UI elements taking their default
248
- values, and
249
- - importable by other modules (more on that in the future).
250
  """
251
- )
252
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
 
 
254
 
255
- @app.cell(hide_code=True)
256
- def __(mo):
257
- mo.md(
258
  """
259
- ## 4. Running notebooks as apps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
 
261
- marimo notebooks can double as apps. Click the app window icon in the
262
- bottom-right to see this notebook in "app view."
263
 
264
- Serve a notebook as an app with `marimo run` at the command-line.
265
- Of course, you can use marimo just to level-up your
266
- notebooking, without ever making apps.
267
  """
268
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
  return
270
 
271
 
272
- @app.cell(hide_code=True)
273
- def __(mo):
274
- mo.md(
 
 
 
 
 
275
  """
276
- ## 5. The `marimo` command-line tool
277
 
278
- **Creating and editing notebooks.** Use
 
 
 
 
279
 
280
- ```
281
- marimo edit
282
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
- in a terminal to start the marimo notebook server. From here
285
- you can create a new notebook or edit existing ones.
286
 
 
 
 
287
 
288
- **Running as apps.** Use
 
 
289
 
290
- ```
291
- marimo run notebook.py
292
- ```
 
 
293
 
294
- to start a webserver that serves your notebook as an app in read-only mode,
295
- with code cells hidden.
296
 
297
- **Convert a Jupyter notebook.** Convert a Jupyter notebook to a marimo
298
- notebook using `marimo convert`:
 
 
 
299
 
300
- ```
301
- marimo convert your_notebook.ipynb > your_app.py
302
- ```
 
303
 
304
- **Tutorials.** marimo comes packaged with tutorials:
 
 
305
 
306
- - `dataflow`: more on marimo's automatic execution
307
- - `ui`: how to use UI elements
308
- - `markdown`: how to write markdown, with interpolated values and
309
- LaTeX
310
- - `plots`: how plotting works in marimo
311
- - `sql`: how to use SQL
312
- - `layout`: layout elements in marimo
313
- - `fileformat`: how marimo's file format works
314
- - `markdown-format`: for using `.md` files in marimo
315
- - `for-jupyter-users`: if you are coming from Jupyter
316
 
317
- Start a tutorial with `marimo tutorial`; for example,
 
 
318
 
319
- ```
320
- marimo tutorial dataflow
321
- ```
322
 
323
- In addition to tutorials, we have examples in our
324
- [our GitHub repo](https://www.github.com/marimo-team/marimo/tree/main/examples).
325
  """
326
- )
327
- return
328
 
 
 
329
 
330
- @app.cell(hide_code=True)
331
- def __(mo):
332
- mo.md(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  """
334
- ## 6. The marimo editor
335
 
336
- Here are some tips to help you get started with the marimo editor.
 
 
 
 
 
337
  """
338
- )
339
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
 
341
 
342
  @app.cell
343
- def __(mo, tips):
344
- mo.accordion(tips)
345
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
 
347
 
348
- @app.cell(hide_code=True)
349
- def __(mo):
350
- mo.md("""## Finally, a fun fact""")
 
 
 
 
 
 
 
 
 
351
  return
352
 
353
 
354
- @app.cell(hide_code=True)
355
- def __(mo):
356
- mo.md(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
  """
358
- The name "marimo" is a reference to a type of algae that, under
359
- the right conditions, clumps together to form a small sphere
360
- called a "marimo moss ball". Made of just strands of algae, these
361
- beloved assemblages are greater than the sum of their parts.
 
 
 
 
 
 
362
  """
363
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
  return
365
 
366
 
367
- @app.cell(hide_code=True)
368
- def __():
369
- tips = {
370
- "Saving": (
371
- """
372
- **Saving**
373
 
374
- - _Name_ your app using the box at the top of the screen, or
375
- with `Ctrl/Cmd+s`. You can also create a named app at the
376
- command line, e.g., `marimo edit app_name.py`.
377
 
378
- - _Save_ by clicking the save icon on the bottom right, or by
379
- inputting `Ctrl/Cmd+s`. By default marimo is configured
380
- to autosave.
381
- """
382
- ),
383
- "Running": (
384
- """
385
- 1. _Run a cell_ by clicking the play ( ▷ ) button on the top
386
- right of a cell, or by inputting `Ctrl/Cmd+Enter`.
387
 
388
- 2. _Run a stale cell_ by clicking the yellow run button on the
389
- right of the cell, or by inputting `Ctrl/Cmd+Enter`. A cell is
390
- stale when its code has been modified but not run.
391
 
392
- 3. _Run all stale cells_ by clicking the play ( ▷ ) button on
393
- the bottom right of the screen, or input `Ctrl/Cmd+Shift+r`.
394
- """
395
- ),
396
- "Console Output": (
397
- """
398
- Console output (e.g., `print()` statements) is shown below a
399
- cell.
400
- """
401
- ),
402
- "Creating, Moving, and Deleting Cells": (
403
- """
404
- 1. _Create_ a new cell above or below a given one by clicking
405
- the plus button to the left of the cell, which appears on
406
- mouse hover.
407
 
408
- 2. _Move_ a cell up or down by dragging on the handle to the
409
- right of the cell, which appears on mouse hover.
410
 
411
- 3. _Delete_ a cell by clicking the trash bin icon. Bring it
412
- back by clicking the undo button on the bottom right of the
413
- screen, or with `Ctrl/Cmd+Shift+z`.
414
- """
415
- ),
416
- "Disabling Automatic Execution": (
417
- """
418
- Via the notebook settings (gear icon) or footer panel, you
419
- can disable automatic execution. This is helpful when
420
- working with expensive notebooks or notebooks that have
421
- side-effects like database transactions.
422
- """
423
- ),
424
- "Disabling Cells": (
425
- """
426
- You can disable a cell via the cell context menu.
427
- marimo will never run a disabled cell or any cells that depend on it.
428
- This can help prevent accidental execution of expensive computations
429
- when editing a notebook.
430
- """
431
- ),
432
- "Code Folding": (
433
- """
434
- You can collapse or fold the code in a cell by clicking the arrow
435
- icons in the line number column to the left, or by using keyboard
436
- shortcuts.
437
 
438
- Use the command palette (`Ctrl/Cmd+k`) or a keyboard shortcut to
439
- quickly fold or unfold all cells.
440
- """
441
- ),
442
- "Code Formatting": (
443
- """
444
- If you have [ruff](https://github.com/astral-sh/ruff) installed,
445
- you can format a cell with the keyboard shortcut `Ctrl/Cmd+b`.
446
- """
447
- ),
448
- "Command Palette": (
449
- """
450
- Use `Ctrl/Cmd+k` to open the command palette.
451
- """
452
- ),
453
- "Keyboard Shortcuts": (
454
- """
455
- Open the notebook menu (top-right) or input `Ctrl/Cmd+Shift+h` to
456
- view a list of all keyboard shortcuts.
457
- """
458
- ),
459
- "Configuration": (
460
- """
461
- Configure the editor by clicking the gears icon near the top-right
462
- of the screen.
463
- """
464
- ),
465
- }
466
- return (tips,)
467
 
468
 
469
  if __name__ == "__main__":
 
1
  import marimo
2
 
3
+ __generated_with = "0.13.0"
4
+ app = marimo.App(width="full")
5
 
6
+ with app.setup:
7
+ # Initialization code that runs before all other cells
8
+ import marimo as mo
9
+ from typing import Dict, Optional, List, Union, Any
10
+ from ibm_watsonx_ai import APIClient, Credentials
11
+ from pathlib import Path
12
+ import pandas as pd
13
+ import mimetypes
14
+ import requests
15
+ import zipfile
16
+ import tempfile
17
+ import base64
18
+ import polars
19
+ import time
20
+ import json
21
+ import ast
22
+ import os
23
+ import io
24
+ import re
25
+
26
+ def get_iam_token(api_key):
27
+ return requests.post(
28
+ 'https://iam.cloud.ibm.com/identity/token',
29
+ headers={'Content-Type': 'application/x-www-form-urlencoded'},
30
+ data={'grant_type': 'urn:ibm:params:oauth:grant-type:apikey', 'apikey': api_key}
31
+ ).json()['access_token']
32
+
33
+ def setup_task_credentials(client):
34
+ # Get existing task credentials
35
+ existing_credentials = client.task_credentials.get_details()
36
+
37
+ # Delete existing credentials if any
38
+ if "resources" in existing_credentials and existing_credentials["resources"]:
39
+ for cred in existing_credentials["resources"]:
40
+ cred_id = client.task_credentials.get_id(cred)
41
+ client.task_credentials.delete(cred_id)
42
+
43
+ # Store new credentials
44
+ return client.task_credentials.store()
45
+
46
+ def get_cred_value(key, creds_var_name="baked_in_creds", default=""): ### Helper for working with preset credentials
47
+ """
48
+ Helper function to safely get a value from a credentials dictionary.
49
+
50
+ Args:
51
+ key: The key to look up in the credentials dictionary.
52
+ creds_var_name: The variable name of the credentials dictionary.
53
+ default: The default value to return if the key is not found.
54
+
55
+ Returns:
56
+ The value from the credentials dictionary if it exists and contains the key,
57
+ otherwise returns the default value.
58
+ """
59
+ # Check if the credentials variable exists in globals
60
+ if creds_var_name in globals():
61
+ creds_dict = globals()[creds_var_name]
62
+ if isinstance(creds_dict, dict) and key in creds_dict:
63
+ return creds_dict[key]
64
+ return default
65
 
66
  @app.cell
67
+ def client_variables(client_instantiation_form):
68
+ if client_instantiation_form.value:
69
+ client_setup = client_instantiation_form.value
70
+ else:
71
+ client_setup = None
72
+
73
+ ### Extract Credential Variables:
74
+ if client_setup is not None:
75
+ wx_url = client_setup["wx_region"]
76
+ wx_api_key = client_setup["wx_api_key"].strip()
77
+ os.environ["WATSONX_APIKEY"] = wx_api_key
78
+
79
+ if client_setup["project_id"] is not None:
80
+ project_id = client_setup["project_id"].strip()
81
+ else:
82
+ project_id = None
83
+
84
+ if client_setup["space_id"] is not None:
85
+ space_id = client_setup["space_id"].strip()
86
+ else:
87
+ space_id = None
88
+
89
+ else:
90
+ os.environ["WATSONX_APIKEY"] = ""
91
+ project_id = None
92
+ space_id = None
93
+ wx_api_key = None
94
+ wx_url = None
95
+ return client_setup, project_id, space_id, wx_api_key, wx_url
96
 
 
 
97
 
98
+ @app.cell
99
+ def _(client_setup, wx_api_key):
100
+ if client_setup:
101
+ token = get_iam_token(wx_api_key)
102
+ else:
103
+ token = None
104
+ return
105
 
106
  @app.cell
107
+ def _():
108
+ baked_in_creds = {
109
+ "purpose": "",
110
+ "api_key": "",
111
+ "project_id": "",
112
+ "space_id": "",
113
+ }
114
+ return baked_in_creds
115
 
116
 
117
  @app.cell
118
+ def client_instantiation(
119
+ client_setup,
120
+ project_id,
121
+ space_id,
122
+ wx_api_key,
123
+ wx_url,
124
+ ):
125
+ ### Instantiate the watsonx.ai client
126
+ if client_setup:
127
+ wx_credentials = Credentials(
128
+ url=wx_url,
129
+ api_key=wx_api_key
130
+ )
131
+
132
+ if project_id:
133
+ project_client = APIClient(credentials=wx_credentials, project_id=project_id)
134
+ else:
135
+ project_client = None
136
+
137
+ if space_id:
138
+ deployment_client = APIClient(credentials=wx_credentials, space_id=space_id)
139
+ else:
140
+ deployment_client = None
141
+
142
+ if project_client is not None:
143
+ task_credentials_details = setup_task_credentials(project_client)
144
+ else:
145
+ task_credentials_details = setup_task_credentials(deployment_client)
146
+ else:
147
+ wx_credentials = None
148
+ project_client = None
149
+ deployment_client = None
150
+ task_credentials_details = None
151
+
152
+ client_status = mo.md("### Client Instantiation Status will turn Green When Ready")
153
+
154
+ if project_client is not None or deployment_client is not None:
155
+ client_callout_kind = "success"
156
+ else:
157
+ client_callout_kind = "neutral"
158
+ return (
159
+ client_callout_kind,
160
+ client_status,
161
+ deployment_client,
162
+ project_client,
163
+ )
164
+
165
+
166
+ @app.cell
167
+ def _():
168
  mo.md(
169
+ r"""
170
+ #watsonx.ai Embedding Visualizer - Marimo Notebook
171
 
172
+ #### This marimo notebook can be used to develop a more intuitive understanding of how vector embeddings work by creating a 3D visualization of vector embeddings based on chunked PDF document pages.
 
 
173
 
174
+ #### It can also serve as a useful tool for identifying gaps in model choice, chunking strategy or contents used in building collections by showing how far you are from what you want.
175
+ <br>
 
 
176
 
177
+ /// admonition
178
+ Created by ***Milan Mrdenovic*** [[email protected]] for IBM Ecosystem Client Engineering, NCEE - ***version 5.3** - 20.04.2025*
179
+ ///
180
 
 
 
 
 
 
 
 
 
181
 
182
+ >Licensed under apache 2.0, users hold full accountability for any use or modification of the code.
183
+ ><br>This asset is part of a set meant to support IBMers, IBM Partners, Clients in developing understanding of how to better utilize various watsonx features and generative AI as a subject matter.
184
 
185
+ <br>
186
+ """
 
 
 
 
 
187
  )
188
  return
189
 
190
 
191
+ @app.cell
192
+ def _():
193
+ mo.md("""###Part 1 - Client Setup, File Preparation and Chunking""")
 
 
 
 
 
194
  return
195
 
196
 
197
+ @app.cell
198
+ def accordion_client_setup(client_selector, client_stack):
199
+ ui_accordion_part_1_1 = mo.accordion(
200
+ {
201
+ "Instantiate Client": mo.vstack([client_stack, client_selector], align="center"),
202
+ }
203
+ )
204
 
205
+ ui_accordion_part_1_1
206
+ return
207
 
 
 
 
208
 
209
+ @app.cell
210
+ def accordion_file_upload(select_stack):
211
+ ui_accordion_part_1_2 = mo.accordion(
212
+ {
213
+ "Select Model & Upload Files": select_stack
214
+ }
215
  )
216
+
217
+ ui_accordion_part_1_2
218
  return
219
 
220
 
221
+ @app.cell
222
+ def loaded_texts(
223
+ create_temp_files_from_uploads,
224
+ file_loader,
225
+ pdf_reader,
226
+ run_upload_button,
227
+ set_text_state,
228
+ ):
229
+ if file_loader.value is not None and run_upload_button.value:
230
+ filepaths = create_temp_files_from_uploads(file_loader.value)
231
+ loaded_texts = load_pdf_data_with_progress(pdf_reader, filepaths, file_loader.value, show_progress=True)
232
+
233
+ set_text_state(loaded_texts)
234
+ else:
235
+ filepaths = None
236
+ loaded_texts = None
237
+ return
238
 
 
 
 
239
 
240
+ @app.cell
241
+ def accordion_chunker_setup(chunker_setup):
242
+ ui_accordion_part_1_3 = mo.accordion(
243
+ {
244
+ "Chunker Setup": chunker_setup
245
+ }
 
 
 
 
 
 
246
  )
247
+
248
+ ui_accordion_part_1_3
249
  return
250
 
251
 
252
  @app.cell
253
+ def chunk_documents_to_nodes(
254
+ get_text_state,
255
+ sentence_splitter,
256
+ sentence_splitter_config,
257
+ set_chunk_state,
258
+ ):
259
+ if sentence_splitter_config.value and sentence_splitter and get_text_state() is not None:
260
+ chunked_texts = chunk_documents(get_text_state(), sentence_splitter, show_progress=True)
261
+ set_chunk_state(chunked_texts)
262
+ else:
263
+ chunked_texts = None
264
+ return (chunked_texts,)
265
+
266
+
267
+ @app.cell
268
+ def _():
269
+ mo.md(r"""###Part 2 - Query Setup and Visualization""")
270
+ return
271
 
272
 
273
+ @app.cell
274
+ def accordion_chunk_range(chart_range_selection):
275
+ ui_accordion_part_2_1 = mo.accordion(
276
  {
277
+ "Chunk Range Selection": chart_range_selection
 
 
 
 
 
 
 
 
278
  }
279
  )
280
+ ui_accordion_part_2_1
281
  return
282
 
283
 
284
+ @app.cell
285
+ def chunk_embedding(
286
+ chunks_to_process,
287
+ embedding,
288
+ sentence_splitter_config,
289
+ set_embedding_state,
290
+ ):
291
+ if sentence_splitter_config.value is not None and chunks_to_process is not None:
292
+ with mo.status.spinner(title="Embedding Documents...", remove_on_exit=True) as _spinner:
293
+ output_embeddings = embedding.embed_documents(chunks_to_process)
294
+ _spinner.update("Almost Done")
295
+ time.sleep(1.5)
296
+ set_embedding_state(output_embeddings)
297
+ _spinner.update("Documents Embedded")
298
+ else:
299
+ output_embeddings = None
300
  return
301
 
302
 
303
+ @app.cell
304
+ def preview_chunks(chunks_dict):
305
+ if chunks_dict is not None:
306
+ stats = create_stats(chunks_dict,
307
+ bordered=True,
308
+ object_names=['text','text'],
309
+ group_by_row=True,
310
+ items_per_row=5,
311
+ gap=1,
312
+ label="Chunk")
313
+ ui_chunk_viewer = mo.accordion(
314
+ {
315
+ "View Chunks": stats,
316
+ }
317
+ )
318
+ else:
319
+ ui_chunk_viewer = None
320
+
321
+ ui_chunk_viewer
322
+ return
323
+
324
+
325
+ @app.cell
326
+ def accordion_query_view(chart_visualization, query_stack):
327
+ ui_accordion_part_2_2 = mo.accordion(
328
  {
329
+ # "Query": query_stack
330
+ # "Query": mo.hstack([query_stack, chart_visualization], justify="space-around", align="center", widths=[0.3,0.65])
331
+ "Query": mo.vstack([query_stack, mo.hstack([chart_visualization])], align="center", gap=3)
332
+ # "Query": mo.vstack([query_stack, chart_visualization], justify="start", align="center")
 
 
333
  }
334
  )
335
+ ui_accordion_part_2_2
336
  return
337
 
338
 
339
+ @app.cell
340
+ def chunker_setup(sentence_splitter_config):
341
+ chunker_setup = mo.hstack([sentence_splitter_config], justify="space-around", align="center", widths=[0.55])
342
+ return (chunker_setup,)
343
+
344
+
345
+ @app.cell
346
+ def file_and_model_select(
347
+ file_loader,
348
+ get_embedding_model_list,
349
+ run_upload_button,
350
+ ):
351
+ select_stack = mo.hstack([get_embedding_model_list(), mo.vstack([file_loader, run_upload_button], align="center")], justify="space-around", align="center", widths=[0.3,0.3])
352
+ return (select_stack,)
353
+
354
+
355
+ @app.cell
356
+ def client_instantiation_form():
357
+ # Endpoints
358
+ wx_platform_url = "https://api.dataplatform.cloud.ibm.com"
359
+ regions = {
360
+ "US": "https://us-south.ml.cloud.ibm.com",
361
+ "EU": "https://eu-de.ml.cloud.ibm.com",
362
+ "GB": "https://eu-gb.ml.cloud.ibm.com",
363
+ "JP": "https://jp-tok.ml.cloud.ibm.com",
364
+ "AU": "https://au-syd.ml.cloud.ibm.com",
365
+ "CA": "https://ca-tor.ml.cloud.ibm.com"
366
+ }
367
+
368
+ # Create a form with multiple elements
369
+ client_instantiation_form = (
370
+ mo.md('''
371
+ ###**watsonx.ai credentials:**
372
+
373
+ {wx_region}
374
+
375
+ {wx_api_key}
376
+
377
+ {project_id}
378
+
379
+ {space_id}
380
+ ''')
381
+ .batch(
382
+ wx_region = mo.ui.dropdown(regions, label="Select your watsonx.ai region:", value="US", searchable=True),
383
+ wx_api_key = mo.ui.text(placeholder="Add your IBM Cloud api-key...", label="IBM Cloud Api-key:",
384
+ kind="password", value=get_cred_value('api_key', creds_var_name='baked_in_creds')),
385
+ project_id = mo.ui.text(placeholder="Add your watsonx.ai project_id...", label="Project_ID:",
386
+ kind="text", value=get_cred_value('project_id', creds_var_name='baked_in_creds')),
387
+ space_id = mo.ui.text(placeholder="Add your watsonx.ai space_id...", label="Space_ID:",
388
+ kind="text", value=get_cred_value('space_id', creds_var_name='baked_in_creds'))
389
+ ,)
390
+ .form(show_clear_button=True, bordered=False)
391
  )
392
+ return (client_instantiation_form,)
393
+
394
+
395
+ @app.cell
396
+ def instantiation_status(
397
+ client_callout_kind,
398
+ client_instantiation_form,
399
+ client_status,
400
+ ):
401
+ client_callout = mo.callout(client_status, kind=client_callout_kind)
402
+ client_stack = mo.hstack([client_instantiation_form, client_callout], align="center", justify="space-around", gap=10)
403
+ return (client_stack,)
404
+
405
+
406
+ @app.cell
407
+ def client_selector(deployment_client, project_client):
408
+ client_selector = mo.ui.dropdown({"Project Client":project_client,"Deployment Client":deployment_client}, value="Project Client", label="**Select your active client:**")
409
+ return (client_selector,)
410
+
411
+
412
+ @app.cell
413
+ def active_client(client_selector):
414
+ client = client_selector.value
415
+ return (client,)
416
+
417
+
418
+ @app.cell
419
+ def emb_model_selection(client, set_embedding_model_list):
420
+ if client:
421
+ model_specs = client.foundation_models.get_embeddings_model_specs()
422
+ # model_specs = client.foundation_models.get_model_specs()
423
+ resources = model_specs["resources"]
424
+ # Define embedding models reference data
425
+ embedding_models = {
426
+ "ibm/granite-embedding-107m-multilingual": {"max_tokens": 512, "embedding_dimensions": 384},
427
+ "ibm/granite-embedding-278m-multilingual": {"max_tokens": 512, "embedding_dimensions": 768},
428
+ "ibm/slate-125m-english-rtrvr-v2": {"max_tokens": 512, "embedding_dimensions": 768},
429
+ "ibm/slate-125m-english-rtrvr": {"max_tokens": 512, "embedding_dimensions": 768},
430
+ "ibm/slate-30m-english-rtrvr-v2": {"max_tokens": 512, "embedding_dimensions": 384},
431
+ "ibm/slate-30m-english-rtrvr": {"max_tokens": 512, "embedding_dimensions": 384},
432
+ "sentence-transformers/all-minilm-l6-v2": {"max_tokens": 128, "embedding_dimensions": 384},
433
+ "sentence-transformers/all-minilm-l12-v2": {"max_tokens": 128, "embedding_dimensions": 384},
434
+ "intfloat/multilingual-e5-large": {"max_tokens": 512, "embedding_dimensions": 1024}
435
+ }
436
+
437
+ # Get model IDs from resources
438
+ model_id_list = []
439
+ for resource in resources:
440
+ model_id_list.append(resource["model_id"])
441
+
442
+ # Create enhanced model data for the table
443
+ embedding_model_data = []
444
+ for model_id in model_id_list:
445
+ model_entry = {"model_id": model_id}
446
+
447
+ # Add properties if model exists in our reference, otherwise use 0
448
+ if model_id in embedding_models:
449
+ model_entry["max_tokens"] = embedding_models[model_id]["max_tokens"]
450
+ model_entry["embedding_dimensions"] = embedding_models[model_id]["embedding_dimensions"]
451
+ else:
452
+ model_entry["max_tokens"] = 0
453
+ model_entry["embedding_dimensions"] = 0
454
+
455
+ embedding_model_data.append(model_entry)
456
+
457
+ embedding_model_selection = mo.ui.table(
458
+ embedding_model_data,
459
+ selection="single", # Only allow selecting one row
460
+ label="Select an embedding model to use.",
461
+ page_size=30,
462
+ initial_selection=[1]
463
+ )
464
+ set_embedding_model_list(embedding_model_selection)
465
+ else:
466
+ default_model_data = [{
467
+ "model_id": "ibm/granite-embedding-107m-multilingual",
468
+ "max_tokens": 512,
469
+ "embedding_dimensions": 384
470
+ }]
471
+
472
+ set_embedding_model_list(create_emb_model_selection_table(default_model_data, initial_selection=0, selection_type="single", label="Select a model to use."))
473
  return
474
 
475
 
476
+ @app.function
477
+ def create_emb_model_selection_table(model_data, initial_selection=0, selection_type="single", label="Select a model to use."):
478
+ embedding_model_selection = mo.ui.table(
479
+ model_data,
480
+ selection=selection_type, # Only allow selecting one row
481
+ label=label,
482
+ page_size=30,
483
+ initial_selection=[initial_selection]
484
+ )
485
+ return embedding_model_selection
486
+
487
+
488
+ @app.cell
489
+ def embedding_model():
490
+ get_embedding_model_list, set_embedding_model_list = mo.state(None)
491
+ return get_embedding_model_list, set_embedding_model_list
492
+
493
+
494
+ @app.cell
495
+ def emb_model_parameters(emb_model_max_tk):
496
+ from ibm_watsonx_ai.foundation_models import Embeddings
497
+ from ibm_watsonx_ai.metanames import EmbedTextParamsMetaNames as EmbedParams
498
+
499
+ embed_params = {
500
+ EmbedParams.TRUNCATE_INPUT_TOKENS: emb_model_max_tk,
501
+ EmbedParams.RETURN_OPTIONS: {
502
+ 'input_text': True
503
+ }
504
+ }
505
+ return Embeddings, embed_params
506
+
507
+
508
+ @app.cell
509
+ def emb_model_state(get_embedding_model_list):
510
+ embedding_model = get_embedding_model_list()
511
+ return (embedding_model,)
512
+
513
+
514
+ @app.cell
515
+ def emb_model_setup(embedding_model):
516
+ emb_model = embedding_model.value[0]['model_id']
517
+ emb_model_max_tk = embedding_model.value[0]['max_tokens']
518
+ emb_model_emb_dim = embedding_model.value[0]['embedding_dimensions']
519
+ return emb_model, emb_model_emb_dim, emb_model_max_tk
520
+
521
+
522
+ @app.cell
523
+ def emb_model_instantiation(Embeddings, client, emb_model, embed_params):
524
+ if client is not None:
525
+ embedding = Embeddings(
526
+ model_id=emb_model,
527
+ api_client=client,
528
+ params=embed_params,
529
+ batch_size=1000,
530
+ concurrency_limit=10
531
+ )
532
+ else:
533
+ embedding = None
534
+ return (embedding,)
535
+
536
+
537
+ @app.cell
538
+ def _():
539
+ get_embedding_state, set_embedding_state = mo.state(None)
540
+ return get_embedding_state, set_embedding_state
541
+
542
+
543
+ @app.cell
544
+ def _():
545
+ get_query_state, set_query_state = mo.state(None)
546
+ return get_query_state, set_query_state
547
 
 
 
 
 
548
 
549
+ @app.cell
550
+ def file_loader_input():
551
+ file_loader = mo.ui.file(
552
+ kind="area",
553
+ filetypes=[".pdf"],
554
+ label=" Load .pdf files ",
555
+ multiple=True
556
+ )
557
+ return (file_loader,)
558
+
559
+
560
+ @app.cell
561
+ def file_loader_run(file_loader):
562
+ if file_loader.value is not None:
563
+ run_upload_button = mo.ui.run_button(label="Load Files")
564
+ else:
565
+ run_upload_button = mo.ui.run_button(disabled=True, label="Load Files")
566
+ return (run_upload_button,)
567
+
568
+
569
+ @app.cell
570
+ def helper_function_tempfiles():
571
+ def create_temp_files_from_uploads(upload_results) -> List[str]:
572
  """
573
+ Creates temporary files from a tuple of FileUploadResults objects and returns their paths.
574
+ Args:
575
+ upload_results: Object containing a value attribute that is a tuple of FileUploadResults
576
+ Returns:
577
+ List of temporary file paths
578
+ """
579
+ temp_file_paths = []
580
+
581
+ # Get the number of items in the tuple
582
+ num_items = len(upload_results)
583
+
584
+ # Process each item by index
585
+ for i in range(num_items):
586
+ result = upload_results[i] # Get item by index
587
+
588
+ # Create a temporary file with the original filename
589
+ temp_dir = tempfile.gettempdir()
590
+ file_name = result.name
591
+ temp_path = os.path.join(temp_dir, file_name)
592
+ # Write the contents to the temp file
593
+ with open(temp_path, 'wb') as temp_file:
594
+ temp_file.write(result.contents)
595
+ # Add the path to our list
596
+ temp_file_paths.append(temp_path)
597
+
598
+ return temp_file_paths
599
+
600
+ def cleanup_temp_files(temp_file_paths: List[str]) -> None:
601
+ """Delete temporary files after use."""
602
+ for path in temp_file_paths:
603
+ if os.path.exists(path):
604
+ os.unlink(path)
605
+ return (create_temp_files_from_uploads,)
606
+
607
+
608
+ @app.function
609
+ def load_pdf_data_with_progress(pdf_reader, filepaths, file_loader_value, show_progress=True):
610
+ """
611
+ Loads PDF data for each file path and organizes results by original filename.
612
+ Args:
613
+ pdf_reader: The PyMuPDFReader instance
614
+ filepaths: List of temporary file paths
615
+ file_loader_value: The original upload results value containing file information
616
+ show_progress: Whether to show a progress bar during loading (default: False)
617
+ Returns:
618
+ Dictionary mapping original filenames to their loaded text content
619
+ """
620
+ results = {}
621
+
622
+ # Process files with or without progress bar
623
+ if show_progress:
624
+ import marimo as mo
625
+ # Use progress bar with the length of filepaths as total
626
+ with mo.status.progress_bar(
627
+ total=len(filepaths),
628
+ title="Loading PDFs",
629
+ subtitle="Processing documents...",
630
+ completion_title="PDF Loading Complete",
631
+ completion_subtitle=f"{len(filepaths)} documents processed",
632
+ remove_on_exit=True
633
+ ) as bar:
634
+ # Process each file path
635
+ for i, file_path in enumerate(filepaths):
636
+
637
+ original_file_name = file_loader_value[i].name
638
+ bar.update(subtitle=f"Processing {original_file_name}...")
639
+ loaded_text = pdf_reader.load_data(file_path=file_path, metadata=True)
640
+
641
+ # Store the result with the original filename as the key
642
+ results[original_file_name] = loaded_text
643
+ # Update progress bar
644
+ bar.update(increment=1)
645
+ else:
646
+ # Original logic without progress bar
647
+ for i, file_path in enumerate(filepaths):
648
+ original_file_name = file_loader_value[i].name
649
+ loaded_text = pdf_reader.load_data(file_path=file_path, metadata=True)
650
+ results[original_file_name] = loaded_text
651
+
652
+ return results
653
+
654
+
655
+ @app.cell
656
+ def file_readers():
657
+ from llama_index.readers.file import PyMuPDFReader
658
+ from llama_index.readers.file import FlatReader
659
+ from llama_index.core.node_parser import SentenceSplitter
660
+
661
+ ### File Readers
662
+ pdf_reader = PyMuPDFReader()
663
+ # flat_file_reader = FlatReader()
664
+ return SentenceSplitter, pdf_reader
665
+
666
+
667
+ @app.cell
668
+ def sentence_splitter_setup():
669
+ ### Chunker Setup
670
+ sentence_splitter_config = (
671
+ mo.md('''
672
+ ###**Chunking Setup:**
673
+
674
+ > Unless you want to do some advanced sentence splitting, it's best to stick to adjusting only the chunk size and overlap. Changing the other settings might result in unexpected results.
675
+
676
+ Separator value is set to **" "** by default, while the paragraph separator is **"\\n\\n\\n"**.
677
+
678
+ {chunk_size} {chunk_overlap}
679
+
680
+ {separator} {paragraph_separator}
681
+
682
+ {secondary_chunking_regex}
683
+
684
+ {include_metadata}
685
+
686
+ ''')
687
+ .batch(
688
+ chunk_size = mo.ui.slider(start=100, stop=5000, step=1, label="**Chunk SizeL**", value=350, show_value=True),
689
+ chunk_overlap = mo.ui.slider(start=1, stop=1000, step=1, label="**Chunk Overlap:**", value=50, show_value=True),
690
+ separator = mo.ui.text(placeholder="Define a separator", label="**Separator:**", kind="text", value=" "),
691
+ paragraph_separator = mo.ui.text(placeholder="Define a paragraph separator",
692
+ label="**Paragraph Separator:**", kind="text",
693
+ value="\n\n\n"),
694
+ secondary_chunking_regex = mo.ui.text(placeholder="Define a secondary chunking regex",
695
+ label="**Chunking Regex:**", kind="text",
696
+ value="[^,.;?!]+[,.;?!]?"),
697
+ include_metadata= mo.ui.checkbox(value=True, label="**Include Metadata**")
698
+ )
699
+ .form(show_clear_button=True, bordered=False)
700
  )
701
+ return (sentence_splitter_config,)
702
 
703
 
704
  @app.cell
705
+ def sentence_splitter_instantiation(
706
+ SentenceSplitter,
707
+ sentence_splitter_config,
708
+ ):
709
+ ### Chunker/Sentence Splitter
710
+ if sentence_splitter_config.value is not None:
711
+ sentence_splitter_config_values = sentence_splitter_config.value
712
+ validated_chunk_overlap = min(sentence_splitter_config_values.get("chunk_overlap"),
713
+ int(sentence_splitter_config_values.get("chunk_size") * 0.3))
714
+
715
+ sentence_splitter = SentenceSplitter(
716
+ chunk_size=sentence_splitter_config_values.get("chunk_size"),
717
+ chunk_overlap=validated_chunk_overlap,
718
+ separator=sentence_splitter_config_values.get("separator"),
719
+ paragraph_separator=sentence_splitter_config_values.get("paragraph_separator"),
720
+ secondary_chunking_regex=sentence_splitter_config_values.get("secondary_chunking_regex"),
721
+ include_metadata=sentence_splitter_config_values.get("include_metadata"),
722
+ )
723
+
724
+ else:
725
+ sentence_splitter = SentenceSplitter(
726
+ chunk_size=2048,
727
+ chunk_overlap=204,
728
+ separator=" ",
729
+ paragraph_separator="\n\n\n",
730
+ secondary_chunking_regex="[^,.;?!]+[,.;?!]?",
731
+ include_metadata=True,
732
+ )
733
+ return (sentence_splitter,)
734
+
735
+
736
+ @app.cell
737
+ def text_state():
738
+ get_text_state, set_text_state = mo.state(None)
739
+ return get_text_state, set_text_state
740
 
741
 
742
  @app.cell
743
+ def chunk_state():
744
+ get_chunk_state, set_chunk_state = mo.state(None)
745
+ return get_chunk_state, set_chunk_state
746
+
747
+
748
+ @app.function
749
+ def chunk_documents(loaded_texts, sentence_splitter, show_progress=True):
750
+ """
751
+ Process each document in the loaded_texts dictionary using the sentence_splitter,
752
+ with an optional marimo progress bar tracking progress at document level.
753
+
754
+ Args:
755
+ loaded_texts (dict): Dictionary containing lists of Document objects
756
+ sentence_splitter: The sentence splitter object with get_nodes_from_documents method
757
+ show_progress (bool): Whether to show a progress bar during processing
758
+
759
+ Returns:
760
+ dict: Dictionary with the same structure but containing chunked texts
761
+ """
762
+ chunked_texts_dict = {}
763
+
764
+ # Get the total number of documents across all keys
765
+ total_docs = sum(len(docs) for docs in loaded_texts.values())
766
+ processed_docs = 0
767
+
768
+ # Process with or without progress bar
769
+ if show_progress:
770
+ import marimo as mo
771
+ # Use progress bar with the total number of documents as total
772
+ with mo.status.progress_bar(
773
+ total=total_docs,
774
+ title="Processing Documents",
775
+ subtitle="Chunking documents...",
776
+ completion_title="Processing Complete",
777
+ completion_subtitle=f"{total_docs} documents processed",
778
+ remove_on_exit=True
779
+ ) as bar:
780
+ # Process each key-value pair in the loaded_texts dictionary
781
+ for key, documents in loaded_texts.items():
782
+ # Update progress bar subtitle to show current key
783
+ doc_count = len(documents)
784
+ bar.update(subtitle=f"Chunking {key}... ({doc_count} documents)")
785
+
786
+ # Apply the sentence splitter to each list of documents
787
+ chunked_texts = sentence_splitter.get_nodes_from_documents(
788
+ documents,
789
+ show_progress=False # Disable internal progress to avoid nested bars
790
+ )
791
+
792
+ # Store the result with the same key
793
+ chunked_texts_dict[key] = chunked_texts
794
+ time.sleep(0.15)
795
+
796
+ # Update progress bar with the number of documents in this batch
797
+ bar.update(increment=doc_count)
798
+ processed_docs += doc_count
799
+ else:
800
+ # Process without progress bar
801
+ for key, documents in loaded_texts.items():
802
+ chunked_texts = sentence_splitter.get_nodes_from_documents(
803
+ documents,
804
+ show_progress=True # Use the internal progress bar if no marimo bar
805
+ )
806
+ chunked_texts_dict[key] = chunked_texts
807
+
808
+ return chunked_texts_dict
809
 
810
 
811
  @app.cell
812
+ def chunked_nodes(chunked_texts, get_chunk_state, sentence_splitter):
813
+ if chunked_texts is not None and sentence_splitter:
814
+ chunked_documents = get_chunk_state()
815
+ else:
816
+ chunked_documents = None
817
+ return (chunked_documents,)
818
 
819
 
820
  @app.cell
821
+ def prep_cumulative_df(chunked_documents, llamaindex_convert_docs_multi):
822
+ if chunked_documents is not None:
823
+ dict_from_nodes = llamaindex_convert_docs_multi(chunked_documents)
824
+ nodes_from_dict = llamaindex_convert_docs_multi(dict_from_nodes)
825
+ else:
826
+ dict_from_nodes = None
827
+ nodes_from_dict = None
828
+ return (dict_from_nodes,)
829
 
830
 
831
  @app.cell
832
+ def chunks_to_process(
833
+ dict_from_nodes,
834
+ document_range_stack,
835
+ get_data_in_range_triplequote,
836
+ ):
837
+ if dict_from_nodes is not None and document_range_stack.value is not None:
838
+
839
+ chunk_dict_df = create_cumulative_dataframe(dict_from_nodes)
840
+
841
+ if document_range_stack.value is not None:
842
+ chunk_start_idx = document_range_stack.value[0]
843
+ chunk_end_idx = document_range_stack.value[1]
844
+ else:
845
+ chunk_start_idx = 0
846
+ chunk_end_idx = len(chunk_dict_df)
847
+
848
+ chunk_range_index = [chunk_start_idx, chunk_end_idx]
849
+ chunks_dict = get_data_in_range_triplequote(chunk_dict_df,
850
+ index_range=chunk_range_index,
851
+ columns_to_include=["text"])
852
+
853
+ chunks_to_process = chunks_dict['text'] if 'text' in chunks_dict else []
854
+ else:
855
+ chunk_objects = None
856
+ chunks_dict = None
857
+ chunks_to_process = None
858
+ return chunks_dict, chunks_to_process
859
 
860
 
861
+ @app.cell
862
+ def helper_function_doc_formatting():
863
+ def llamaindex_convert_docs_multi(items):
864
  """
865
+ Automatically convert between document objects and dictionaries.
866
 
867
+ This function handles:
868
+ - Converting dictionaries to document objects
869
+ - Converting document objects to dictionaries
870
+ - Processing lists or individual items
871
+ - Supporting dictionary structures where values are lists of documents
872
 
873
+ Args:
874
+ items: A document object, dictionary, or list of either.
875
+ Can also be a dictionary mapping filenames to lists of documents.
876
 
877
+ Returns:
878
+ Converted item(s) maintaining the original structure
 
 
 
 
879
  """
880
+ # Handle empty or None input
881
+ if not items:
882
+ return []
883
+
884
+ # Handle dictionary mapping filenames to document lists (from load_pdf_data)
885
+ if isinstance(items, dict) and all(isinstance(v, list) for v in items.values()):
886
+ result = {}
887
+ for filename, doc_list in items.items():
888
+ result[filename] = llamaindex_convert_docs(doc_list)
889
+ return result
890
+
891
+ # Handle single items (not in a list)
892
+ if not isinstance(items, list):
893
+ # Single dictionary to document
894
+ if isinstance(items, dict):
895
+ # Determine document class
896
+ doc_class = None
897
+ if 'doc_type' in items:
898
+ import importlib
899
+ module_path, class_name = items['doc_type'].rsplit('.', 1)
900
+ module = importlib.import_module(module_path)
901
+ doc_class = getattr(module, class_name)
902
+ if not doc_class:
903
+ from llama_index.core.schema import Document
904
+ doc_class = Document
905
+ return doc_class.from_dict(items)
906
+ # Single document to dictionary
907
+ elif hasattr(items, 'to_dict'):
908
+ return items.to_dict()
909
+ # Return as is if can't convert
910
+ return items
911
+
912
+ # Handle list input
913
+ result = []
914
+
915
+ # Handle empty list
916
+ if len(items) == 0:
917
+ return result
918
+
919
+ # Determine the type of conversion based on the first non-None item
920
+ first_item = next((item for item in items if item is not None), None)
921
+
922
+ # If we found no non-None items, return empty list
923
+ if first_item is None:
924
+ return result
925
+
926
+ # Convert dictionaries to documents
927
+ if isinstance(first_item, dict):
928
+ # Get the right document class from the items themselves
929
+ doc_class = None
930
+ # Try to get doc class from metadata if available
931
+ if 'doc_type' in first_item:
932
+ import importlib
933
+ module_path, class_name = first_item['doc_type'].rsplit('.', 1)
934
+ module = importlib.import_module(module_path)
935
+ doc_class = getattr(module, class_name)
936
+ if not doc_class:
937
+ # Fallback to default Document class from llama_index
938
+ from llama_index.core.schema import Document
939
+ doc_class = Document
940
+
941
+ # Convert each dictionary to document
942
+ for item in items:
943
+ if isinstance(item, dict):
944
+ result.append(doc_class.from_dict(item))
945
+ elif item is None:
946
+ result.append(None)
947
+ elif isinstance(item, list):
948
+ result.append(llamaindex_convert_docs(item))
949
+ else:
950
+ result.append(item)
951
+
952
+ # Convert documents to dictionaries
953
+ else:
954
+ for item in items:
955
+ if hasattr(item, 'to_dict'):
956
+ result.append(item.to_dict())
957
+ elif item is None:
958
+ result.append(None)
959
+ elif isinstance(item, list):
960
+ result.append(llamaindex_convert_docs(item))
961
+ else:
962
+ result.append(item)
963
+
964
+ return result
965
+
966
+ def llamaindex_convert_docs(items):
967
+ """
968
+ Automatically convert between document objects and dictionaries.
969
 
970
+ Args:
971
+ items: A list of document objects or dictionaries
972
 
973
+ Returns:
974
+ List of converted items (dictionaries or document objects)
 
975
  """
976
+ result = []
977
+
978
+ # Handle empty or None input
979
+ if not items:
980
+ return result
981
+
982
+ # Determine the type of conversion based on the first item
983
+ if isinstance(items[0], dict):
984
+ # Get the right document class from the items themselves
985
+ # Look for a 'doc_type' or '__class__' field in the dictionary
986
+ doc_class = None
987
+
988
+ # Try to get doc class from metadata if available
989
+ if 'doc_type' in items[0]:
990
+ import importlib
991
+ module_path, class_name = items[0]['doc_type'].rsplit('.', 1)
992
+ module = importlib.import_module(module_path)
993
+ doc_class = getattr(module, class_name)
994
+
995
+ if not doc_class:
996
+ # Fallback to default Document class from llama_index
997
+ from llama_index.core.schema import Document
998
+ doc_class = Document
999
+
1000
+ # Convert dictionaries to documents
1001
+ for item in items:
1002
+ if isinstance(item, dict):
1003
+ result.append(doc_class.from_dict(item))
1004
+ else:
1005
+ # Convert documents to dictionaries
1006
+ for item in items:
1007
+ if hasattr(item, 'to_dict'):
1008
+ result.append(item.to_dict())
1009
+
1010
+ return result
1011
+ return (llamaindex_convert_docs_multi,)
1012
 
 
 
1013
 
1014
+ @app.cell
1015
+ def helper_function_create_df():
1016
+ def create_document_dataframes(dict_from_docs):
1017
  """
1018
+ Creates a pandas DataFrame for each file in the dictionary.
1019
+
1020
+ Args:
1021
+ dict_from_docs: Dictionary mapping filenames to lists of documents
1022
+
1023
+ Returns:
1024
+ List of pandas DataFrames, each representing all documents from a single file
1025
+ """
1026
+ dataframes = []
1027
+
1028
+ for filename, docs in dict_from_docs.items():
1029
+ # Create a list to hold all document records for this file
1030
+ file_records = []
1031
+
1032
+ for i, doc in enumerate(docs):
1033
+ # Convert the document to a format compatible with DataFrame
1034
+ if hasattr(doc, 'to_dict'):
1035
+ doc_data = doc.to_dict()
1036
+ elif isinstance(doc, dict):
1037
+ doc_data = doc
1038
+ else:
1039
+ doc_data = {'content': str(doc)}
1040
+
1041
+ # Add document index information
1042
+ doc_data['doc_index'] = i
1043
+
1044
+ # Add to the list of records for this file
1045
+ file_records.append(doc_data)
1046
+
1047
+ # Create a single DataFrame for all documents in this file
1048
+ if file_records:
1049
+ df = pd.DataFrame(file_records)
1050
+ df['filename'] = filename # Add filename as a column
1051
+ dataframes.append(df)
1052
+
1053
+ return dataframes
1054
+
1055
+ def create_dataframe_previews(dataframe_list, page_size=5):
1056
+ """
1057
+ Creates a list of mo.ui.dataframe components, one for each DataFrame in the input list.
1058
+
1059
+ Args:
1060
+ dataframe_list: List of pandas DataFrames (output from create_document_dataframes)
1061
+ page_size: Number of rows to show per page for each component
1062
+
1063
+ Returns:
1064
+ List of mo.ui.dataframe components
1065
+ """
1066
+ # Create a list of mo.ui.dataframe components
1067
+ preview_components = []
1068
+
1069
+ for df in dataframe_list:
1070
+ # Create a mo.ui.dataframe component for this DataFrame
1071
+ preview = mo.ui.dataframe(df, page_size=page_size)
1072
+ preview_components.append(preview)
1073
+
1074
+ return preview_components
1075
  return
1076
 
1077
 
1078
+ @app.cell
1079
+ def helper_function_chart_preparation():
1080
+ import altair as alt
1081
+ import numpy as np
1082
+ import plotly.express as px
1083
+ from sklearn.manifold import TSNE
1084
+
1085
+ def prepare_embedding_data(embeddings, texts, model_id=None, embedding_dimensions=None):
1086
  """
1087
+ Prepare embedding data for visualization
1088
 
1089
+ Args:
1090
+ embeddings: List of embeddings arrays
1091
+ texts: List of text strings
1092
+ model_id: Embedding model ID (optional)
1093
+ embedding_dimensions: Embedding dimensions (optional)
1094
 
1095
+ Returns:
1096
+ DataFrame with processed data and metadata
1097
+ """
1098
+ # Flatten embeddings (in case they're nested)
1099
+ flattened_embeddings = []
1100
+ for emb in embeddings:
1101
+ if isinstance(emb, list) and len(emb) > 0 and isinstance(emb[0], list):
1102
+ flattened_embeddings.append(emb[0]) # Take first element if nested
1103
+ else:
1104
+ flattened_embeddings.append(emb)
1105
+
1106
+ # Convert to numpy array
1107
+ embedding_array = np.array(flattened_embeddings)
1108
+
1109
+ # Apply dimensionality reduction (t-SNE)
1110
+ tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(embedding_array)-1))
1111
+ reduced_embeddings = tsne.fit_transform(embedding_array)
1112
+
1113
+ # Create truncated texts for display
1114
+ truncated_texts = [text[:50] + "..." if len(text) > 50 else text for text in texts]
1115
+
1116
+ # Create dataframe for visualization
1117
+ df = pd.DataFrame({
1118
+ "x": reduced_embeddings[:, 0],
1119
+ "y": reduced_embeddings[:, 1],
1120
+ "text": truncated_texts,
1121
+ "full_text": texts,
1122
+ "index": range(len(texts))
1123
+ })
1124
+
1125
+ # Add metadata
1126
+ metadata = {
1127
+ "model_id": model_id,
1128
+ "embedding_dimensions": embedding_dimensions
1129
+ }
1130
 
1131
+ return df, metadata
 
1132
 
1133
+ def create_embedding_chart(df, metadata=None):
1134
+ """
1135
+ Create an Altair chart for embedding visualization
1136
 
1137
+ Args:
1138
+ df: DataFrame with x, y coordinates and text
1139
+ metadata: Dictionary with model_id and embedding_dimensions
1140
 
1141
+ Returns:
1142
+ Altair chart
1143
+ """
1144
+ model_id = metadata.get("model_id") if metadata else None
1145
+ embedding_dimensions = metadata.get("embedding_dimensions") if metadata else None
1146
 
1147
+ selection = alt.selection_multi(fields=['index'])
 
1148
 
1149
+ base = alt.Chart(df).encode(
1150
+ x=alt.X("x:Q", title="Dimension 1"),
1151
+ y=alt.Y("y:Q", title="Dimension 2"),
1152
+ tooltip=["text", "index"]
1153
+ )
1154
 
1155
+ points = base.mark_circle(size=100).encode(
1156
+ color=alt.Color("index:N", legend=None),
1157
+ opacity=alt.condition(selection, alt.value(1), alt.value(0.2))
1158
+ ).add_selection(selection) # Add this line to apply the selection
1159
 
1160
+ text = base.mark_text(align="left", dx=7).encode(
1161
+ text="index:N"
1162
+ )
1163
 
1164
+ return (points + text).properties(
1165
+ width=700,
1166
+ height=500,
1167
+ title=f"Embedding Visualization{f' - Model: {model_id}' if model_id else ''}{f' ({embedding_dimensions} dimensions)' if embedding_dimensions else ''}"
1168
+ ).interactive()
 
 
 
 
 
1169
 
1170
+ def show_selected_text(indices, texts):
1171
+ """
1172
+ Create markdown display for selected texts
1173
 
1174
+ Args:
1175
+ indices: List of selected indices
1176
+ texts: List of all texts
1177
 
1178
+ Returns:
1179
+ Markdown string
1180
  """
1181
+ if not indices:
1182
+ return "No text selected"
1183
 
1184
+ selected_texts = [texts[i] for i in indices if i < len(texts)]
1185
+ return "\n\n".join([f"**Document {i}**:\n{text}" for i, text in zip(indices, selected_texts)])
1186
 
1187
+ def prepare_embedding_data_3d(embeddings, texts, model_id=None, embedding_dimensions=None):
1188
+ """
1189
+ Prepare embedding data for 3D visualization
1190
+
1191
+ Args:
1192
+ embeddings: List of embeddings arrays
1193
+ texts: List of text strings
1194
+ model_id: Embedding model ID (optional)
1195
+ embedding_dimensions: Embedding dimensions (optional)
1196
+
1197
+ Returns:
1198
+ DataFrame with processed data and metadata
1199
+ """
1200
+ # Flatten embeddings (in case they're nested)
1201
+ flattened_embeddings = []
1202
+ for emb in embeddings:
1203
+ if isinstance(emb, list) and len(emb) > 0 and isinstance(emb[0], list):
1204
+ flattened_embeddings.append(emb[0])
1205
+ else:
1206
+ flattened_embeddings.append(emb)
1207
+
1208
+ # Convert to numpy array
1209
+ embedding_array = np.array(flattened_embeddings)
1210
+
1211
+ # Handle the case of a single embedding differently
1212
+ if len(embedding_array) == 1:
1213
+ # For a single point, we don't need t-SNE, just use a fixed position
1214
+ reduced_embeddings = np.array([[0.0, 0.0, 0.0]])
1215
+ else:
1216
+ # Apply dimensionality reduction to 3D
1217
+ # Fix: Ensure perplexity is at least 1.0
1218
+ perplexity_value = max(1.0, min(30, len(embedding_array)-1))
1219
+ tsne = TSNE(n_components=3, random_state=42, perplexity=perplexity_value)
1220
+ reduced_embeddings = tsne.fit_transform(embedding_array)
1221
+
1222
+ # Format texts for display
1223
+ formatted_texts = []
1224
+ for text in texts:
1225
+ # Truncate if needed
1226
+ if len(text) > 500:
1227
+ text = text[:500] + "..."
1228
+
1229
+ # Insert line breaks for wrapping
1230
+ wrapped_text = ""
1231
+ for i in range(0, len(text), 50):
1232
+ wrapped_text += text[i:i+50] + "<br>"
1233
+
1234
+ formatted_texts.append("<b>"+wrapped_text+"</b>")
1235
+
1236
+ # Create dataframe for visualization
1237
+ df = pd.DataFrame({
1238
+ "x": reduced_embeddings[:, 0],
1239
+ "y": reduced_embeddings[:, 1],
1240
+ "z": reduced_embeddings[:, 2],
1241
+ "text": formatted_texts,
1242
+ "full_text": texts,
1243
+ "index": range(len(texts)),
1244
+ "embedding": flattened_embeddings # Store the original embeddings for later use
1245
+ })
1246
+
1247
+ # Add metadata
1248
+ metadata = {
1249
+ "model_id": model_id,
1250
+ "embedding_dimensions": embedding_dimensions
1251
+ }
1252
+
1253
+ return df, metadata
1254
+
1255
+ def create_3d_embedding_chart(df, metadata=None, chart_width=1200, chart_height=800, marker_size_var: int=3):
1256
+ """
1257
+ Create a 3D Plotly chart for embedding visualization with proximity-based coloring
1258
+ """
1259
+ model_id = metadata.get("model_id") if metadata else None
1260
+ embedding_dimensions = metadata.get("embedding_dimensions") if metadata else None
1261
+
1262
+ # Calculate the proximity between points
1263
+ from scipy.spatial.distance import pdist, squareform
1264
+ # Get the coordinates as a numpy array
1265
+ coords = df[['x', 'y', 'z']].values
1266
+
1267
+ # Calculate pairwise distances
1268
+ dist_matrix = squareform(pdist(coords))
1269
+
1270
+ # For each point, find its average distance to all other points
1271
+ avg_distances = np.mean(dist_matrix, axis=1)
1272
+
1273
+ # Add this to the dataframe - smaller values = closer to other points
1274
+ df['proximity'] = avg_distances
1275
+
1276
+ # Create 3D scatter plot with proximity-based coloring
1277
+ fig = px.scatter_3d(
1278
+ df,
1279
+ x='x',
1280
+ y='y',
1281
+ z='z',
1282
+ # x='petal_length', # Changed from 'x' to 'petal_length'
1283
+ # y='petal_width', # Changed from 'y' to 'petal_width'
1284
+ # z='petal_height',
1285
+ color='proximity', # Color based on proximity
1286
+ color_continuous_scale='Viridis_r', # Reversed so closer points are warmer colors
1287
+ hover_data=['text', 'index', 'proximity'],
1288
+ labels={'x': 'Dimension 1', 'y': 'Dimension 2', 'z': 'Dimension 3', 'proximity': 'Avg Distance'},
1289
+ # labels={'x': 'Dimension 1', 'y': 'Dimension 2', 'z': 'Dimension 3', 'proximity': 'Avg Distance'},
1290
+ title=f"<b>3D Embedding Visualization</b>{f' - Model: <b>{model_id}</b>' if model_id else ''}{f' <i>({embedding_dimensions} dimensions)</i>' if embedding_dimensions else ''}",
1291
+ text='index',
1292
+ # size_max=marker_size_var
1293
+ )
1294
+
1295
+ # Update marker size and layout
1296
+ # fig.update_traces(marker=dict(size=3), selector=dict(mode='markers'))
1297
+ fig.update_traces(
1298
+ marker=dict(
1299
+ size=marker_size_var, # Very small marker size
1300
+ opacity=0.7, # Slightly transparent
1301
+ symbol="diamond", # Use circle markers (other options: "square", "diamond", "cross", "x")
1302
+ line=dict(
1303
+ width=0.5, # Very thin border
1304
+ color="white" # White outline makes small dots more visible
1305
+ )
1306
+ ),
1307
+ textfont=dict(
1308
+ color="rgba(255, 255, 255, 0.3)",
1309
+ size=8
1310
+ ),
1311
+ # hovertemplate="<b>index=%{text}</b><br>%{customdata[0]}<br><br>Avg Distance=%{customdata[2]:.4f}<extra></extra>", ### Hover Changes
1312
+ hovertemplate="text:<br><b>%{customdata[0]}</b><br>index: <b>%{text}</b><br><br>Avg Distance: <b>%{customdata[2]:.4f}</b><extra></extra>",
1313
+ hoverinfo="text+name",
1314
+ hoverlabel=dict(
1315
+ bgcolor="white", # White background for hover labels
1316
+ font_size=12 # Font size for hover text
1317
+ ),
1318
+ selector=dict(type='scatter3d')
1319
+ )
1320
+
1321
+ # Keep your existing layout settings
1322
+ fig.update_layout(
1323
+ scene=dict(
1324
+ xaxis=dict(
1325
+ title='Dimension 1',
1326
+ nticks=40,
1327
+ backgroundcolor="rgb(10, 10, 20, 0.1)",
1328
+ gridcolor="white",
1329
+ showbackground=True,
1330
+ gridwidth=0.35,
1331
+ zerolinecolor="white",
1332
+ ),
1333
+ yaxis=dict(
1334
+ title='Dimension 2',
1335
+ nticks=40,
1336
+ backgroundcolor="rgb(10, 10, 20, 0.1)",
1337
+ gridcolor="white",
1338
+ showbackground=True,
1339
+ gridwidth=0.35,
1340
+ zerolinecolor="white",
1341
+ ),
1342
+ zaxis=dict(
1343
+ title='Dimension 3',
1344
+ nticks=40,
1345
+ backgroundcolor="rgb(10, 10, 20, 0.1)",
1346
+ gridcolor="white",
1347
+ showbackground=True,
1348
+ gridwidth=0.35,
1349
+ zerolinecolor="white",
1350
+ ),
1351
+ # Control camera view angle
1352
+ camera=dict(
1353
+ up=dict(x=0, y=0, z=1),
1354
+ center=dict(x=0, y=0, z=0),
1355
+ eye=dict(x=1.25, y=1.25, z=1.25),
1356
+ ),
1357
+ aspectratio=dict(x=1, y=1, z=1),
1358
+ aspectmode='data'
1359
+ ),
1360
+ width=int(chart_width),
1361
+ height=int(chart_height),
1362
+ margin=dict(r=20, l=10, b=10, t=50),
1363
+ paper_bgcolor="rgb(0, 0, 0)",
1364
+ plot_bgcolor="rgb(0, 0, 0)",
1365
+ coloraxis_colorbar=dict(
1366
+ title="Average Distance",
1367
+ thicknessmode="pixels", thickness=20,
1368
+ lenmode="pixels", len=400,
1369
+ yanchor="top", y=1,
1370
+ ticks="outside",
1371
+ dtick=0.1
1372
+ )
1373
+ )
1374
+
1375
+ return fig
1376
+ return create_3d_embedding_chart, prepare_embedding_data_3d
1377
+
1378
+
1379
+ @app.cell
1380
+ def helper_function_text_preparation():
1381
+ def convert_table_to_json_docs(df, selected_columns=None):
1382
+ """
1383
+ Convert a pandas DataFrame or dictionary to a list of JSON documents.
1384
+ Dynamically includes columns based on user selection.
1385
+ Column names are standardized to lowercase with underscores instead of spaces
1386
+ and special characters removed.
1387
+
1388
+ Args:
1389
+ df: The DataFrame or dictionary to process
1390
+ selected_columns: List of column names to include in the output documents
1391
+
1392
+ Returns:
1393
+ list: A list of dictionaries, each representing a row as a JSON document
1394
+ """
1395
+ import pandas as pd
1396
+ import re
1397
+
1398
+ def standardize_key(key):
1399
+ """Convert a column name to lowercase with underscores instead of spaces and no special characters"""
1400
+ if not isinstance(key, str):
1401
+ return str(key).lower()
1402
+ # Replace spaces with underscores and convert to lowercase
1403
+ key = key.lower().replace(' ', '_')
1404
+ # Remove special characters (keeping alphanumeric and underscores)
1405
+ return re.sub(r'[^\w]', '', key)
1406
+
1407
+ # Handle case when input is a dictionary
1408
+ if isinstance(df, dict):
1409
+ # Filter the dictionary to include only selected columns
1410
+ if selected_columns:
1411
+ return [{standardize_key(k): df.get(k, None) for k in selected_columns}]
1412
+ else:
1413
+ # If no columns selected, return all key-value pairs with standardized keys
1414
+ return [{standardize_key(k): v for k, v in df.items()}]
1415
+
1416
+ # Handle case when df is None
1417
+ if df is None:
1418
+ return []
1419
+
1420
+ # Ensure df is a DataFrame
1421
+ if not isinstance(df, pd.DataFrame):
1422
+ try:
1423
+ df = pd.DataFrame(df)
1424
+ except:
1425
+ return [] # Return empty list if conversion fails
1426
+
1427
+ # Now check if DataFrame is empty
1428
+ if df.empty:
1429
+ return []
1430
+
1431
+ # If no columns are specifically selected, use all available columns
1432
+ if not selected_columns or not isinstance(selected_columns, list) or len(selected_columns) == 0:
1433
+ selected_columns = list(df.columns)
1434
+
1435
+ # Determine which columns exist in the DataFrame
1436
+ available_columns = []
1437
+ columns_lower = {col.lower(): col for col in df.columns if isinstance(col, str)}
1438
+
1439
+ for col in selected_columns:
1440
+ if col in df.columns:
1441
+ available_columns.append(col)
1442
+ elif isinstance(col, str) and col.lower() in columns_lower:
1443
+ available_columns.append(columns_lower[col.lower()])
1444
+
1445
+ # If no valid columns found, return empty list
1446
+ if not available_columns:
1447
+ return []
1448
+
1449
+ # Process rows
1450
+ json_docs = []
1451
+ for _, row in df.iterrows():
1452
+ doc = {}
1453
+ for col in available_columns:
1454
+ value = row[col]
1455
+ # Standardize the column name when adding to document
1456
+ std_col = standardize_key(col)
1457
+ doc[std_col] = None if pd.isna(value) else value
1458
+ json_docs.append(doc)
1459
+
1460
+ return json_docs
1461
+
1462
+ def get_column_values(df, columns_to_include):
1463
  """
1464
+ Extract values from specified columns of a dataframe as lists.
1465
 
1466
+ Args:
1467
+ df: A pandas DataFrame
1468
+ columns_to_include: A list of column names to extract
1469
+
1470
+ Returns:
1471
+ Dictionary with column names as keys and their values as lists
1472
  """
1473
+ result = {}
1474
+
1475
+ # Validate that columns exist in the dataframe
1476
+ valid_columns = [col for col in columns_to_include if col in df.columns]
1477
+ invalid_columns = set(columns_to_include) - set(valid_columns)
1478
+
1479
+ if invalid_columns:
1480
+ print(f"Warning: These columns don't exist in the dataframe: {list(invalid_columns)}")
1481
+
1482
+ # Extract values for each valid column
1483
+ for col in valid_columns:
1484
+ result[col] = df[col].tolist()
1485
+
1486
+ return result
1487
+
1488
+ def get_data_in_range(doc_dict_df, index_range, columns_to_include):
1489
+ """
1490
+ Extract values from specified columns of a dataframe within a given index range.
1491
+
1492
+ Args:
1493
+ doc_dict_df: The pandas DataFrame to extract data from
1494
+ index_range: An integer specifying the number of rows to include (from 0 to index_range-1)
1495
+ columns_to_include: A list of column names to extract
1496
+
1497
+ Returns:
1498
+ Dictionary with column names as keys and their values (within the index range) as lists
1499
+ """
1500
+ # Validate the index range
1501
+ max_index = len(doc_dict_df)
1502
+ if index_range <= 0:
1503
+ print(f"Warning: Invalid index range {index_range}. Must be positive.")
1504
+ return {}
1505
+
1506
+ # Adjust index_range if it exceeds the dataframe length
1507
+ if index_range > max_index:
1508
+ print(f"Warning: Index range {index_range} exceeds dataframe length {max_index}. Using maximum length.")
1509
+ index_range = max_index
1510
+
1511
+ # Slice the dataframe to get rows from 0 to index_range-1
1512
+ df_subset = doc_dict_df.iloc[:index_range]
1513
+
1514
+ # Use the provided get_column_values function to extract column data
1515
+ return get_column_values(df_subset, columns_to_include)
1516
+
1517
+ def get_data_in_range_triplequote(doc_dict_df, index_range, columns_to_include):
1518
+ """
1519
+ Extract values from specified columns of a dataframe within a given index range.
1520
+ Wraps string values with triple quotes and escapes URLs.
1521
+
1522
+ Args:
1523
+ doc_dict_df: The pandas DataFrame to extract data from
1524
+ index_range: A list of two integers specifying the start and end indices of rows to include
1525
+ (e.g., [0, 10] includes rows from index 0 to 9 inclusive)
1526
+ columns_to_include: A list of column names to extract
1527
+ """
1528
+ # Validate the index range
1529
+ start_idx, end_idx = index_range
1530
+ max_index = len(doc_dict_df)
1531
+
1532
+ # Validate start index
1533
+ if start_idx < 0:
1534
+ print(f"Warning: Invalid start index {start_idx}. Using 0 instead.")
1535
+ start_idx = 0
1536
+
1537
+ # Validate end index
1538
+ if end_idx <= start_idx:
1539
+ print(f"Warning: End index {end_idx} must be greater than start index {start_idx}. Using {start_idx + 1} instead.")
1540
+ end_idx = start_idx + 1
1541
+
1542
+ # Adjust end index if it exceeds the dataframe length
1543
+ if end_idx > max_index:
1544
+ print(f"Warning: End index {end_idx} exceeds dataframe length {max_index}. Using maximum length.")
1545
+ end_idx = max_index
1546
+
1547
+ # Slice the dataframe to get rows from start_idx to end_idx-1
1548
+ # Using .loc with slice to preserve original indices
1549
+ df_subset = doc_dict_df.iloc[start_idx:end_idx]
1550
+
1551
+ # Use the provided get_column_values function to extract column data
1552
+ result = get_column_values(df_subset, columns_to_include)
1553
+
1554
+ # Process each string result to wrap in triple quotes
1555
+ for col in result:
1556
+ if isinstance(result[col], list):
1557
+ # Create a new list with items wrapped in triple quotes
1558
+ processed_items = []
1559
+ for item in result[col]:
1560
+ if isinstance(item, str):
1561
+ # Replace http:// and https:// with escaped versions
1562
+ item = item.replace("http://", "http\\://").replace("https://", "https\\://")
1563
+ # processed_items.append('"""' + item + '"""')
1564
+ processed_items.append(item)
1565
+ else:
1566
+ processed_items.append(item)
1567
+ result[col] = processed_items
1568
+ return result
1569
+ return (get_data_in_range_triplequote,)
1570
 
1571
 
1572
  @app.cell
1573
+ def prepare_doc_select(sentence_splitter_config):
1574
+ def prepare_document_selection(node_dict):
1575
+ """
1576
+ Creates document selection UI component.
1577
+ Args:
1578
+ node_dict: Dictionary mapping filenames to lists of documents
1579
+ Returns:
1580
+ mo.ui component for document selection
1581
+ """
1582
+ # Calculate total number of documents across all files
1583
+ total_docs = sum(len(docs) for docs in node_dict.values())
1584
+
1585
+ # Create a combined DataFrame of all documents for table selection
1586
+ all_docs_records = []
1587
+ doc_index_global = 0
1588
+ for filename, docs in node_dict.items():
1589
+ for i, doc in enumerate(docs):
1590
+ # Convert the document to a format compatible with DataFrame
1591
+ if hasattr(doc, 'to_dict'):
1592
+ doc_data = doc.to_dict()
1593
+ elif isinstance(doc, dict):
1594
+ doc_data = doc
1595
+ else:
1596
+ doc_data = {'content': str(doc)}
1597
+
1598
+ # Add metadata
1599
+ doc_data['filename'] = filename
1600
+ doc_data['doc_index'] = i
1601
+ doc_data['global_index'] = doc_index_global
1602
+ all_docs_records.append(doc_data)
1603
+ doc_index_global += 1
1604
+
1605
+ # Create UI component
1606
+ stop_value = max(total_docs, 2)
1607
+ llama_docs = mo.ui.range_slider(
1608
+ start=1,
1609
+ stop=stop_value,
1610
+ step=1,
1611
+ full_width=True,
1612
+ show_value=True,
1613
+ label="**Select a Range of Chunks to Visualize:**"
1614
+ ).form(submit_button_disabled=check_state(sentence_splitter_config.value))
1615
+
1616
+ return llama_docs
1617
+ return (prepare_document_selection,)
1618
 
1619
 
1620
+ @app.cell
1621
+ def document_range_selection(
1622
+ dict_from_nodes,
1623
+ prepare_document_selection,
1624
+ set_range_slider_state,
1625
+ ):
1626
+ if dict_from_nodes is not None:
1627
+ llama_docs = prepare_document_selection(dict_from_nodes)
1628
+ set_range_slider_state(llama_docs)
1629
+ else:
1630
+ bare_dict = {}
1631
+ llama_docs = prepare_document_selection(bare_dict)
1632
  return
1633
 
1634
 
1635
+ @app.function
1636
+ def create_cumulative_dataframe(dict_from_docs):
1637
+ """
1638
+ Creates a cumulative DataFrame from a nested dictionary of documents.
1639
+
1640
+ Args:
1641
+ dict_from_docs: Dictionary mapping filenames to lists of documents
1642
+
1643
+ Returns:
1644
+ DataFrame with all documents flattened with global indices
1645
+ """
1646
+ # Create a list to hold all document records
1647
+ all_records = []
1648
+ global_idx = 1 # Start from 1 to match range slider expectations
1649
+
1650
+ for filename, docs in dict_from_docs.items():
1651
+ for i, doc in enumerate(docs):
1652
+ # Convert the document to a dict format
1653
+ if hasattr(doc, 'to_dict'):
1654
+ doc_data = doc.to_dict()
1655
+ elif isinstance(doc, dict):
1656
+ doc_data = doc.copy()
1657
+ else:
1658
+ doc_data = {'content': str(doc)}
1659
+
1660
+ # Add additional metadata
1661
+ doc_data['filename'] = filename
1662
+ doc_data['doc_index'] = i
1663
+ doc_data['global_index'] = global_idx
1664
+
1665
+ # If there's 'content' but no 'text', create a 'text' field
1666
+ if 'content' in doc_data and 'text' not in doc_data:
1667
+ doc_data['text'] = doc_data['content']
1668
+
1669
+ all_records.append(doc_data)
1670
+ global_idx += 1
1671
+
1672
+ # Create DataFrame from all records
1673
+ return pd.DataFrame(all_records)
1674
+
1675
+
1676
+ @app.function
1677
+ def create_stats(texts_dict, bordered=False, object_names=None, group_by_row=False, items_per_row=6, gap=2, label="Chunk"):
1678
+ """
1679
+ Create a list of stat objects for each item in the specified dictionary.
1680
+
1681
+ Parameters:
1682
+ - texts_dict (dict): Dictionary containing the text data
1683
+ - bordered (bool): Whether the stats should be bordered
1684
+ - object_names (list or tuple): Two object names to use for label and value
1685
+ [label_object, value_object]
1686
+ - group_by_row (bool): Whether to group stats in rows (horizontal stacks)
1687
+ - items_per_row (int): Number of stat objects per row when group_by_row is True
1688
+
1689
+ Returns:
1690
+ - object: A vertical stack of stat objects or rows of stat objects
1691
+ """
1692
+ if not object_names or len(object_names) < 2:
1693
+ raise ValueError("You must provide two object names as a list or tuple")
1694
+
1695
+ label_object = object_names[0]
1696
+ value_object = object_names[1]
1697
+
1698
+ # Validate that both objects exist in the dictionary
1699
+ if label_object not in texts_dict:
1700
+ raise ValueError(f"Label object '{label_object}' not found in texts_dict")
1701
+ if value_object not in texts_dict:
1702
+ raise ValueError(f"Value object '{value_object}' not found in texts_dict")
1703
+
1704
+ # Determine how many items to process (based on the label object length)
1705
+ num_items = len(texts_dict[label_object])
1706
+
1707
+ # Create individual stat objects
1708
+ individual_stats = []
1709
+ for i in range(num_items):
1710
+ stat = mo.stat(
1711
+ label=texts_dict[label_object][i],
1712
+ value=f"{label} Number: {len(texts_dict[value_object][i])}",
1713
+ bordered=bordered
1714
+ )
1715
+ individual_stats.append(stat)
1716
+
1717
+ # If grouping is not enabled, just return a vertical stack of all stats
1718
+ if not group_by_row:
1719
+ return mo.vstack(individual_stats, wrap=False)
1720
+
1721
+ # Group stats into rows based on items_per_row
1722
+ rows = []
1723
+ for i in range(0, num_items, items_per_row):
1724
+ # Get a slice of stats for this row (up to items_per_row items)
1725
+ row_stats = individual_stats[i:i+items_per_row]
1726
+ # Create a horizontal stack for this row
1727
+ widths = [0.35] * len(row_stats)
1728
+ row = mo.hstack(row_stats, gap=gap, align="start", justify="center", widths=widths)
1729
+ rows.append(row)
1730
+
1731
+ # Return a vertical stack of all rows
1732
+ return mo.vstack(rows)
1733
+
1734
+
1735
+ @app.cell
1736
+ def prepare_chart_embeddings(
1737
+ chunks_to_process,
1738
+ emb_model,
1739
+ emb_model_emb_dim,
1740
+ get_embedding_state,
1741
+ prepare_embedding_data_3d,
1742
+ ):
1743
+ # chart_dataframe, chart_metadata = None, None
1744
+ if chunks_to_process is not None and get_embedding_state() is not None:
1745
+ chart_dataframe, chart_metadata = prepare_embedding_data_3d(
1746
+ get_embedding_state(),
1747
+ chunks_to_process,
1748
+ model_id=emb_model,
1749
+ embedding_dimensions=emb_model_emb_dim
1750
+ )
1751
+ else:
1752
+ chart_dataframe, chart_metadata = None, None
1753
+ return chart_dataframe, chart_metadata
1754
+
1755
+
1756
+ @app.cell
1757
+ def chart_dims():
1758
+ chart_dimensions = (
1759
+ mo.md('''
1760
+ > **Adjust Chart Window**
1761
+
1762
+ {chart_height}
1763
+
1764
+ {chat_width}
1765
+
1766
+ ''').batch(
1767
+ chart_height = mo.ui.slider(start=500, step=30, stop=1000, label="**Height:**", value=800, show_value=True),
1768
+ chat_width = mo.ui.slider(start=900, step=50, stop=1400, label="**Width:**", value=1200, show_value=True)
1769
+ )
1770
+ )
1771
+ return (chart_dimensions,)
1772
+
1773
+
1774
+ @app.cell
1775
+ def chart_dim_values(chart_dimensions):
1776
+ chart_height = chart_dimensions.value['chart_height']
1777
+ chart_width = chart_dimensions.value['chat_width']
1778
+ return chart_height, chart_width
1779
+
1780
+
1781
+ @app.cell
1782
+ def create_baseline_chart(
1783
+ chart_dataframe,
1784
+ chart_height,
1785
+ chart_metadata,
1786
+ chart_width,
1787
+ create_3d_embedding_chart,
1788
+ ):
1789
+ if chart_dataframe is not None and chart_metadata is not None:
1790
+ emb_plot = create_3d_embedding_chart(chart_dataframe, chart_metadata, chart_width, chart_height, marker_size_var=9)
1791
+ chart = mo.ui.plotly(emb_plot)
1792
+ else:
1793
+ emb_plot = None
1794
+ chart = None
1795
+ return (emb_plot,)
1796
+
1797
+
1798
+ @app.cell
1799
+ def test_query(get_chunk_state):
1800
+ placeholder = """How can i use watsonx.data to perform vector search?"""
1801
+
1802
+ query = mo.ui.text_area(label="**Write text to check:**", full_width=True, rows=8, value=placeholder).form(show_clear_button=True, submit_button_disabled=check_state(get_chunk_state()))
1803
+ return (query,)
1804
+
1805
+
1806
+ @app.cell
1807
+ def query_stack(chart_dimensions, query):
1808
+ # query_stack = mo.hstack([query], justify="space-around", align="center", widths=[0.65])
1809
+ query_stack = mo.hstack([query, chart_dimensions], justify="space-around", align="center", gap=15)
1810
+ return (query_stack,)
1811
+
1812
+
1813
+ @app.function
1814
+ def check_state(variable):
1815
+ return variable is None
1816
+
1817
+
1818
+ @app.cell
1819
+ def helper_function_add_query_to_chart():
1820
+ def add_query_to_embedding_chart(existing_chart, query_coords, query_text, marker_size=12):
1821
  """
1822
+ Add a query point to an existing 3D embedding chart as a large red dot.
1823
+
1824
+ Args:
1825
+ existing_chart: The existing plotly figure or chart data
1826
+ query_coords: Dictionary with 'x', 'y', 'z' coordinates for the query point
1827
+ query_text: Text of the query to display on hover
1828
+ marker_size: Size of the query marker (default: 18, typically 2x other markers)
1829
+
1830
+ Returns:
1831
+ A modified plotly figure with the query point added as a red dot
1832
  """
1833
+ import plotly.graph_objects as go
1834
+
1835
+ # Create a deep copy of the existing chart to avoid modifying the original
1836
+ import copy
1837
+ chart_copy = copy.deepcopy(existing_chart)
1838
+
1839
+ # Handle case where chart_copy is a dictionary or list (from mo.ui.plotly)
1840
+ if isinstance(chart_copy, (dict, list)):
1841
+ # Create a new plotly figure from the data
1842
+ import plotly.graph_objects as go
1843
+
1844
+ if isinstance(chart_copy, list):
1845
+ # If it's a list, assume it's a list of traces
1846
+ fig = go.Figure(data=chart_copy)
1847
+ else:
1848
+ # If it's a dict with 'data' and 'layout'
1849
+ fig = go.Figure(data=chart_copy.get('data', []), layout=chart_copy.get('layout', {}))
1850
+
1851
+ chart_copy = fig
1852
+
1853
+ # Create the query trace
1854
+ query_trace = go.Scatter3d(
1855
+ x=[query_coords['x']],
1856
+ y=[query_coords['y']],
1857
+ z=[query_coords['z']],
1858
+ mode='markers',
1859
+ name='Query',
1860
+ marker=dict(
1861
+ size=marker_size, # Typically 2x the size of other markers
1862
+ color='red', # Bright red color
1863
+ symbol='circle', # Circle shape
1864
+ opacity=0.70, # Fully opaque
1865
+ line=dict(
1866
+ width=1, # Thin white border
1867
+ color='white'
1868
+ )
1869
+ ),
1870
+ # text=['Query: ' + query_text],
1871
+ text=['<b>Query:</b><br>' + '<br>'.join([query_text[i:i+50] for i in range(0, len(query_text), 50)])], ### Text Wrapping
1872
+ hoverinfo="text+name"
1873
+ )
1874
+
1875
+ # Add the query trace to the chart copy
1876
+ chart_copy.add_trace(query_trace)
1877
+
1878
+ return chart_copy
1879
+
1880
+
1881
+ def get_query_coordinates(reference_embeddings=None, query_embedding=None):
1882
+ """
1883
+ Calculate appropriate coordinates for a query point based on reference embeddings.
1884
+
1885
+ This function handles several scenarios:
1886
+ 1. If both reference embeddings and query embedding are provided, it places the
1887
+ query near similar documents.
1888
+ 2. If only reference embeddings are provided, it places the query at a visible
1889
+ location near the center of the chart.
1890
+ 3. If neither are provided, it returns default origin coordinates.
1891
+
1892
+ Args:
1893
+ reference_embeddings: DataFrame with x, y, z coordinates from the main chart
1894
+ query_embedding: The embedding vector of the query
1895
+
1896
+ Returns:
1897
+ Dictionary with x, y, z coordinates for the query point
1898
+ """
1899
+ import numpy as np
1900
+
1901
+ # Default coordinates (origin with slight offset)
1902
+ default_coords = {'x': 0.0, 'y': 0.0, 'z': 0.0}
1903
+
1904
+ # If we don't have reference embeddings, return default
1905
+ if reference_embeddings is None or len(reference_embeddings) == 0:
1906
+ return default_coords
1907
+
1908
+ # If we have reference embeddings but no query embedding,
1909
+ # position at a visible location near the center
1910
+ if query_embedding is None:
1911
+ center_coords = {
1912
+ 'x': reference_embeddings['x'].mean(),
1913
+ 'y': reference_embeddings['y'].mean(),
1914
+ 'z': reference_embeddings['z'].mean()
1915
+ }
1916
+ return center_coords
1917
+
1918
+ # If we have both reference embeddings and query embedding,
1919
+ # try to position near similar documents
1920
+ try:
1921
+ from sklearn.metrics.pairwise import cosine_similarity
1922
+
1923
+ # Check if original embeddings are in the dataframe
1924
+ if 'embedding' in reference_embeddings.columns:
1925
+ # Get all document embeddings as a 2D array
1926
+ if isinstance(reference_embeddings['embedding'].iloc[0], list):
1927
+ doc_embeddings = np.array(reference_embeddings['embedding'].tolist())
1928
+ else:
1929
+ doc_embeddings = np.array([emb for emb in reference_embeddings['embedding'].values])
1930
+
1931
+ # Reshape query embedding for comparison
1932
+ query_emb_array = np.array(query_embedding)
1933
+ if query_emb_array.ndim == 1:
1934
+ query_emb_array = query_emb_array.reshape(1, -1)
1935
+
1936
+ # Calculate cosine similarities
1937
+ similarities = cosine_similarity(query_emb_array, doc_embeddings)[0]
1938
+
1939
+ # Find the closest document
1940
+ closest_idx = np.argmax(similarities)
1941
+
1942
+ # Use the position of the closest document, with slight offset for visibility
1943
+ query_coords = {
1944
+ 'x': reference_embeddings['x'].iloc[closest_idx] + 0.2,
1945
+ 'y': reference_embeddings['y'].iloc[closest_idx] + 0.2,
1946
+ 'z': reference_embeddings['z'].iloc[closest_idx] + 0.2
1947
+ }
1948
+ return query_coords
1949
+ except Exception as e:
1950
+ print(f"Error positioning query near similar documents: {e}")
1951
+
1952
+ # Fallback to center position if similarity calculation fails
1953
+ center_coords = {
1954
+ 'x': reference_embeddings['x'].mean(),
1955
+ 'y': reference_embeddings['y'].mean(),
1956
+ 'z': reference_embeddings['z'].mean()
1957
+ }
1958
+ return center_coords
1959
+ return add_query_to_embedding_chart, get_query_coordinates
1960
+
1961
+
1962
+ @app.cell
1963
+ def combined_chart_visualization(
1964
+ add_query_to_embedding_chart,
1965
+ chart_dataframe,
1966
+ emb_plot,
1967
+ embedding,
1968
+ get_query_coordinates,
1969
+ get_query_state,
1970
+ query,
1971
+ set_chart_state,
1972
+ set_query_state,
1973
+ ):
1974
+ # Usage with highlight_closest=True
1975
+ if chart_dataframe is not None and query.value:
1976
+ # Get the query embedding
1977
+ query_emb = embedding.embed_documents([query.value])
1978
+ set_query_state(query_emb)
1979
+
1980
+ # Get appropriate coordinates for the query
1981
+ query_coords = get_query_coordinates(
1982
+ reference_embeddings=chart_dataframe,
1983
+ query_embedding=get_query_state()
1984
+ )
1985
+
1986
+ # Add the query to the chart with closest points highlighted
1987
+ result = add_query_to_embedding_chart(
1988
+ existing_chart=emb_plot,
1989
+ query_coords=query_coords,
1990
+ query_text=query.value,
1991
+ )
1992
+
1993
+ chart_with_query = result
1994
+
1995
+ # Create the visualization
1996
+ combined_viz = mo.ui.plotly(chart_with_query)
1997
+ set_chart_state(combined_viz)
1998
+ else:
1999
+ combined_viz = None
2000
  return
2001
 
2002
 
2003
+ @app.cell
2004
+ def _():
2005
+ get_range_slider_state, set_range_slider_state = mo.state(None)
2006
+ return get_range_slider_state, set_range_slider_state
 
 
2007
 
 
 
 
2008
 
2009
+ @app.cell
2010
+ def _(get_range_slider_state):
2011
+ if get_range_slider_state() is not None:
2012
+ document_range_stack = get_range_slider_state()
2013
+ else:
2014
+ document_range_stack = None
2015
+ return (document_range_stack,)
 
 
2016
 
 
 
 
2017
 
2018
+ @app.cell
2019
+ def _():
2020
+ get_chart_state, set_chart_state = mo.state(None)
2021
+ return get_chart_state, set_chart_state
 
 
 
 
 
 
 
 
 
 
 
2022
 
 
 
2023
 
2024
+ @app.cell
2025
+ def _(get_chart_state, query):
2026
+ if query.value is not None:
2027
+ chart_visualization = get_chart_state()
2028
+ else:
2029
+ chart_visualization = None
2030
+ return (chart_visualization,)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2031
 
2032
+
2033
+ @app.cell
2034
+ def c(document_range_stack):
2035
+ chart_range_selection = mo.hstack([document_range_stack], justify="space-around", align="center", widths=[0.65])
2036
+ return (chart_range_selection,)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2037
 
2038
 
2039
  if __name__ == "__main__":