Maharshi Gor
Enhances model selection and logging in pipeline components; adds logprobs support and improves UI feedback for disabled sliders.
4f5d1cb
import gradio as gr | |
import yaml | |
from loguru import logger | |
from app_configs import UNSELECTED_VAR_NAME | |
from components import commons | |
from components.model_pipeline.state_manager import ( | |
ModelStepUIState, | |
PipelineState, | |
PipelineStateManager, | |
PipelineUIState, | |
) | |
from components.model_step.model_step import ModelStepComponent | |
from components.utils import make_state | |
from workflows.structs import ModelStep, Workflow | |
from workflows.validators import WorkflowValidator | |
def validate_simple_workflow(workflow: Workflow, required_output_variables: list[str]) -> Workflow: | |
"""Validate the workflow.""" | |
step = next(iter(workflow.steps.values())) | |
if not step.output_fields: | |
raise ValueError("No output fields found in the workflow") | |
output_field_names = {output.name for output in step.output_fields} | |
if not set(required_output_variables) <= output_field_names: | |
missing_vars = required_output_variables - output_field_names | |
raise ValueError(f"Missing required output variables: {missing_vars}") | |
return workflow | |
def validate_complex_workflow(workflow: Workflow, required_output_variables: list[str]) -> Workflow: | |
"""Validate the workflow.""" | |
print("Validating complex workflow.") | |
return workflow | |
step = next(iter(workflow.steps.values())) | |
if not step.output_fields: | |
raise ValueError("No output fields found in the workflow") | |
output_field_names = {output.name for output in step.output_fields} | |
if not output_field_names <= set(required_output_variables): | |
missing_vars = output_field_names - set(required_output_variables) | |
raise ValueError(f"Missing required output variables: {missing_vars}") | |
return workflow | |
def parse_yaml_workflow(yaml_str: str) -> Workflow: | |
"""Parse a YAML workflow.""" | |
workflow = yaml.safe_load(yaml_str) | |
return Workflow(**workflow) | |
def update_workflow_from_code(yaml_str: str, ui_state: PipelineUIState) -> PipelineState: | |
"""Update a workflow from a YAML string.""" | |
workflow = parse_yaml_workflow(yaml_str) | |
ui_state = PipelineUIState.from_workflow(workflow) | |
return PipelineState(workflow=workflow, ui_state=ui_state) | |
class PipelineInterface: | |
"""UI for the pipeline.""" | |
def __init__( | |
self, | |
workflow: Workflow, | |
ui_state: PipelineUIState | None = None, | |
model_options: list[str] = None, | |
simple: bool = False, | |
show_pipeline_selector: bool = False, | |
): | |
self.model_options = model_options | |
self.simple = simple | |
self.show_pipeline_selector = show_pipeline_selector | |
if not ui_state: | |
ui_state = PipelineUIState.from_workflow(workflow) | |
self.ui_state = make_state(ui_state) | |
self.pipeline_state = make_state(PipelineState(workflow=workflow, ui_state=ui_state)) | |
self.variables_state = make_state(workflow.get_available_variables()) | |
self.model_selection_state = make_state({}) | |
self.sm = PipelineStateManager() | |
self.input_variables = workflow.inputs | |
self.required_output_variables = list(workflow.outputs.keys()) | |
# UI elements | |
self.steps_container = None | |
self.components = [] | |
# Render the pipeline UI | |
self.render() | |
def _render_step( | |
self, | |
model_step: ModelStep, | |
step_ui_state: ModelStepUIState, | |
available_variables: list[str], | |
position: int = 0, | |
n_steps: int = 1, | |
): | |
with gr.Column(elem_classes="step-container"): | |
# Create the step component | |
step_interface = ModelStepComponent( | |
value=model_step, | |
ui_state=step_ui_state, | |
model_options=self.model_options, | |
input_variables=available_variables, | |
pipeline_state_manager=self.sm, | |
) | |
step_interface.on_model_step_change( | |
self.sm.update_model_step_state, | |
inputs=[self.pipeline_state, step_interface.model_step_state, step_interface.ui_state], | |
outputs=[self.pipeline_state, self.ui_state, self.variables_state, self.model_selection_state], | |
) | |
step_interface.on_ui_change( | |
self.sm.update_model_step_ui, | |
inputs=[self.pipeline_state, step_interface.ui_state, gr.State(model_step.id)], | |
outputs=[self.pipeline_state, self.ui_state], | |
) | |
if self.simple: | |
return step_interface | |
is_multi_step = n_steps > 1 | |
# logger.debug(f"Rendering step {position} of {n_steps}") | |
# Add step controls below | |
with gr.Row(elem_classes="step-controls", visible=is_multi_step): | |
up_button = gr.Button("⬆️ Move Up", elem_classes="step-control-btn", interactive=is_multi_step) | |
down_button = gr.Button("⬇️ Move Down", elem_classes="step-control-btn", interactive=is_multi_step) | |
remove_button = gr.Button("🗑️ Remove", elem_classes="step-control-btn", interactive=is_multi_step) | |
buttons = (up_button, down_button, remove_button) | |
self._assign_step_controls(buttons, position) | |
return (step_interface, *buttons) | |
def _assign_step_controls(self, buttons: tuple[gr.Button, gr.Button, gr.Button], position: int): | |
up_button, down_button, remove_button = buttons | |
position = gr.State(position) | |
up_button.click(self.sm.move_up, inputs=[self.ui_state, position], outputs=self.ui_state) | |
down_button.click(self.sm.move_down, inputs=[self.ui_state, position], outputs=self.ui_state) | |
remove_button.click( | |
self.sm.remove_step, | |
inputs=[self.pipeline_state, position], | |
outputs=[self.pipeline_state, self.ui_state, self.variables_state], | |
) | |
def _render_add_step_button(self, position: int): | |
if position not in {0, -1}: | |
raise ValueError("Position must be 0 or -1") | |
row_class = "pipeline-header" if position == 0 else "pipeline-footer" | |
with gr.Row(elem_classes=row_class): | |
add_step_btn = gr.Button("➕ Add Step", elem_classes="add-step-button") | |
add_step_btn.click( | |
self.sm.add_step, | |
inputs=[self.pipeline_state, gr.State(position)], | |
outputs=[self.pipeline_state, self.ui_state, self.variables_state], | |
) | |
return add_step_btn | |
def _render_output_panel(self, available_variables: list[str], pipeline_state: PipelineState): | |
dropdowns = {} | |
variable_options = [UNSELECTED_VAR_NAME] + [v for v in available_variables if v not in self.input_variables] | |
with gr.Column(elem_classes="step-accordion control-panel"): | |
commons.get_panel_header( | |
header="Final output variables mapping:", | |
) | |
with gr.Row(elem_classes="output-fields-row"): | |
for output_field in self.required_output_variables: | |
value = pipeline_state.workflow.outputs.get(output_field, UNSELECTED_VAR_NAME) | |
dropdown = gr.Dropdown( | |
label=output_field, | |
value=value, | |
choices=variable_options, | |
interactive=True, | |
elem_classes="output-field-variable", | |
# show_label=False, | |
) | |
dropdown.change( | |
self.sm.update_output_variables, | |
inputs=[self.pipeline_state, gr.State(output_field), dropdown], | |
outputs=[self.pipeline_state], | |
) | |
dropdowns[output_field] = dropdown | |
def update_choices(available_variables): | |
"""Update the choices for the dropdowns""" | |
return [ | |
gr.update(choices=available_variables, value=None, selected=None) for dropdown in dropdowns.values() | |
] | |
self.variables_state.change( | |
update_choices, | |
inputs=[self.variables_state], | |
outputs=list(dropdowns.values()), | |
) | |
return dropdowns | |
def validate_workflow(self, state: PipelineState) -> PipelineState: | |
"""Validate the workflow.""" | |
try: | |
if self.simple: | |
workflow = validate_simple_workflow(state.workflow, self.required_output_variables) | |
else: | |
workflow = validate_complex_workflow(state.workflow, self.required_output_variables) | |
state.workflow = workflow | |
return state | |
except ValueError as e: | |
raise gr.Error(e) | |
def _render_pipeline_header(self): | |
# Add Step button at top | |
input_variables_str = ", ".join([f"`{variable}`" for variable in self.input_variables]) | |
output_variables_str = ", ".join([f"`{variable}`" for variable in self.required_output_variables]) | |
if self.simple: | |
instruction = "Create a simple single LLM call pipeline that takes in the following input variables and outputs the following output variables:" | |
else: | |
instruction = "Create a pipeline that takes in the following input variables and outputs the following output variables:" | |
gr.Markdown(f"### {instruction}") | |
gr.Markdown(f"* Input Variables: {input_variables_str}") | |
gr.Markdown(f"* Output Variables: {output_variables_str}") | |
# if not self.simple: | |
# self._render_add_step_button(0) | |
def render(self): | |
"""Render the pipeline UI.""" | |
# Create a placeholder for all the step components | |
self.all_components = [] | |
# self.pipeline_state.change( | |
# lambda x, y: print(f"Pipeline state changed! UI:\n{x}\n\n Data:\n{y}"), | |
# inputs=[self.ui_state, self.pipeline_state], | |
# outputs=[], | |
# ) | |
self._render_pipeline_header() | |
# Function to render all steps | |
def render_steps(state: PipelineState, ui_state: PipelineUIState): | |
"""Render all steps in the pipeline""" | |
logger.info(f"Rerender triggered! Current UI State:{ui_state.model_dump()}") | |
workflow = state.workflow | |
components = [] | |
step_objects = [] # Reset step objects list | |
for i, step_id in enumerate(ui_state.step_ids): | |
step_data = workflow.steps[step_id] | |
step_ui_state = ui_state.steps[step_id] | |
available_variables = self.sm.get_all_variables(state, step_id) | |
sub_components = self._render_step(step_data, step_ui_state, available_variables, i, ui_state.n_steps) | |
step_objects.append(sub_components) | |
components.append(step_objects) | |
if not self.simple: | |
self._render_add_step_button(-1) | |
def render_output_fields(available_variables, pipeline_state): | |
logger.info(f"Rerendering output panel: {available_variables} {pipeline_state.workflow}") | |
self._render_output_panel(available_variables, pipeline_state) | |
export_btn = gr.Button("Export Pipeline", elem_classes="export-button") | |
# components.append(export_btn) | |
# Add a code box to display the workflow JSON | |
# with gr.Column(elem_classes="workflow-json-container"): | |
with gr.Accordion( | |
"Pipeline Preview (click to expand and edit)", open=False, elem_classes="pipeline-preview" | |
) as self.config_accordion: | |
self.config_output = gr.Code( | |
label="Workflow Configuration", | |
show_label=False, | |
language="yaml", | |
elem_classes="workflow-json", | |
interactive=True, | |
autocomplete=True, | |
) | |
# components.append(config_accordion) | |
self.config_output.blur( | |
fn=update_workflow_from_code, | |
inputs=[self.config_output, self.ui_state], | |
outputs=[self.pipeline_state], | |
) | |
# Connect the export button to show the workflow JSON | |
self.add_triggers_for_pipeline_export([export_btn.click], self.pipeline_state) | |
export_btn.click(self.validate_workflow, inputs=[self.pipeline_state], outputs=[self.pipeline_state]).success( | |
fn=lambda: gr.update(visible=True, open=True), outputs=[self.config_accordion] | |
) | |
def add_triggers_for_pipeline_export(self, triggers: list, input_pipeline_state: gr.State): | |
gr.on( | |
triggers, | |
self.validate_workflow, | |
inputs=[input_pipeline_state], | |
outputs=[self.pipeline_state], | |
).success( | |
fn=self.sm.get_formatted_config, | |
inputs=[self.pipeline_state, gr.State("yaml")], | |
outputs=[self.config_output], | |
js="() => {document.querySelector('.pipeline-preview').scrollIntoView({behavior: 'smooth'})}", | |
) | |