Maharshi Gor
Enhances model selection and logging in pipeline components; adds logprobs support and improves UI feedback for disabled sliders.
4f5d1cb
import json | |
from typing import Any | |
import gradio as gr | |
from gradio.components import FormComponent | |
from app_configs import UNSELECTED_VAR_NAME | |
from components.model_pipeline.state_manager import ModelStepUIState, PipelineState, PipelineStateManager | |
from utils import get_full_model_name | |
from workflows.structs import ModelStep | |
from .state_manager import ModelStepStateManager | |
from .ui_components import InputRowButtonGroup, OutputRowButtonGroup | |
def _make_accordion_label(model_step: ModelStep): | |
name = model_step.name if model_step.name else "Untitled" | |
input_field_names = [field.name for field in model_step.input_fields] | |
inputs_str = ", ".join(input_field_names) | |
output_field_names = [field.name for field in model_step.output_fields] | |
outputs_str = ", ".join(output_field_names) | |
return "{}: {} ({}) → ({})".format(model_step.id, name, inputs_str, outputs_str) | |
class ModelStepComponent(FormComponent): | |
""" | |
A custom Gradio component representing a single Step in a pipeline. | |
It contains: | |
1. Model Provider & System Prompt | |
2. Inputs – fields with name, description, and variable used | |
3. Outputs – fields with name, description, and variable used | |
Listens to events: | |
- on_model_step_change | |
- on_ui_change | |
""" | |
def __init__( | |
self, | |
value: ModelStep | gr.State, | |
ui_state: ModelStepUIState | gr.State | None = None, | |
model_options: list[str] | None = None, | |
input_variables: list[str] | None = None, | |
max_input_fields=5, | |
max_output_fields=5, | |
pipeline_state_manager: PipelineStateManager | None = None, | |
**kwargs, | |
): | |
self.max_fields = { | |
"input": max_input_fields, | |
"output": max_output_fields, | |
} | |
self.model_options = model_options | |
self.input_variables = [UNSELECTED_VAR_NAME] + input_variables | |
self.sm = ModelStepStateManager(max_input_fields, max_output_fields) | |
self.pipeline_sm: PipelineStateManager = pipeline_state_manager | |
self.model_step_state = gr.State(value) | |
ui_state = ui_state or ModelStepUIState() | |
if not isinstance(ui_state, gr.State): | |
ui_state = gr.State(ui_state) | |
self.ui_state: gr.State = ui_state | |
self.inputs_count_state = gr.State(len(value.input_fields)) | |
self.outputs_count_state = gr.State(len(value.output_fields)) | |
# UI components that will be created in render | |
self.accordion = None | |
self.ui = None | |
self.step_name_input = None | |
self.model_selection = None | |
self.system_prompt = None | |
self.input_rows = [] | |
self.output_rows = [] | |
super().__init__(**kwargs) | |
# self.render() | |
self.setup_event_listeners() | |
def model_step(self) -> ModelStep: | |
return self.model_step_state.value | |
def step_id(self) -> str: | |
return self.model_step.id | |
def get_step_config(self) -> dict: | |
return self.model_step.model_dump() | |
# UI state accessors | |
def is_open(self) -> bool: | |
return self.ui_state.value.expanded | |
def get_active_tab(self) -> str: | |
"""Get the current active tab.""" | |
return self.ui_state.value.active_tab | |
def _render_input_row(self, i: int) -> tuple[gr.Row, tuple, tuple]: | |
"""Render a single input row at index i.""" | |
inputs = self.model_step.input_fields | |
is_visible = i < len(inputs) | |
label_visible = i == 0 | |
disable_delete = i == 0 and len(inputs) == 1 | |
initial_name = inputs[i].name if is_visible else "" | |
initial_desc = inputs[i].description if is_visible else "" | |
initial_var = inputs[i].variable or UNSELECTED_VAR_NAME if is_visible else UNSELECTED_VAR_NAME | |
with gr.Row(visible=is_visible, elem_classes="field-row form") as row: | |
button_group = InputRowButtonGroup(disable_delete=disable_delete) | |
inp_var = gr.Dropdown( | |
choices=self.input_variables, | |
label="Variable Used", | |
value=initial_var, | |
elem_classes="field-variable", | |
scale=1, | |
show_label=label_visible, | |
) | |
inp_name = gr.Textbox( | |
label="Input Name", | |
placeholder="Field name", | |
value=initial_name, | |
elem_classes="field-name", | |
scale=1, | |
show_label=label_visible, | |
) | |
inp_desc = gr.Textbox( | |
label="Description", | |
placeholder="Field description", | |
value=initial_desc, | |
elem_classes="field-description", | |
scale=3, | |
show_label=label_visible, | |
) | |
fields = (inp_name, inp_var, inp_desc) | |
# buttons = (delete_button, add_button) | |
return row, fields, button_group | |
def _render_output_row(self, i: int) -> tuple[gr.Row, tuple, tuple]: | |
"""Render a single output row at index i.""" | |
outputs = self.model_step.output_fields | |
is_visible = i < len(outputs) | |
label_visible = i == 0 | |
disable_delete = i == 0 and len(outputs) == 1 | |
initial_name = outputs[i].name if is_visible else "" | |
initial_desc = outputs[i].description if is_visible else "" | |
initial_type = outputs[i].type if is_visible else "str" | |
with gr.Row(visible=is_visible, elem_classes="field-row") as row: | |
button_group = OutputRowButtonGroup(disable_delete=disable_delete) | |
out_name = gr.Textbox( | |
label="Output Field", | |
placeholder="Variable identifier", | |
value=initial_name, | |
elem_classes="field-name", | |
scale=1, | |
show_label=label_visible, | |
) | |
out_type = gr.Dropdown( | |
choices=["str", "int", "float", "bool"], | |
allow_custom_value=True, | |
label="Type", | |
value=initial_type, | |
elem_classes="field-type", | |
scale=0, | |
show_label=label_visible, | |
interactive=True, | |
) | |
out_desc = gr.Textbox( | |
label="Description", | |
placeholder="Field description", | |
value=initial_desc, | |
elem_classes="field-description", | |
scale=3, | |
show_label=label_visible, | |
) | |
fields = (out_name, out_type, out_desc) | |
return row, fields, button_group | |
def _render_prompt_tab_content(self): | |
self.system_prompt = gr.Textbox( | |
label="System Prompt", | |
placeholder="Enter the system prompt for this step", | |
lines=5, | |
value=self.model_step.system_prompt, | |
elem_classes="system-prompt", | |
) | |
def _render_inputs_tab_content(self): | |
with gr.Column(variant="panel", elem_classes="fields-panel") as self.inputs_column: | |
# Render input rows using helper method | |
for i in range(self.max_fields["input"]): | |
row = self._render_input_row(i) | |
self.input_rows.append(row) | |
def _render_outputs_tab_content(self): | |
with gr.Column(variant="panel", elem_classes="fields-panel") as self.outputs_column: | |
# Render output rows using helper method | |
for i in range(self.max_fields["output"]): | |
row = self._render_output_row(i) | |
self.output_rows.append(row) | |
def _render_tab_content(self, tab_id: str): | |
if tab_id == "model-tab": | |
self._render_prompt_tab_content() | |
elif tab_id == "inputs-tab": | |
self._render_inputs_tab_content() | |
elif tab_id == "outputs-tab": | |
self._render_outputs_tab_content() | |
def _render_header(self, model_options: tuple[str]): | |
# Header with step name | |
with gr.Row(elem_classes="step-header-row"): | |
self.step_name_input = gr.Textbox( | |
label="", | |
value=self.model_step.name, | |
elem_classes="step-name", | |
show_label=False, | |
placeholder="Model name...", | |
) | |
unselected_choice = "Select Model..." | |
current_value = ( | |
get_full_model_name(self.model_step.model, self.model_step.provider) | |
if self.model_step.model | |
else unselected_choice | |
) | |
self.model_selection = gr.Dropdown( | |
choices=[unselected_choice] + model_options, | |
label="Model Provider", | |
show_label=False, | |
value=current_value, | |
elem_classes="model-dropdown", | |
scale=1, | |
) | |
self.temperature_slider = gr.Slider( | |
value=self.model_step.temperature, | |
minimum=0.0, | |
maximum=5, | |
step=0.05, | |
info="Temperature", | |
show_label=False, | |
show_reset_button=False, | |
) | |
def render(self): | |
"""Render the component UI""" | |
# Reset UI component lists | |
self.input_rows = [] | |
self.output_rows = [] | |
self.tabs = {} | |
# Create the accordion for this step | |
accordion_label = _make_accordion_label(self.model_step) | |
self.accordion = gr.Accordion(label=accordion_label, open=self.is_open(), elem_classes="step-accordion") | |
# Create the UI content inside the accordion | |
with self.accordion: | |
self._render_header(self.model_options) | |
# Configuration tabs | |
selected_tab = self.get_active_tab() | |
with gr.Tabs(elem_classes="step-tabs", selected=selected_tab): | |
tab_ids = ("model-tab", "inputs-tab", "outputs-tab") | |
tab_labels = ("Model", "Inputs", "Outputs") | |
for tab_id, label in zip(tab_ids, tab_labels): | |
with gr.TabItem(label, elem_classes="tab-content", id=tab_id) as tab: | |
self._render_tab_content(tab_id) | |
self.tabs[tab_id] = tab | |
return self.accordion | |
def _setup_event_listeners_for_view_change(self): | |
for tab_id, tab in self.tabs.items(): | |
tab.select( | |
fn=self.sm.update_ui_state, | |
inputs=[self.ui_state, gr.State("active_tab"), gr.State(tab_id)], | |
outputs=[self.ui_state], | |
) | |
self.accordion.collapse( | |
fn=self.sm.update_ui_state, | |
inputs=[self.ui_state, gr.State("expanded"), gr.State(False)], | |
outputs=[self.ui_state], | |
) | |
self.accordion.expand( | |
fn=self.sm.update_ui_state, | |
inputs=[self.ui_state, gr.State("expanded"), gr.State(True)], | |
outputs=[self.ui_state], | |
) | |
def _setup_event_listeners_model_tab(self): | |
# Step name change | |
self.step_name_input.blur( | |
fn=self._update_state_and_label, | |
inputs=[self.model_step_state, self.step_name_input], | |
outputs=[self.model_step_state, self.accordion], | |
) | |
self.temperature_slider.release( | |
fn=self.sm.update_temperature, | |
inputs=[self.model_step_state, self.temperature_slider], | |
outputs=[self.model_step_state], | |
) | |
# Model and system prompt | |
self.model_selection.input( | |
fn=self.sm.update_model_and_provider, | |
inputs=[self.model_step_state, self.model_selection], | |
outputs=[self.model_step_state], | |
) | |
self.system_prompt.blur( | |
fn=self.sm.update_system_prompt, | |
inputs=[self.model_step_state, self.system_prompt], | |
outputs=[self.model_step_state], | |
) | |
def _setup_event_listeners_inputs_tab(self): | |
# Setup input row events | |
for i, (row, fields, button_group) in enumerate(self.input_rows): | |
inp_name, inp_var, inp_desc = fields | |
row_index = gr.State(i) | |
# Field change handlers | |
inp_name.blur( | |
fn=self.sm.update_input_field_name, | |
inputs=[self.model_step_state, inp_name, row_index], | |
outputs=[self.model_step_state], | |
) | |
inp_var.change( | |
fn=self.sm.update_input_field_variable, | |
inputs=[self.model_step_state, inp_var, inp_name, row_index], | |
outputs=[self.model_step_state], | |
) | |
inp_desc.blur( | |
fn=self.sm.update_input_field_description, | |
inputs=[self.model_step_state, inp_desc, row_index], | |
outputs=[self.model_step_state], | |
) | |
rows = [row for (row, _, _) in self.input_rows] | |
input_fields = [field for (_, fields, _) in self.input_rows for field in fields] | |
# Button handlers | |
button_group.delete( | |
fn=self.sm.delete_input_field, | |
inputs=[self.model_step_state, row_index], | |
outputs=[self.model_step_state, self.inputs_count_state] + rows + input_fields, | |
) | |
button_group.add( | |
fn=self.sm.add_input_field, | |
inputs=[self.model_step_state, row_index], | |
outputs=[self.model_step_state, self.inputs_count_state] + rows + input_fields, | |
) | |
def _setup_event_listeners_outputs_tab(self): | |
# Setup output row events | |
for i, (row, fields, button_group) in enumerate(self.output_rows): | |
out_name, out_type, out_desc = fields | |
row_index = gr.State(i) | |
# Field change handlers | |
out_name.blur( | |
fn=self.sm.update_output_field_name, | |
inputs=[self.model_step_state, out_name, row_index], | |
outputs=[self.model_step_state], | |
) | |
out_type.change( | |
fn=self.sm.update_output_field_type, | |
inputs=[self.model_step_state, out_type, row_index], | |
outputs=[self.model_step_state], | |
) | |
out_desc.blur( | |
fn=self.sm.update_output_field_description, | |
inputs=[self.model_step_state, out_desc, row_index], | |
outputs=[self.model_step_state], | |
) | |
rows = [row for (row, _, _) in self.output_rows] | |
output_fields = [field for (_, fields, _) in self.output_rows for field in fields] | |
# Button handlers | |
button_group.delete( | |
fn=self.sm.delete_output_field, | |
inputs=[self.model_step_state, row_index], | |
outputs=[self.model_step_state, self.outputs_count_state] + rows + output_fields, | |
) | |
button_group.add( | |
fn=self.sm.add_output_field, | |
inputs=[self.model_step_state, row_index], | |
outputs=[self.model_step_state, self.outputs_count_state] + rows + output_fields, | |
) | |
button_group.up( | |
fn=self.sm.move_output_field, | |
inputs=[self.model_step_state, row_index, gr.State("up")], | |
outputs=[self.model_step_state] + output_fields, | |
) | |
button_group.down( | |
fn=self.sm.move_output_field, | |
inputs=[self.model_step_state, row_index, gr.State("down")], | |
outputs=[self.model_step_state] + output_fields, | |
) | |
# Function to set up event listeners - call this separately after all components are rendered | |
def setup_event_listeners(self): | |
"""Set up all event listeners for this component""" | |
self._setup_event_listeners_for_view_change() | |
self._setup_event_listeners_model_tab() | |
self._setup_event_listeners_inputs_tab() | |
self._setup_event_listeners_outputs_tab() | |
def state_str(x, limited: bool = False): | |
d = x.model_dump() | |
if limited: | |
d = {k: d[k] for k in {"name", "temperature"}} | |
return json.dumps(d, indent=2) | |
def log_step_states(x, y, src: str): | |
print(f"{src} triggered! UI:\n{state_str(x)}\n\nData:\n{state_str(y, True)}") | |
print("--------------------------------") | |
print(f"self.model_step_state: \n{self.get_step_config()}") | |
print("--------------------------------") | |
# self.model_step_state.change( | |
# log_step_states, | |
# inputs=[self.ui_state, self.model_step_state, gr.State("Model Change")], | |
# ) | |
# self.ui_state.change( | |
# log_step_states, | |
# inputs=[self.ui_state, self.model_step_state, gr.State("UI Change")], | |
# ) | |
def on_model_step_change(self, fn, inputs, outputs): | |
"""Set up an event listener for the model change event.""" | |
return self.model_step_state.change(fn, inputs, outputs) | |
def on_ui_change(self, fn, inputs, outputs): | |
"""Set up an event listener for the UI change event.""" | |
return self.ui_state.change(fn, inputs, outputs) | |
def _update_state_and_label(self, model_step: ModelStep, name: str): | |
"""Update both the state and the accordion label.""" | |
new_model_step = self.sm.update_step_name(model_step, name) | |
new_label = _make_accordion_label(new_model_step) | |
return new_model_step, gr.update(label=new_label) | |
def refresh_variable_dropdowns(self, pipeline_state: PipelineState): | |
# TODO: Fix this. Not sure why this is needed. | |
"""Refresh the variable dropdown options in all input rows.""" | |
variable_choices = [] | |
if self.pipeline_sm is not None: | |
variable_choices = self.pipeline_sm.get_all_variables(pipeline_state) | |
for _, fields, _ in self.input_rows: | |
_, inp_var, _ = fields | |
inp_var.update(choices=variable_choices) | |
def _update_model_and_refresh_ui(self, updated_model_step): | |
"""Update the model step state and refresh UI elements that depend on it.""" | |
self.model_step_state.value = updated_model_step | |
# Update accordion label | |
new_label = _make_accordion_label(updated_model_step) | |
if self.accordion: | |
self.accordion.update(label=new_label) | |
return updated_model_step | |