|
def process_with_llm(fields_to_process, prompt_template, inf_model, params, batch_size=10): |
|
""" |
|
Process documents with LLM using a prompt template with dynamic field mapping. |
|
Uses template fields to extract values from pre-standardized document fields. |
|
|
|
Args: |
|
fields_to_process (list): List of document dictionaries to process |
|
prompt_template (str): Template with {field_name} placeholders matching keys in documents |
|
inf_model: The inference model instance to use for generation |
|
params: Parameters to pass to the inference model |
|
batch_size (int): Number of documents to process per batch |
|
|
|
Returns: |
|
list: Processed results from the LLM |
|
""" |
|
import marimo as mo |
|
import time |
|
import re |
|
|
|
|
|
if not fields_to_process or not inf_model: |
|
print("Missing required inputs") |
|
return [] |
|
|
|
|
|
if isinstance(prompt_template, dict) and 'value' in prompt_template: |
|
prompt_template = prompt_template['value'] |
|
elif not isinstance(prompt_template, str): |
|
print(f"Invalid prompt template type: {type(prompt_template)}, expected string") |
|
return [] |
|
|
|
|
|
|
|
field_pattern = r'\{([^{}]+)\}' |
|
template_fields = re.findall(field_pattern, prompt_template) |
|
|
|
if not template_fields: |
|
print("No field placeholders found in template") |
|
return [] |
|
|
|
|
|
formatted_prompts = [] |
|
for doc in fields_to_process: |
|
try: |
|
|
|
field_values = {} |
|
|
|
for field in template_fields: |
|
|
|
if field in doc: |
|
field_values[field] = doc[field] if doc[field] is not None else "" |
|
|
|
elif '.' in field: |
|
try: |
|
|
|
parts = field.split('.') |
|
value = doc |
|
for part in parts: |
|
if isinstance(value, dict) and part in value: |
|
value = value[part] |
|
else: |
|
value = None |
|
break |
|
field_values[field] = value if value is not None else "" |
|
except: |
|
field_values[field] = "" |
|
else: |
|
|
|
field_values[field] = "" |
|
|
|
|
|
for key in field_values: |
|
if field_values[key] is None: |
|
field_values[key] = "" |
|
|
|
|
|
prompt = prompt_template.format(**field_values) |
|
formatted_prompts.append(prompt) |
|
|
|
except Exception as e: |
|
print(f"Error formatting prompt: {str(e)}") |
|
print(f"Field values: {field_values}") |
|
continue |
|
|
|
|
|
if not formatted_prompts: |
|
print("No valid prompts generated") |
|
return [] |
|
|
|
|
|
if formatted_prompts: |
|
print(f"Sample formatted prompt: {formatted_prompts[0][:200]}...") |
|
|
|
|
|
batches = [formatted_prompts[i:i + batch_size] for i in range(0, len(formatted_prompts), batch_size)] |
|
|
|
results = [] |
|
|
|
|
|
with mo.status.progress_bar( |
|
total=len(batches), |
|
title="Processing Batches", |
|
subtitle=f"Processing {len(formatted_prompts)} prompts in {len(batches)} batches", |
|
completion_title="Processing Complete", |
|
completion_subtitle=f"Processed {len(formatted_prompts)} prompts successfully", |
|
show_rate=True, |
|
show_eta=True, |
|
remove_on_exit=True |
|
) as progress: |
|
for i, batch in enumerate(batches): |
|
start_time = time.time() |
|
|
|
try: |
|
|
|
print(f"Sending batch {i+1} of {len(batches)} to model") |
|
|
|
|
|
batch_results = inf_model.generate_text(prompt=batch, params=params) |
|
|
|
results.extend(batch_results) |
|
|
|
except Exception as e: |
|
print(f"Error in batch {i+1}: {str(e)}") |
|
continue |
|
|
|
end_time = time.time() |
|
inference_time = end_time - start_time |
|
print(f"Inference time for Batch {i+1}: {inference_time:.2f} seconds") |
|
|
|
|
|
progress.update(increment=1) |
|
|
|
|
|
time.sleep(1) |
|
|
|
return results |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def append_llm_results_to_dataframe(target_dataframe, fields_to_process, llm_results, selection_table, column_name=None): |
|
""" |
|
Add LLM processing results directly to the target DataFrame using selection indices |
|
|
|
Args: |
|
target_dataframe (pandas.DataFrame): DataFrame to modify in-place |
|
fields_to_process (list): List of document dictionaries that were processed |
|
llm_results (list): Results from the process_with_llm function |
|
selection_table: Table selection containing indices of rows to update |
|
column_name (str, optional): Custom name for the new column |
|
""" |
|
column_name = column_name or f"Added Column {len(list(target_dataframe))}" |
|
|
|
|
|
if column_name not in target_dataframe.columns: |
|
target_dataframe[column_name] = "" |
|
|
|
|
|
if not isinstance(llm_results, list) or not llm_results: |
|
print("No LLM results to add") |
|
return |
|
|
|
|
|
if selection_table is not None and not selection_table.empty: |
|
selected_indices = selection_table.index.tolist() |
|
|
|
|
|
if len(selected_indices) != len(llm_results): |
|
print(f"Warning: Number of results ({len(llm_results)}) doesn't match selected rows ({len(selected_indices)})") |
|
|
|
|
|
for idx, result in zip(selected_indices, llm_results): |
|
try: |
|
if idx < len(target_dataframe): |
|
target_dataframe.at[idx, column_name] = result |
|
else: |
|
print(f"Warning: Selected index {idx} exceeds DataFrame length") |
|
except Exception as e: |
|
print(f"Error adding result to DataFrame: {str(e)}") |
|
else: |
|
print("No selection table provided or empty selection") |
|
|
|
def add_llm_results_to_dataframe(original_df, fields_to_process, llm_results, column_name=None): |
|
""" |
|
Add LLM processing results to a copy of the original DataFrame |
|
|
|
Args: |
|
original_df (pandas.DataFrame): Original DataFrame |
|
fields_to_process (list): List of document dictionaries that were processed |
|
llm_results (list): Results from the process_with_llm function |
|
|
|
Returns: |
|
pandas.DataFrame: Copy of original DataFrame with added "Added Column {len(list(original_df))}" column or a custom name |
|
""" |
|
import pandas as pd |
|
|
|
column_name = column_name or f"Added Column {len(list(original_df))}" |
|
|
|
|
|
result_df = original_df.copy() |
|
|
|
|
|
result_df[column_name] = "" |
|
|
|
|
|
if not isinstance(llm_results, list) or not llm_results: |
|
print("No LLM results to add") |
|
return result_df |
|
|
|
|
|
for i, (doc, result) in enumerate(zip(fields_to_process, llm_results)): |
|
try: |
|
|
|
|
|
if i < len(result_df): |
|
result_df.at[i, column_name] = result |
|
else: |
|
print(f"Warning: Result index {i} exceeds DataFrame length") |
|
except Exception as e: |
|
print(f"Error adding result to DataFrame: {str(e)}") |
|
continue |
|
|
|
return result_df |
|
|
|
|
|
def display_answers_as_markdown(answers, mo): |
|
""" |
|
Takes a list of answers and displays each one as markdown using mo.md() |
|
|
|
Args: |
|
answers (list): List of text answers from the LLM |
|
mo: The existing marimo module from the environment |
|
|
|
Returns: |
|
list: List of markdown elements |
|
""" |
|
|
|
if not answers: |
|
return [mo.md("No answers available")] |
|
|
|
|
|
markdown_elements = [] |
|
for i, answer in enumerate(answers): |
|
|
|
md_element = mo.md(f"""\n\n---\n\n# Answer {i+1}\n\n{answer}""") |
|
markdown_elements.append(md_element) |
|
|
|
return markdown_elements |
|
|
|
def display_answers_stacked(answers, mo): |
|
""" |
|
Takes a list of answers and displays them stacked vertically using mo.vstack() |
|
|
|
Args: |
|
answers (list): List of text answers from the LLM |
|
mo: The existing marimo module from the environment |
|
|
|
Returns: |
|
element: A vertically stacked collection of markdown elements |
|
""" |
|
|
|
md_elements = display_answers_as_markdown(answers, mo) |
|
|
|
|
|
separator = mo.md("---") |
|
elements_with_separators = [] |
|
|
|
for i, elem in enumerate(md_elements): |
|
elements_with_separators.append(elem) |
|
if i < len(md_elements) - 1: |
|
elements_with_separators.append(separator) |
|
|
|
|
|
return mo.vstack(elements_with_separators, align="start", gap="2") |