MilanM commited on
Commit
87b6e34
·
verified ·
1 Parent(s): f416eb7

Upload 3 files

Browse files
helper_functions/debug_helper_functions.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def debug_element(obj):
2
+ """Get all attributes and their string representations from an object using dir()."""
3
+ import copy
4
+ try:
5
+ # Create a deep copy of the object if possible
6
+ obj_copy = copy.deepcopy(obj)
7
+ except:
8
+ try:
9
+ # If deepcopy fails, try shallow copy
10
+ obj_copy = copy.copy(obj)
11
+ except:
12
+ # If copying fails completely, use the original object
13
+ obj_copy = obj
14
+
15
+ attributes = dir(obj_copy)
16
+ results = []
17
+
18
+ for attr in attributes:
19
+ try:
20
+ # Get the attribute value from the copy
21
+ value = getattr(obj_copy, attr)
22
+
23
+ # Handle callable attributes
24
+ if callable(value):
25
+ try:
26
+ # Try to call the method without arguments
27
+ result = value()
28
+ str_value = f"<callable result: {str(result)}>"
29
+ except Exception as call_error:
30
+ # If calling fails, just record it's a callable
31
+ str_value = f"<callable: {type(value).__name__} - error when called: {str(call_error)}>"
32
+ else:
33
+ str_value = str(value)
34
+
35
+ results.append(f"{attr}: {str_value}")
36
+ except Exception as e:
37
+ results.append(f"{attr}: <error accessing: {str(e)}>")
38
+
39
+ return results
helper_functions/helper_functions.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ibm_watsonx_ai import APIClient, Credentials
2
+ from typing import Dict, Optional, List, Union, Any, Set
3
+ import pandas as pd
4
+ import marimo as mo
5
+ import json
6
+ import glob
7
+ import io
8
+ import os
9
+
10
+ def get_cred_value(key, creds_var_name="baked_in_creds", default=""):
11
+ """
12
+ Helper function to safely get a value from a credentials dictionary.
13
+
14
+ Searches for credentials in:
15
+ 1. Global variables with the specified variable name
16
+ 2. Imported modules containing the specified variable name
17
+
18
+ Args:
19
+ key: The key to look up in the credentials dictionary.
20
+ creds_var_name: The variable name of the credentials dictionary.
21
+ default: The default value to return if the key is not found.
22
+ Returns:
23
+ The value from the credentials dictionary if it exists and contains the key,
24
+ otherwise returns the default value.
25
+ """
26
+ # Check if the credentials variable exists in globals
27
+ if creds_var_name in globals():
28
+ creds_dict = globals()[creds_var_name]
29
+ if isinstance(creds_dict, dict) and key in creds_dict:
30
+ return creds_dict[key]
31
+
32
+ # Check if credentials are in an imported module
33
+ import sys
34
+ for module_name, module_obj in sys.modules.items():
35
+ if hasattr(module_obj, creds_var_name):
36
+ creds_dict = getattr(module_obj, creds_var_name)
37
+ if isinstance(creds_dict, dict) and key in creds_dict:
38
+ return creds_dict[key]
39
+
40
+ return default
41
+
42
+ def get_key_by_value(dictionary, value):
43
+ for key, val in dictionary.items():
44
+ if val == value:
45
+ return key
46
+ return None
47
+
48
+ def markdown_spacing(number):
49
+ """Convert a number to that many '&nbsp;' characters."""
50
+ return '&nbsp;' * number
51
+
52
+ def wrap_with_spaces(text_to_wrap, prefix_spaces=2, suffix_spaces=2):
53
+ """Wrap text with non-breaking spaces on either side."""
54
+ prefix = markdown_spacing(prefix_spaces) if prefix_spaces > 0 else ""
55
+ suffix = markdown_spacing(suffix_spaces) if suffix_spaces > 0 else ""
56
+ return f"{prefix}{text_to_wrap}{suffix}"
57
+
58
+
59
+ def load_file_dataframe(file, file_extension, sheet_selector=None, excel_data=None, header_row=0):
60
+ """
61
+ Load a dataframe from an uploaded file with customizable header and row skipping.
62
+
63
+ Parameters:
64
+ -----------
65
+ file : marimo.ui.file object
66
+ The file upload component containing the file data
67
+ file_extension : str
68
+ The extension of the uploaded file (.xlsx, .xls, .csv, .json)
69
+ sheet_selector : marimo.ui.dropdown, optional
70
+ Dropdown component for selecting Excel sheets
71
+ excel_data : BytesIO, optional
72
+ BytesIO object containing Excel data
73
+ header_row : int, optional
74
+ Row index to use as column headers (0-based). Default is 0 (first row).
75
+ Use None to have pandas generate default column names.
76
+
77
+ Returns:
78
+ --------
79
+ tuple
80
+ (pandas.DataFrame, list) - The loaded dataframe and list of column names
81
+ """
82
+
83
+ dataframe = pd.DataFrame([])
84
+ column_names = []
85
+
86
+ if file.contents():
87
+ # Handle different file types
88
+ if file_extension in ['.xlsx', '.xls'] and sheet_selector is not None and sheet_selector.value:
89
+ # For Excel files - now we can safely access sheet_selector.value
90
+ excel_data.seek(0) # Reset buffer position
91
+ dataframe = pd.read_excel(
92
+ excel_data,
93
+ sheet_name=sheet_selector.value,
94
+ header=header_row,
95
+ engine="openpyxl" if file_extension == '.xlsx' else "xlrd"
96
+ )
97
+ column_names = list(dataframe.columns)
98
+ elif file_extension == '.csv':
99
+ # For CSV files
100
+ csv_data = io.StringIO(file.contents().decode('utf-8'))
101
+ dataframe = pd.read_csv(csv_data, header=header_row)
102
+ column_names = list(dataframe.columns)
103
+ elif file_extension == '.json':
104
+ # For JSON files
105
+ try:
106
+ json_data = json.loads(file.contents().decode('utf-8'))
107
+ # Handle different JSON structures
108
+ if isinstance(json_data, list):
109
+ dataframe = pd.DataFrame(json_data)
110
+ elif isinstance(json_data, dict):
111
+ # If it's a dictionary with nested structures, try to normalize it
112
+ if any(isinstance(v, (dict, list)) for v in json_data.values()):
113
+ # For nested JSON with consistent structure
114
+ dataframe = pd.json_normalize(json_data)
115
+ else:
116
+ # For flat JSON
117
+ dataframe = pd.DataFrame([json_data])
118
+ column_names = list(dataframe.columns)
119
+ except Exception as e:
120
+ print(f"Error parsing JSON: {e}")
121
+
122
+ return dataframe, column_names
123
+
124
+
125
+ def create_parameter_table(input_list, column_name="Active Options", label="Select the Parameters to set to Active",
126
+ selection_type="multi-cell", text_justify="center"):
127
+ """
128
+ Creates a marimo table for parameter selection.
129
+
130
+ Args:
131
+ input_list: List of parameter names to display in the table
132
+ column_name: Name of the column (default: "Active Options")
133
+ label: Label for the table (default: "Select the Parameters to set to Active:")
134
+ selection_type: Selection type, either "single-cell" or "multi-cell" (default: "multi-cell")
135
+ text_justify: Text justification for the column (default: "center")
136
+
137
+ Returns:
138
+ A marimo table configured for parameter selection
139
+ """
140
+ import marimo as mo
141
+
142
+ # Validate selection type
143
+ if selection_type not in ["single-cell", "multi-cell"]:
144
+ raise ValueError("selection_type must be either 'single-cell' or 'multi-cell'")
145
+
146
+ # Validate text justification
147
+ if text_justify not in ["left", "center", "right"]:
148
+ raise ValueError("text_justify must be one of: 'left', 'center', 'right'")
149
+
150
+ # Create the table
151
+ parameter_table = mo.ui.table(
152
+ label=f"**{label}**",
153
+ data={column_name: input_list},
154
+ selection=selection_type,
155
+ text_justify_columns={column_name: text_justify}
156
+ )
157
+
158
+ return parameter_table
159
+
160
+ def get_cell_values(parameter_options):
161
+ """
162
+ Extract active parameter values from a mo.ui.table.
163
+
164
+ Args:
165
+ parameter_options: A mo.ui.table with cell selection enabled
166
+
167
+ Returns:
168
+ Dictionary mapping parameter names to boolean values (True/False)
169
+ """
170
+ # Get all parameter names from the table data
171
+ all_params = set()
172
+
173
+ # Use the data property to get all options from the table
174
+ if hasattr(parameter_options, 'data'):
175
+ table_data = parameter_options.data
176
+
177
+ # Handle DataFrame-like structure
178
+ if hasattr(table_data, 'shape') and hasattr(table_data, 'iloc'):
179
+ for i in range(table_data.shape[0]):
180
+ # Get value from first column
181
+ if table_data.shape[1] > 0:
182
+ param = table_data.iloc[i, 0]
183
+ if param and isinstance(param, str):
184
+ all_params.add(param)
185
+
186
+ # Handle dict structure (common in marimo tables)
187
+ elif isinstance(table_data, dict):
188
+ # Get the first column's values
189
+ if len(table_data) > 0:
190
+ col_name = next(iter(table_data))
191
+ for param in table_data[col_name]:
192
+ if param and isinstance(param, str):
193
+ all_params.add(param)
194
+
195
+ # Create result dictionary with all parameters set to False by default
196
+ result = {param: False for param in all_params}
197
+
198
+ # Get the selected cells
199
+ if hasattr(parameter_options, 'value') and parameter_options.value is not None:
200
+ selected_cells = parameter_options.value
201
+
202
+ # Process selected cells
203
+ for cell in selected_cells:
204
+ if hasattr(cell, 'value') and cell.value in result:
205
+ result[cell.value] = True
206
+ elif isinstance(cell, dict) and 'value' in cell and cell['value'] in result:
207
+ result[cell['value']] = True
208
+ elif isinstance(cell, str) and cell in result:
209
+ result[cell] = True
210
+
211
+ return result
212
+
213
+ def convert_table_to_json_docs(df, selected_columns=None):
214
+ """
215
+ Convert a pandas DataFrame or dictionary to a list of JSON documents.
216
+ Dynamically includes columns based on user selection.
217
+ Column names are standardized to lowercase with underscores instead of spaces
218
+ and special characters removed.
219
+
220
+ Args:
221
+ df: The DataFrame or dictionary to process
222
+ selected_columns: List of column names to include in the output documents
223
+
224
+ Returns:
225
+ list: A list of dictionaries, each representing a row as a JSON document
226
+ """
227
+ import pandas as pd
228
+ import re
229
+
230
+ def standardize_key(key):
231
+ """Convert a column name to lowercase with underscores instead of spaces and no special characters"""
232
+ if not isinstance(key, str):
233
+ return str(key).lower()
234
+ # Replace spaces with underscores and convert to lowercase
235
+ key = key.lower().replace(' ', '_')
236
+ # Remove special characters (keeping alphanumeric and underscores)
237
+ return re.sub(r'[^\w]', '', key)
238
+
239
+ # Handle case when input is a dictionary
240
+ if isinstance(df, dict):
241
+ # Filter the dictionary to include only selected columns
242
+ if selected_columns:
243
+ return [{standardize_key(k): df.get(k, None) for k in selected_columns}]
244
+ else:
245
+ # If no columns selected, return all key-value pairs with standardized keys
246
+ return [{standardize_key(k): v for k, v in df.items()}]
247
+
248
+ # Handle case when df is None
249
+ if df is None:
250
+ return []
251
+
252
+ # Ensure df is a DataFrame
253
+ if not isinstance(df, pd.DataFrame):
254
+ try:
255
+ df = pd.DataFrame(df)
256
+ except:
257
+ return [] # Return empty list if conversion fails
258
+
259
+ # Now check if DataFrame is empty
260
+ if df.empty:
261
+ return []
262
+
263
+ # Process selected_columns if it's a dictionary of true/false values
264
+ if isinstance(selected_columns, dict):
265
+ # Extract keys where value is True
266
+ selected_columns = [col for col, include in selected_columns.items() if include]
267
+
268
+ # If no columns are specifically selected, use all available columns
269
+ if not selected_columns or not isinstance(selected_columns, list) or len(selected_columns) == 0:
270
+ selected_columns = list(df.columns)
271
+
272
+ # Determine which columns exist in the DataFrame
273
+ available_columns = []
274
+ columns_lower = {col.lower(): col for col in df.columns if isinstance(col, str)}
275
+
276
+ for col in selected_columns:
277
+ if col in df.columns:
278
+ available_columns.append(col)
279
+ elif isinstance(col, str) and col.lower() in columns_lower:
280
+ available_columns.append(columns_lower[col.lower()])
281
+
282
+ # If no valid columns found, return empty list
283
+ if not available_columns:
284
+ return []
285
+
286
+ # Process rows
287
+ json_docs = []
288
+ for _, row in df.iterrows():
289
+ doc = {}
290
+ for col in available_columns:
291
+ value = row[col]
292
+ # Standardize the column name when adding to document
293
+ std_col = standardize_key(col)
294
+ doc[std_col] = None if pd.isna(value) else value
295
+ json_docs.append(doc)
296
+
297
+ return json_docs
298
+
299
+ def filter_models_by_function(resources, function_type="prompt_chat"):
300
+ """
301
+ Filter model IDs from resources list that have a specific function type
302
+
303
+ Args:
304
+ resources (list): List of model resource objects
305
+ function_type (str, optional): Function type to filter by. Defaults to "prompt_chat".
306
+
307
+ Returns:
308
+ list: List of model IDs that have the specified function
309
+ """
310
+ filtered_model_ids = []
311
+
312
+ if not resources or not isinstance(resources, list):
313
+ return filtered_model_ids
314
+
315
+ for model in resources:
316
+ # Check if the model has a functions attribute
317
+ if "functions" in model and isinstance(model["functions"], list):
318
+ # Check if any function has the matching id
319
+ has_function = any(
320
+ func.get("id") == function_type
321
+ for func in model["functions"]
322
+ if isinstance(func, dict)
323
+ )
324
+
325
+ if has_function and "model_id" in model:
326
+ filtered_model_ids.append(model["model_id"])
327
+
328
+ return filtered_model_ids
329
+
330
+
331
+ def get_model_selection_table(client=None, model_type="all", filter_functionality=None, selection_mode="single-cell"):
332
+ """
333
+ Creates and displays a table for model selection based on specified parameters.
334
+
335
+ Args:
336
+ client: The client object for API calls. If None, returns default models.
337
+ model_type (str): Type of models to display. Options: "all", "chat", "embedding".
338
+ filter_functionality (str, optional): Filter models by functionality type.
339
+ Options include: "image_chat", "text_chat", "autoai_rag",
340
+ "text_generation", "multilingual", etc.
341
+ selection_mode (str): Mode for selecting table entries. Options: "single", "single-cell".
342
+ Defaults to "single-cell".
343
+
344
+ Returns:
345
+ The selected model ID from the displayed table.
346
+ """
347
+ # Default model list if client is None
348
+ default_models = ['mistralai/mistral-large']
349
+
350
+ if client is None:
351
+ # If no client, use default models
352
+ available_models = default_models
353
+ selection = mo.ui.table(
354
+ available_models,
355
+ selection="single",
356
+ label="Select a model to use.",
357
+ page_size=30,
358
+ )
359
+ return selection
360
+
361
+ # Get appropriate model specs based on model_type
362
+ if model_type == "chat":
363
+ model_specs = client.foundation_models.get_chat_model_specs()
364
+ elif model_type == "embedding":
365
+ model_specs = client.foundation_models.get_embeddings_model_specs()
366
+ else:
367
+ model_specs = client.foundation_models.get_model_specs()
368
+
369
+ # Extract resources from model specs
370
+ resources = model_specs.get("resources", [])
371
+
372
+ # Filter by functionality if specified
373
+ if filter_functionality and resources:
374
+ model_id_list = filter_models_by_function(resources, filter_functionality)
375
+ else:
376
+ # Create list of model IDs if no filtering
377
+ model_id_list = [resource["model_id"] for resource in resources]
378
+
379
+ # If no models available after filtering, use defaults
380
+ if not model_id_list:
381
+ model_id_list = default_models
382
+
383
+ # Create and display selection table
384
+ model_selector = mo.ui.table(
385
+ model_id_list,
386
+ selection=selection_mode,
387
+ label="Select a model to use.",
388
+ page_size=30,
389
+ initial_selection = [("0", "value")] if selection_mode == "single-cell" else [0]
390
+ ### For single-cell it must have [("<row_nr as a string>","column_name string")] to work as initial value
391
+ )
392
+
393
+ return model_selector, resources, model_id_list
394
+
395
+ def _enforce_model_selection(model_selection, model_id_list):
396
+ # If nothing is selected (empty list) or value is None
397
+ if not model_selection.value:
398
+ # Reset to first item
399
+ model = 0
400
+ model_selection._value = model_id_list[model]
401
+ print(model_selection.value)
402
+ return model_selection.value
403
+
404
+ def update_max_tokens_limit(model_selection, resources, model_id_list):
405
+ # Default value
406
+ default_max_tokens = 4096
407
+
408
+ try:
409
+ # Check if we have a selection and resources
410
+ if model_selection.value is None or not hasattr(model_selection, 'value'):
411
+ print("No model selection or selection has no value")
412
+ return default_max_tokens
413
+
414
+ if not resources or not isinstance(resources, list) or len(resources) == 0:
415
+ print("Resources is empty or not a list")
416
+ return default_max_tokens
417
+
418
+ # Get the model ID - handle both index selection and direct string selection
419
+ selected_value = model_selection.value
420
+ print(f"Raw selection value: {selected_value}")
421
+
422
+ # If it's an array with indices
423
+ if isinstance(selected_value, list) and len(selected_value) > 0:
424
+ if isinstance(selected_value[0], int) and 0 <= selected_value[0] < len(model_id_list):
425
+ selected_model_id = model_id_list[selected_value[0]]
426
+ else:
427
+ selected_model_id = str(selected_value[0]) # Convert to string if needed
428
+ else:
429
+ selected_model_id = str(selected_value) # Direct value
430
+
431
+ print(f"Selected model ID: {selected_model_id}")
432
+
433
+ # Find the model
434
+ for model in resources:
435
+ model_id = model.get("model_id")
436
+ if model_id == selected_model_id:
437
+ if "model_limits" in model and "max_output_tokens" in model["model_limits"]:
438
+ return model["model_limits"]["max_output_tokens"]
439
+ break
440
+
441
+ except Exception as e:
442
+ print(f"Error: {e}")
443
+
444
+ return default_max_tokens
445
+
446
+
447
+ def load_templates(
448
+ folder_path: str,
449
+ file_extensions: Optional[List[str]] = None,
450
+ strip_whitespace: bool = True
451
+ ) -> Dict[str, str]:
452
+ """
453
+ Load template files from a specified folder into a dictionary.
454
+
455
+ Args:
456
+ folder_path: Path to the folder containing template files
457
+ file_extensions: List of file extensions to include (default: ['.txt', '.md'])
458
+ strip_whitespace: Whether to strip leading/trailing whitespace from templates (default: True)
459
+
460
+ Returns:
461
+ Dictionary with filename (without extension) as key and file content as value
462
+ """
463
+ # Default extensions if none provided
464
+ if file_extensions is None:
465
+ file_extensions = ['.txt', '.md']
466
+
467
+ # Ensure extensions start with a dot
468
+ file_extensions = [ext if ext.startswith('.') else f'.{ext}' for ext in file_extensions]
469
+
470
+ templates = {"empty": " "} # Default empty template
471
+
472
+ # Create glob patterns for each extension
473
+ patterns = [os.path.join(folder_path, f'*{ext}') for ext in file_extensions]
474
+
475
+ # Find all matching files
476
+ for pattern in patterns:
477
+ for file_path in glob.glob(pattern):
478
+ try:
479
+ # Extract filename without extension to use as key
480
+ filename = os.path.basename(file_path)
481
+ template_name = os.path.splitext(filename)[0]
482
+
483
+ # Read file content
484
+ with open(file_path, 'r', encoding='utf-8') as file:
485
+ content = file.read()
486
+
487
+ # Strip whitespace if specified
488
+ if strip_whitespace:
489
+ content = content.strip()
490
+
491
+ templates[template_name] = content
492
+
493
+ except Exception as e:
494
+ print(f"Error loading template from {file_path}: {str(e)}")
495
+
496
+ return templates
helper_functions/table_helper_functions.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ def process_with_llm(fields_to_process, prompt_template, inf_model, params, batch_size=10):
3
+ """
4
+ Process documents with LLM using a prompt template with dynamic field mapping.
5
+ Uses template fields to extract values from pre-standardized document fields.
6
+
7
+ Args:
8
+ fields_to_process (list): List of document dictionaries to process
9
+ prompt_template (str): Template with {field_name} placeholders matching keys in documents
10
+ inf_model: The inference model instance to use for generation
11
+ params: Parameters to pass to the inference model
12
+ batch_size (int): Number of documents to process per batch
13
+
14
+ Returns:
15
+ list: Processed results from the LLM
16
+ """
17
+ import time
18
+ import re
19
+
20
+ # Safety check for inputs
21
+ if not fields_to_process or not inf_model:
22
+ print("Missing required inputs")
23
+ return []
24
+
25
+ # Handle case where prompt_template is a dictionary (from UI components)
26
+ if isinstance(prompt_template, dict) and 'value' in prompt_template:
27
+ prompt_template = prompt_template['value']
28
+ elif not isinstance(prompt_template, str):
29
+ print(f"Invalid prompt template type: {type(prompt_template)}, expected string")
30
+ return []
31
+
32
+ # Extract field names from the prompt template using regex
33
+ # This finds all strings between curly braces
34
+ field_pattern = r'\{([^{}]+)\}'
35
+ template_fields = re.findall(field_pattern, prompt_template)
36
+
37
+ if not template_fields:
38
+ print("No field placeholders found in template")
39
+ return []
40
+
41
+ # Create formatted prompts from the documents
42
+ formatted_prompts = []
43
+ for doc in fields_to_process:
44
+ try:
45
+ # Create a dictionary of field values to substitute
46
+ field_values = {}
47
+
48
+ for field in template_fields:
49
+ # Try direct match first
50
+ if field in doc:
51
+ field_values[field] = doc[field] if doc[field] is not None else ""
52
+ # If field contains periods (e.g., "data.title"), evaluate it
53
+ elif '.' in field:
54
+ try:
55
+ # Build a safe evaluation string
56
+ parts = field.split('.')
57
+ value = doc
58
+ for part in parts:
59
+ if isinstance(value, dict) and part in value:
60
+ value = value[part]
61
+ else:
62
+ value = None
63
+ break
64
+ field_values[field] = value if value is not None else ""
65
+ except:
66
+ field_values[field] = ""
67
+ else:
68
+ # Default to empty string if field not found
69
+ field_values[field] = ""
70
+
71
+ # Handle None values at the top level to ensure formatting works
72
+ for key in field_values:
73
+ if field_values[key] is None:
74
+ field_values[key] = ""
75
+
76
+ # Format the prompt with all available fields
77
+ prompt = prompt_template.format(**field_values)
78
+ formatted_prompts.append(prompt)
79
+
80
+ except Exception as e:
81
+ print(f"Error formatting prompt: {str(e)}")
82
+ print(f"Field values: {field_values}")
83
+ continue
84
+
85
+ # Return empty list if no valid prompts
86
+ if not formatted_prompts:
87
+ print("No valid prompts generated")
88
+ return []
89
+
90
+ # Print a sample of the formatted prompts for debugging
91
+ if formatted_prompts:
92
+ print(f"Sample formatted prompt: {formatted_prompts[0][:200]}...")
93
+
94
+ # Split into batches
95
+ batches = [formatted_prompts[i:i + batch_size] for i in range(0, len(formatted_prompts), batch_size)]
96
+
97
+ results = []
98
+
99
+ # Process each batch
100
+ for i, batch in enumerate(batches):
101
+ start_time = time.time()
102
+
103
+ try:
104
+ # Use the provided inference model to generate responses
105
+ print(f"Sending batch {i+1} of {len(batches)} to model")
106
+
107
+ # Call the inference model with the batch of prompts and params
108
+ batch_results = inf_model.generate_text(prompt=batch, params=params)
109
+
110
+ results.extend(batch_results)
111
+
112
+ except Exception as e:
113
+ print(f"Error in batch {i+1}: {str(e)}")
114
+ continue
115
+
116
+ end_time = time.time()
117
+ inference_time = end_time - start_time
118
+ print(f"Inference time for Batch {i+1}: {inference_time:.2f} seconds")
119
+
120
+ return results
121
+
122
+ def append_llm_results_to_dataframe(target_dataframe, fields_to_process, llm_results, selection_table, column_name=None):
123
+ """
124
+ Add LLM processing results directly to the target DataFrame using selection indices
125
+
126
+ Args:
127
+ target_dataframe (pandas.DataFrame): DataFrame to modify in-place
128
+ fields_to_process (list): List of document dictionaries that were processed
129
+ llm_results (list): Results from the process_with_llm function
130
+ selection_table: Table selection containing indices of rows to update
131
+ column_name (str, optional): Custom name for the new column
132
+ """
133
+ column_name = column_name or f"Added Column {len(list(target_dataframe))}"
134
+
135
+ # Initialize the new column with empty strings if it doesn't exist
136
+ if column_name not in target_dataframe.columns:
137
+ target_dataframe[column_name] = ""
138
+
139
+ # Safety checks
140
+ if not isinstance(llm_results, list) or not llm_results:
141
+ print("No LLM results to add")
142
+ return
143
+
144
+ # Get indices from selection table
145
+ if selection_table is not None and not selection_table.empty:
146
+ selected_indices = selection_table.index.tolist()
147
+
148
+ # Make sure we have the right number of results for the selected rows
149
+ if len(selected_indices) != len(llm_results):
150
+ print(f"Warning: Number of results ({len(llm_results)}) doesn't match selected rows ({len(selected_indices)})")
151
+
152
+ # Add results to the DataFrame at the selected indices
153
+ for idx, result in zip(selected_indices, llm_results):
154
+ try:
155
+ if idx < len(target_dataframe):
156
+ target_dataframe.at[idx, column_name] = result
157
+ else:
158
+ print(f"Warning: Selected index {idx} exceeds DataFrame length")
159
+ except Exception as e:
160
+ print(f"Error adding result to DataFrame: {str(e)}")
161
+ else:
162
+ print("No selection table provided or empty selection")
163
+
164
+ def add_llm_results_to_dataframe(original_df, fields_to_process, llm_results, column_name=None):
165
+ """
166
+ Add LLM processing results to a copy of the original DataFrame
167
+
168
+ Args:
169
+ original_df (pandas.DataFrame): Original DataFrame
170
+ fields_to_process (list): List of document dictionaries that were processed
171
+ llm_results (list): Results from the process_with_llm function
172
+
173
+ Returns:
174
+ pandas.DataFrame: Copy of original DataFrame with added "Added Column {len(list(original_df))}" column or a custom name
175
+ """
176
+ import pandas as pd
177
+
178
+ column_name = column_name or f"Added Column {len(list(original_df))}"
179
+
180
+ # Create a copy of the original DataFrame
181
+ result_df = original_df.copy()
182
+
183
+ # Initialize the new column with empty strings
184
+ result_df[column_name] = ""
185
+
186
+ # Safety checks
187
+ if not isinstance(llm_results, list) or not llm_results:
188
+ print("No LLM results to add")
189
+ return result_df
190
+
191
+ # Add results to the DataFrame
192
+ for i, (doc, result) in enumerate(zip(fields_to_process, llm_results)):
193
+ try:
194
+ # Find the matching row in the DataFrame
195
+ # This assumes the order of fields_to_process matches the original DataFrame
196
+ if i < len(result_df):
197
+ result_df.at[i, column_name] = result
198
+ else:
199
+ print(f"Warning: Result index {i} exceeds DataFrame length")
200
+ except Exception as e:
201
+ print(f"Error adding result to DataFrame: {str(e)}")
202
+ continue
203
+
204
+ return result_df
205
+
206
+
207
+ def display_answers_as_markdown(answers, mo):
208
+ """
209
+ Takes a list of answers and displays each one as markdown using mo.md()
210
+
211
+ Args:
212
+ answers (list): List of text answers from the LLM
213
+ mo: The existing marimo module from the environment
214
+
215
+ Returns:
216
+ list: List of markdown elements
217
+ """
218
+ # Handle case where answers is None or empty
219
+ if not answers:
220
+ return [mo.md("No answers available")]
221
+
222
+ # Create markdown for each answer
223
+ markdown_elements = []
224
+ for i, answer in enumerate(answers):
225
+ # Create a formatted markdown element with answer number and content
226
+ md_element = mo.md(f"""\n\n---\n\n# Answer {i+1}\n\n{answer}""")
227
+ markdown_elements.append(md_element)
228
+
229
+ return markdown_elements
230
+
231
+ def display_answers_stacked(answers, mo):
232
+ """
233
+ Takes a list of answers and displays them stacked vertically using mo.vstack()
234
+
235
+ Args:
236
+ answers (list): List of text answers from the LLM
237
+ mo: The existing marimo module from the environment
238
+
239
+ Returns:
240
+ element: A vertically stacked collection of markdown elements
241
+ """
242
+ # Get individual markdown elements
243
+ md_elements = display_answers_as_markdown(answers, mo)
244
+
245
+ # Add separator between each answer
246
+ separator = mo.md("---")
247
+ elements_with_separators = []
248
+
249
+ for i, elem in enumerate(md_elements):
250
+ elements_with_separators.append(elem)
251
+ if i < len(md_elements) - 1: # Don't add separator after the last element
252
+ elements_with_separators.append(separator)
253
+
254
+ # Return a vertically stacked collection
255
+ return mo.vstack(elements_with_separators, align="start", gap="2")