Maharshi Gor
Add user input validation to pipeline interfaces error display on pipeline change.
849566b
from typing import Any, Literal | |
from pydantic import BaseModel, Field, model_validator | |
from workflows.structs import ModelStep, TossupWorkflow, Workflow | |
def make_step_id(step_number: int): | |
"""Make a step id from a step name.""" | |
if step_number < 26: | |
return chr(ord("A") + step_number) | |
else: | |
# For more than 26 steps, use AA, AB, AC, etc. | |
first_char = chr(ord("A") + (step_number // 26) - 1) | |
second_char = chr(ord("A") + (step_number % 26)) | |
return f"{first_char}{second_char}" | |
def make_step_number(step_id: str): | |
"""Make a step number from a step id.""" | |
if len(step_id) == 1: | |
return ord(step_id) - ord("A") | |
else: | |
return (ord(step_id[0]) - ord("A")) * 26 + (ord(step_id[1]) - ord("A")) + 1 | |
class ModelStepUIState(BaseModel): | |
"""Represents the UI state for a model step component.""" | |
expanded: bool = True | |
active_tab: Literal["model-tab", "inputs-tab", "outputs-tab"] = "model-tab" | |
class Config: | |
frozen = True | |
def update(self, key: str, value: Any) -> "ModelStepUIState": | |
"""Update the UI state.""" | |
return self.model_copy(update={key: value}) | |
class PipelineUIState(BaseModel): | |
"""Represents the UI state for a pipeline component.""" | |
step_ids: list[str] = Field(default_factory=list) | |
steps: dict[str, ModelStepUIState] = Field(default_factory=dict) | |
def model_post_init(self, __context: Any) -> None: | |
if not self.steps and self.step_ids: | |
self.steps = {step_id: ModelStepUIState() for step_id in self.step_ids} | |
return super().model_post_init(__context) | |
def get_step_position(self, step_id: str): | |
"""Get the position of a step in the pipeline.""" | |
return next((i for i, step in enumerate(self.step_ids) if step == step_id), None) | |
def n_steps(self) -> int: | |
"""Get the number of steps in the pipeline.""" | |
return len(self.step_ids) | |
def from_workflow(cls, workflow: Workflow): | |
"""Create a pipeline UI state from a workflow.""" | |
return PipelineUIState( | |
step_ids=list(workflow.steps.keys()), | |
steps={step_id: ModelStepUIState() for step_id in workflow.steps.keys()}, | |
) | |
def from_pipeline_state(cls, pipeline_state: "PipelineState"): | |
"""Create a pipeline UI state from a pipeline state.""" | |
return cls.from_workflow(pipeline_state.workflow) | |
# Update methods | |
def insert_step(self, step_id: str, position: int = -1) -> "PipelineUIState": | |
"""Insert a step into the pipeline at the given position.""" | |
if position == -1: | |
position = len(self.step_ids) | |
self.step_ids.insert(position, step_id) | |
steps = self.steps | {step_id: ModelStepUIState()} | |
return self.model_copy(update={"step_ids": self.step_ids, "steps": steps}) | |
def remove_step(self, step_id: str) -> "PipelineUIState": | |
"""Remove a step from the pipeline.""" | |
if step_id not in self.step_ids: | |
raise ValueError(f"Step {step_id} not found in pipeline. Step IDs: {self.step_ids}") | |
self.step_ids.remove(step_id) | |
self.steps.pop(step_id) | |
return self.model_copy(update={"step_ids": self.step_ids, "steps": self.steps}) | |
def update_step(self, step_id: str, ui_state: ModelStepUIState) -> "PipelineUIState": | |
"""Update a step in the pipeline.""" | |
if step_id not in self.steps: | |
raise ValueError(f"Step {step_id} not found in pipeline. Step IDs: {self.step_ids}") | |
return self.model_copy(update={"steps": self.steps | {step_id: ui_state}}) | |
class PipelineState(BaseModel): | |
"""Represents the state for a pipeline component.""" | |
workflow: Workflow | |
ui_state: PipelineUIState | |
def from_workflow(cls, workflow: Workflow): | |
"""Create a pipeline state from a workflow.""" | |
return cls(workflow=workflow, ui_state=PipelineUIState.from_workflow(workflow)) | |
def update_workflow(self, workflow: Workflow) -> "PipelineState": | |
return self.model_copy(update={"workflow": workflow}) | |
def insert_step(self, position: int, step: ModelStep) -> "PipelineState": | |
if step.id in self.workflow.steps: | |
raise ValueError(f"Step {step.id} already exists in pipeline") | |
# Validate position | |
if position != -1 and (position < 0 or position > self.n_steps): | |
raise ValueError(f"Invalid position: {position}. Must be between 0 and {self.n_steps} or -1") | |
# Create a new workflow with updated steps | |
workflow = self.workflow.add_step(step) | |
self.ui_state = self.ui_state.insert_step(step.id, position) | |
# Return a new PipelineState with the updated workflow | |
return self.model_copy(update={"workflow": workflow, "ui_state": self.ui_state}) | |
def remove_step(self, position: int) -> "PipelineState": | |
step_id = self.ui_state.step_ids[position] | |
workflow = self.workflow.remove_step(step_id) | |
ui_state = self.ui_state.remove_step(step_id) | |
return self.model_copy(update={"workflow": workflow, "ui_state": ui_state}) | |
def update_step(self, step: ModelStep, ui_state: ModelStepUIState | None = None) -> "PipelineState": | |
"""Update a step in the pipeline.""" | |
if step.id not in self.workflow.steps: | |
raise ValueError(f"Step {step.id} not found in pipeline") | |
workflow = self.workflow.update_step(step) | |
update = {"workflow": workflow} | |
if ui_state is not None: | |
update["ui_state"] = self.ui_state.update_step(step.id, ui_state) | |
return self.model_copy(update=update) | |
def update_output_variable(self, target: str, produced_variable: str) -> "PipelineState": | |
"""Update the output variables for a step.""" | |
self.workflow.outputs[target] = produced_variable | |
return self | |
def get_available_variables(self, model_step_id: str | None = None) -> list[str]: | |
"""Get all variables from all steps.""" | |
available_variables = self.available_variables | |
if model_step_id is None: | |
return available_variables | |
prefix = f"{model_step_id}." | |
return [var for var in available_variables if not var.startswith(prefix)] | |
def available_variables(self) -> list[str]: | |
return self.workflow.get_available_variables() | |
def n_steps(self) -> int: | |
return len(self.workflow.steps) | |
def get_new_step_id(self) -> str: | |
"""Get a step ID for a new step.""" | |
if not self.workflow.steps: | |
return "A" | |
else: | |
last_step_number = max(map(make_step_number, self.workflow.steps.keys())) | |
return make_step_id(last_step_number + 1) | |
class TossupPipelineState(PipelineState): | |
workflow: TossupWorkflow | |
def update_step(self, step: ModelStep, ui_state: ModelStepUIState | None = None) -> "TossupPipelineState": | |
"""Update a step in the pipeline.""" | |
state = super().update_step(step, ui_state) | |
state.workflow = state.workflow.refresh_buzzer() | |
return state | |
def update_output_variable(self, target: str, produced_variable: str) -> "TossupPipelineState": | |
state = super().update_output_variable(target, produced_variable) | |
state.workflow = state.workflow.refresh_buzzer() | |
return state | |