Maharshi Gor
Updated workflow APIs, code clean up and minor functions for hf pipeline support
f064c62
import gradio as gr | |
import yaml | |
from loguru import logger | |
from app_configs import UNSELECTED_VAR_NAME | |
from components import commons | |
from components import typed_dicts as td | |
from components.model_pipeline.state_manager import ( | |
BasePipelineValidator, | |
PipelineStateManager, | |
) | |
from components.model_step.model_step import ModelStepComponent | |
from components.structs import ModelStepUIState, PipelineState, PipelineUIState | |
from components.utils import make_state | |
from shared.workflows.structs import ModelStep, Workflow | |
from .state_manager import get_output_panel_state | |
DEFAULT_MAX_TEMPERATURE = 5.0 | |
class PipelineInterface: | |
"""UI for the pipeline.""" | |
state_manager_cls = PipelineStateManager | |
def __init__( | |
self, | |
app: gr.Blocks, | |
workflow: Workflow, | |
ui_state: PipelineUIState | None = None, | |
model_options: list[str] = None, | |
config: dict = {}, | |
validator: BasePipelineValidator | None = None, | |
): | |
self.app = app | |
self.model_options = model_options | |
self.config = config | |
self.simple = self.config.get("simple", False) | |
ui_state = ui_state or PipelineUIState.from_workflow(workflow) | |
# Gradio States | |
self.workflow_state = make_state(workflow.model_dump()) | |
self.variables_state = make_state(workflow.get_available_variables()) | |
self.output_panel_state = make_state(get_output_panel_state(workflow)) | |
self.ui_validator = validator | |
# Maintains the toggle state change for pipeline changes through user input. | |
self.pipeline_change = gr.State(False) | |
self.sm = self.state_manager_cls(validator) | |
pipeline_state_dict = self.sm.create_pipeline_state_dict(workflow=workflow, ui_state=ui_state) | |
self.pipeline_state = make_state(pipeline_state_dict) | |
def get_aux_states(pipeline_state_dict: td.PipelineStateDict): | |
"""Get the auxiliary states for the pipeline.""" | |
logger.debug("Pipeline changed! Getting aux states for pipeline state.") | |
pipeline_state = self.sm.make_pipeline_state(pipeline_state_dict) | |
return ( | |
pipeline_state.workflow.model_dump(), | |
pipeline_state.workflow.get_available_variables(), | |
get_output_panel_state(pipeline_state.workflow), | |
) | |
# Triggers for pipeline state changes | |
self.pipeline_state.change( | |
get_aux_states, | |
inputs=[self.pipeline_state], | |
outputs=[self.workflow_state, self.variables_state, self.output_panel_state], | |
trigger_mode="multiple", | |
) | |
# IO Variables | |
self.input_variables = workflow.inputs | |
self.required_output_variables = list(workflow.outputs.keys()) | |
# UI elements | |
self.steps_container = None | |
self.components = [] | |
# Render the pipeline UI | |
self.render() | |
def _render_step( | |
self, | |
model_step: ModelStep, | |
step_ui_state: ModelStepUIState, | |
available_variables: list[str], | |
position: int = 0, | |
n_steps: int = 1, | |
): | |
with gr.Column(elem_classes="step-container"): | |
# Create the step component | |
step_interface = ModelStepComponent( | |
value=model_step, | |
ui_state=step_ui_state, | |
model_options=self.model_options, | |
input_variables=available_variables, | |
pipeline_state_manager=self.sm, | |
max_temperature=self.config.get("max_temperature", DEFAULT_MAX_TEMPERATURE), | |
) | |
step_interface.on_model_step_change( | |
self.sm.update_model_step_state, | |
inputs=[self.pipeline_state, step_interface.model_step_state, step_interface.ui_state], | |
outputs=[self.pipeline_state], | |
) | |
step_interface.on_ui_change( | |
self.sm.update_model_step_ui, | |
inputs=[self.pipeline_state, step_interface.ui_state, gr.State(model_step.id)], | |
outputs=[self.pipeline_state], | |
) | |
if self.simple: | |
return step_interface | |
is_multi_step = n_steps > 1 | |
# Add step controls below | |
with gr.Row(elem_classes="step-controls", visible=is_multi_step): | |
up_button = gr.Button("⬆️ Move Up", elem_classes="step-control-btn", interactive=is_multi_step) | |
down_button = gr.Button("⬇️ Move Down", elem_classes="step-control-btn", interactive=is_multi_step) | |
remove_button = gr.Button("🗑️ Remove", elem_classes="step-control-btn", interactive=is_multi_step) | |
buttons = (up_button, down_button, remove_button) | |
self._assign_step_controls(buttons, position) | |
return (step_interface, *buttons) | |
def _assign_step_controls(self, buttons: tuple[gr.Button, gr.Button, gr.Button], position: int): | |
up_button, down_button, remove_button = buttons | |
position = gr.State(position) | |
up_button.click( | |
self.sm.move_up, | |
inputs=[self.pipeline_state, self.pipeline_change, position], | |
outputs=[self.pipeline_state, self.pipeline_change], | |
) | |
down_button.click( | |
self.sm.move_down, | |
inputs=[self.pipeline_state, self.pipeline_change, position], | |
outputs=[self.pipeline_state, self.pipeline_change], | |
) | |
remove_button.click( | |
self.sm.remove_step, | |
inputs=[self.pipeline_state, self.pipeline_change, position], | |
outputs=[self.pipeline_state, self.pipeline_change], | |
) | |
def _render_add_step_button(self, position: int): | |
if position not in {0, -1}: | |
raise ValueError("Position must be 0 or -1") | |
row_class = "pipeline-header" if position == 0 else "pipeline-footer" | |
with gr.Row(elem_classes=row_class): | |
add_step_btn = gr.Button("➕ Add Step", elem_classes="add-step-button") | |
add_step_btn.click( | |
self.sm.add_step, | |
inputs=[self.pipeline_state, self.pipeline_change, gr.State(position)], | |
outputs=[self.pipeline_state, self.pipeline_change], | |
) | |
return add_step_btn | |
def validate_workflow_ui(self, state_dict: td.PipelineStateDict): | |
"""Validate the workflow.""" | |
try: | |
state = self.sm.make_pipeline_state(state_dict) | |
self.ui_validator(state.workflow) | |
return gr.update(visible=False) | |
except ValueError as e: | |
logger.exception(e) | |
return gr.update(visible=True, value=str(e)) | |
def _render_pipeline_header(self): | |
# Add Step button at top | |
input_variables_str = ", ".join([f"`{variable}`" for variable in self.input_variables]) | |
output_variables_str = ", ".join([f"`{variable}`" for variable in self.required_output_variables]) | |
if self.simple: | |
instruction = "Create a simple single LLM call pipeline that takes in the following input variables and outputs the following output variables:" | |
else: | |
instruction = "Create a pipeline with the following input and output variables." | |
gr.Markdown(f"### {instruction}") | |
gr.Markdown(f"* Input Variables: {input_variables_str}") | |
gr.Markdown(f"* Output Variables: {output_variables_str}") | |
def _render_model_steps(self, pipeline_state: PipelineState): | |
ui_state = pipeline_state.ui_state | |
workflow = pipeline_state.workflow | |
components = [] | |
step_objects = [] # Reset step objects list | |
for i, step_id in enumerate(ui_state.step_ids): | |
step_data = workflow.steps[step_id] | |
step_ui_state = ui_state.steps[step_id] | |
available_variables = pipeline_state.get_available_variables(step_id) | |
sub_components = self._render_step(step_data, step_ui_state, available_variables, i, ui_state.n_steps) | |
step_objects.append(sub_components) | |
components.append(step_objects) | |
return components | |
def _render_output_panel(self, pipeline_state: PipelineState): | |
dropdowns = {} | |
available_variables = pipeline_state.workflow.get_available_variables() | |
variable_options = [UNSELECTED_VAR_NAME] + [v for v in available_variables if v not in self.input_variables] | |
with gr.Column(elem_classes="step-accordion control-panel"): | |
commons.get_panel_header( | |
header="Final output variables mapping:", | |
) | |
with gr.Row(elem_classes="output-fields-row"): | |
for output_field in self.required_output_variables: | |
value = pipeline_state.workflow.outputs.get(output_field, UNSELECTED_VAR_NAME) | |
dropdown = gr.Dropdown( | |
label=output_field, | |
value=value, | |
choices=variable_options, | |
interactive=True, | |
elem_classes="output-field-variable", | |
# show_label=False, | |
) | |
dropdown.change( | |
self.sm.update_output_variables, | |
inputs=[self.pipeline_state, gr.State(output_field), dropdown], | |
outputs=[self.pipeline_state], | |
) | |
dropdowns[output_field] = dropdown | |
def update_choices(available_variables: list[str]): | |
"""Update the choices for the dropdowns""" | |
return [gr.update(choices=available_variables, value=None, selected=None) for _ in dropdowns.values()] | |
self.variables_state.change( | |
update_choices, | |
inputs=[self.variables_state], | |
outputs=list(dropdowns.values()), | |
) | |
return dropdowns | |
def _render_pipeline_preview(self): | |
export_btn = gr.Button("Export Pipeline", elem_classes="export-button", visible=False) | |
# components.append(export_btn) | |
self.error_display = gr.HTML(label="Error", elem_id="pipeline-preview-error-display", visible=False) | |
# Add a code box to display the workflow JSON | |
# with gr.Column(elem_classes="workflow-json-container"): | |
with gr.Accordion( | |
"Pipeline Preview (click to expand and edit)", open=False, elem_classes="pipeline-preview" | |
) as self.config_accordion: | |
self.config_output = gr.Code( | |
show_label=False, | |
language="yaml", | |
elem_classes="workflow-json", | |
interactive=True, | |
autocomplete=True, | |
) | |
# components.append(config_accordion) | |
self.config_output.blur( | |
fn=self.sm.update_workflow_from_code, | |
inputs=[self.config_output, self.pipeline_change], | |
outputs=[self.pipeline_state, self.pipeline_change, self.error_display], | |
) | |
# Connect the export button to show the workflow JSON | |
self.add_triggers_for_pipeline_export( | |
[export_btn.click], self.pipeline_state, scroll=True, expand_accordion=True | |
) | |
def render(self): | |
"""Render the pipeline UI.""" | |
# Create a placeholder for all the step components | |
self.all_components = [] | |
self._render_pipeline_header() | |
# Function to render all steps | |
def render_steps(pipeline_state_dict: td.PipelineStateDict, evt: gr.EventData): | |
"""Render all steps in the pipeline""" | |
logger.info( | |
f"Rerender pipeline steps triggered! \nInput Pipeline's UI State:{pipeline_state_dict.get('ui_state')}\n Event: {evt.target} {evt._data}" | |
) | |
pipeline_state = self.sm.make_pipeline_state(pipeline_state_dict) | |
self._render_model_steps(pipeline_state) | |
if not self.simple: | |
self._render_add_step_button(-1) | |
def render_output_fields(pipeline_state_dict: td.PipelineStateDict): | |
pipeline_state = self.sm.make_pipeline_state(pipeline_state_dict) | |
logger.debug(f"Rerendering output panel: {get_output_panel_state(pipeline_state.workflow)}") | |
self._render_output_panel(pipeline_state) | |
self._render_pipeline_preview() | |
def add_triggers_for_pipeline_export( | |
self, | |
triggers: list, | |
input_pipeline_state: gr.State, | |
scroll: bool = False, | |
expand_accordion: bool = False, | |
): | |
js = None | |
if scroll: | |
js = "() => {document.querySelector('.pipeline-preview').scrollIntoView({behavior: 'smooth'})}" | |
# TODO: modify this validate workflow to user input level and not executable label. | |
# (workflows that can be converted to UI interface, no logical validation) | |
event = gr.on( | |
triggers, | |
self.validate_workflow_ui, | |
inputs=[input_pipeline_state], | |
outputs=[self.error_display], | |
).success( | |
fn=self.sm.get_formatted_config, | |
inputs=[self.pipeline_state, gr.State("yaml")], | |
outputs=[self.config_output, self.error_display], | |
js=js, | |
) | |
if expand_accordion: | |
event.then(fn=lambda: gr.update(visible=True, open=True), outputs=[self.config_accordion]) | |