Maharshi Gor
Add user input validation to pipeline interfaces error display on pipeline change.
849566b
raw
history blame
14.3 kB
import gradio as gr
import yaml
from loguru import logger
from app_configs import UNSELECTED_VAR_NAME
from components import commons
from components import typed_dicts as td
from components.model_pipeline.state_manager import (
ModelStepUIState,
PipelineState,
PipelineStateManager,
PipelineUIState,
PipelineValidator,
TossupPipelineState,
TossupPipelineStateManager,
)
from components.model_step.model_step import ModelStepComponent
from components.utils import make_state
from workflows.structs import ModelStep, TossupWorkflow, Workflow
from workflows.validators import WorkflowValidationError, WorkflowValidator
from .state_manager import get_output_panel_state
DEFAULT_MAX_TEMPERATURE = 5.0
class PipelineInterface:
"""UI for the pipeline."""
def __init__(
self,
app: gr.Blocks,
workflow: Workflow,
ui_state: PipelineUIState | None = None,
model_options: list[str] = None,
config: dict = {},
validator: PipelineValidator | None = None,
):
self.app = app
self.model_options = model_options
self.config = config
self.simple = self.config.get("simple", False)
ui_state = ui_state or PipelineUIState.from_workflow(workflow)
# Gradio States
self.workflow_state = make_state(workflow.model_dump())
self.variables_state = make_state(workflow.get_available_variables())
self.output_panel_state = make_state(get_output_panel_state(workflow))
# Maintains the toggle state change for pipeline changes through user input.
self.pipeline_change = gr.State(False)
if isinstance(workflow, TossupWorkflow):
pipeline_state = TossupPipelineState(workflow=workflow, ui_state=ui_state)
self.sm = TossupPipelineStateManager(validator)
else:
pipeline_state = PipelineState(workflow=workflow, ui_state=ui_state)
self.sm = PipelineStateManager(validator)
self.pipeline_state = make_state(pipeline_state.model_dump())
def get_aux_states(pipeline_state_dict: td.PipelineStateDict):
"""Get the auxiliary states for the pipeline."""
logger.debug("Pipeline changed! Getting aux states for pipeline state.")
pipeline_state = self.sm.make_pipeline_state(pipeline_state_dict)
return (
pipeline_state.workflow.model_dump(),
pipeline_state.workflow.get_available_variables(),
get_output_panel_state(pipeline_state.workflow),
)
# Triggers for pipeline state changes
self.pipeline_state.change(
get_aux_states,
inputs=[self.pipeline_state],
outputs=[self.workflow_state, self.variables_state, self.output_panel_state],
trigger_mode="multiple",
)
# IO Variables
self.input_variables = workflow.inputs
self.required_output_variables = list(workflow.outputs.keys())
# UI elements
self.steps_container = None
self.components = []
# Render the pipeline UI
self.render()
def _render_step(
self,
model_step: ModelStep,
step_ui_state: ModelStepUIState,
available_variables: list[str],
position: int = 0,
n_steps: int = 1,
):
with gr.Column(elem_classes="step-container"):
# Create the step component
step_interface = ModelStepComponent(
value=model_step,
ui_state=step_ui_state,
model_options=self.model_options,
input_variables=available_variables,
pipeline_state_manager=self.sm,
max_temperature=self.config.get("max_temperature", DEFAULT_MAX_TEMPERATURE),
)
step_interface.on_model_step_change(
self.sm.update_model_step_state,
inputs=[self.pipeline_state, step_interface.model_step_state, step_interface.ui_state],
outputs=[self.pipeline_state],
)
step_interface.on_ui_change(
self.sm.update_model_step_ui,
inputs=[self.pipeline_state, step_interface.ui_state, gr.State(model_step.id)],
outputs=[self.pipeline_state],
)
if self.simple:
return step_interface
is_multi_step = n_steps > 1
# logger.debug(f"Rendering step {position} of {n_steps}")
# Add step controls below
with gr.Row(elem_classes="step-controls", visible=is_multi_step):
up_button = gr.Button("⬆️ Move Up", elem_classes="step-control-btn", interactive=is_multi_step)
down_button = gr.Button("⬇️ Move Down", elem_classes="step-control-btn", interactive=is_multi_step)
remove_button = gr.Button("🗑️ Remove", elem_classes="step-control-btn", interactive=is_multi_step)
buttons = (up_button, down_button, remove_button)
self._assign_step_controls(buttons, position)
return (step_interface, *buttons)
def _assign_step_controls(self, buttons: tuple[gr.Button, gr.Button, gr.Button], position: int):
up_button, down_button, remove_button = buttons
position = gr.State(position)
up_button.click(
self.sm.move_up,
inputs=[self.pipeline_state, self.pipeline_change, position],
outputs=[self.pipeline_state, self.pipeline_change],
)
down_button.click(
self.sm.move_down,
inputs=[self.pipeline_state, self.pipeline_change, position],
outputs=[self.pipeline_state, self.pipeline_change],
)
remove_button.click(
self.sm.remove_step,
inputs=[self.pipeline_state, self.pipeline_change, position],
outputs=[self.pipeline_state, self.pipeline_change],
)
def _render_add_step_button(self, position: int):
if position not in {0, -1}:
raise ValueError("Position must be 0 or -1")
row_class = "pipeline-header" if position == 0 else "pipeline-footer"
with gr.Row(elem_classes=row_class):
add_step_btn = gr.Button("➕ Add Step", elem_classes="add-step-button")
add_step_btn.click(
self.sm.add_step,
inputs=[self.pipeline_state, self.pipeline_change, gr.State(position)],
outputs=[self.pipeline_state, self.pipeline_change],
)
return add_step_btn
def validate_workflow(self, state_dict: td.PipelineStateDict):
"""Validate the workflow."""
try:
state = self.sm.make_pipeline_state(state_dict)
validator = WorkflowValidator(
max_temperature=self.config.get("max_temperature", 10),
)
if not validator.validate(state.workflow):
raise WorkflowValidationError(validator.errors)
except ValueError as e:
logger.exception(e)
state_dict_str = yaml.dump(state_dict, default_flow_style=False, indent=2)
logger.error(f"Could not validate workflow: \n{state_dict_str}")
raise gr.Error(e)
def _render_pipeline_header(self):
# Add Step button at top
input_variables_str = ", ".join([f"`{variable}`" for variable in self.input_variables])
output_variables_str = ", ".join([f"`{variable}`" for variable in self.required_output_variables])
if self.simple:
instruction = "Create a simple single LLM call pipeline that takes in the following input variables and outputs the following output variables:"
else:
instruction = "Create a pipeline with the following input and output variables."
gr.Markdown(f"### {instruction}")
gr.Markdown(f"* Input Variables: {input_variables_str}")
gr.Markdown(f"* Output Variables: {output_variables_str}")
def _render_model_steps(self, pipeline_state: PipelineState):
ui_state = pipeline_state.ui_state
workflow = pipeline_state.workflow
components = []
step_objects = [] # Reset step objects list
for i, step_id in enumerate(ui_state.step_ids):
step_data = workflow.steps[step_id]
step_ui_state = ui_state.steps[step_id]
available_variables = pipeline_state.get_available_variables(step_id)
sub_components = self._render_step(step_data, step_ui_state, available_variables, i, ui_state.n_steps)
step_objects.append(sub_components)
components.append(step_objects)
return components
def _render_output_panel(self, pipeline_state: PipelineState):
dropdowns = {}
available_variables = pipeline_state.workflow.get_available_variables()
variable_options = [UNSELECTED_VAR_NAME] + [v for v in available_variables if v not in self.input_variables]
with gr.Column(elem_classes="step-accordion control-panel"):
commons.get_panel_header(
header="Final output variables mapping:",
)
with gr.Row(elem_classes="output-fields-row"):
for output_field in self.required_output_variables:
value = pipeline_state.workflow.outputs.get(output_field, UNSELECTED_VAR_NAME)
dropdown = gr.Dropdown(
label=output_field,
value=value,
choices=variable_options,
interactive=True,
elem_classes="output-field-variable",
# show_label=False,
)
dropdown.change(
self.sm.update_output_variables,
inputs=[self.pipeline_state, gr.State(output_field), dropdown],
outputs=[self.pipeline_state],
)
dropdowns[output_field] = dropdown
def update_choices(available_variables: list[str]):
"""Update the choices for the dropdowns"""
return [gr.update(choices=available_variables, value=None, selected=None) for _ in dropdowns.values()]
self.variables_state.change(
update_choices,
inputs=[self.variables_state],
outputs=list(dropdowns.values()),
)
return dropdowns
def _render_pipeline_preview(self):
export_btn = gr.Button("Export Pipeline", elem_classes="export-button", visible=False)
# components.append(export_btn)
self.error_display = gr.HTML(label="Error", elem_id="pipeline-preview-error-display", visible=False)
# Add a code box to display the workflow JSON
# with gr.Column(elem_classes="workflow-json-container"):
with gr.Accordion(
"Pipeline Preview (click to expand and edit)", open=False, elem_classes="pipeline-preview"
) as self.config_accordion:
self.config_output = gr.Code(
show_label=False,
language="yaml",
elem_classes="workflow-json",
interactive=True,
autocomplete=True,
)
# components.append(config_accordion)
self.config_output.blur(
fn=self.sm.update_workflow_from_code,
inputs=[self.config_output, self.pipeline_change],
outputs=[self.pipeline_state, self.pipeline_change, self.error_display],
)
# Connect the export button to show the workflow JSON
self.add_triggers_for_pipeline_export([export_btn.click], self.pipeline_state, scroll=True)
export_btn.click(self.validate_workflow, inputs=[self.pipeline_state], outputs=[]).success(
fn=lambda: gr.update(visible=True, open=True), outputs=[self.config_accordion]
)
def render(self):
"""Render the pipeline UI."""
# Create a placeholder for all the step components
self.all_components = []
self._render_pipeline_header()
# Function to render all steps
@gr.render(
triggers=[self.pipeline_change.change],
inputs=[self.pipeline_state],
concurrency_limit=1,
concurrency_id="render_steps",
trigger_mode="multiple",
)
def render_steps(pipeline_state_dict: td.PipelineStateDict, evt: gr.EventData):
"""Render all steps in the pipeline"""
logger.info(
f"Rerender pipeline steps triggered! \nInput Pipeline's UI State:{pipeline_state_dict.get('ui_state')}\n Event: {evt.target} {evt._data}"
)
pipeline_state = self.sm.make_pipeline_state(pipeline_state_dict)
self._render_model_steps(pipeline_state)
if not self.simple:
self._render_add_step_button(-1)
@gr.render(
triggers=[self.output_panel_state.change],
inputs=[self.pipeline_state],
concurrency_limit=1,
concurrency_id="render_output_fields",
trigger_mode="multiple",
)
def render_output_fields(pipeline_state_dict: td.PipelineStateDict):
pipeline_state = self.sm.make_pipeline_state(pipeline_state_dict)
logger.debug(f"Rerendering output panel: {get_output_panel_state(pipeline_state.workflow)}")
self._render_output_panel(pipeline_state)
self._render_pipeline_preview()
def add_triggers_for_pipeline_export(self, triggers: list, input_pipeline_state: gr.State, scroll: bool = False):
js = None
if scroll:
js = "() => {document.querySelector('.pipeline-preview').scrollIntoView({behavior: 'smooth'})}"
# TODO: modify this validate workflow to user input level and not executable label.
# (workflows that can be converted to UI interface, no logical validation)
gr.on(
triggers,
self.validate_workflow,
inputs=[input_pipeline_state],
outputs=[],
).success(
fn=self.sm.get_formatted_config,
inputs=[self.pipeline_state, gr.State("yaml")],
outputs=[self.config_output, self.error_display],
js=js,
)