Maharshi Gor
Add user input validation to pipeline interfaces error display on pipeline change.
849566b
from typing import Literal | |
from app_configs import CONFIGS | |
from components.structs import PipelineState, TossupPipelineState | |
from components.typed_dicts import PipelineStateDict, TossupPipelineStateDict | |
from workflows.structs import TossupWorkflow, Workflow | |
from workflows.validators import WorkflowValidator | |
def validate_workflow( | |
workflow: TossupWorkflow | Workflow, required_input_vars: list[str], required_output_vars: list[str] | |
): | |
""" | |
Validate that a workflow is properly configured for the tossup task. | |
Args: | |
workflow (TossupWorkflow): The workflow to validate | |
Raises: | |
ValueError: If the workflow is not properly configured | |
""" | |
if not workflow.steps: | |
raise ValueError("Workflow must have at least one step") | |
# Check that the workflow has the correct structure | |
input_vars = set(workflow.inputs) | |
for req_var in required_input_vars: | |
if req_var not in input_vars: | |
raise ValueError(f"Workflow must have '{req_var}' as an input") | |
output_vars = set(workflow.outputs) | |
for req_var in required_output_vars: | |
if req_var not in output_vars: | |
raise ValueError(f"Workflow must produce '{req_var}' as an output") | |
# Ensure all steps are properly configured | |
WorkflowValidator().validate(workflow) | |
def validate_tossup_workflow(pipeline_state_dict: TossupPipelineStateDict) -> TossupPipelineState: | |
pipeline_state = TossupPipelineState(**pipeline_state_dict) | |
validate_workflow( | |
pipeline_state.workflow, | |
CONFIGS["tossup"]["required_input_vars"], | |
CONFIGS["tossup"]["required_output_vars"], | |
) | |
return pipeline_state | |
def validate_bonus_workflow(pipeline_state_dict: PipelineStateDict): | |
pipeline_state = PipelineState(**pipeline_state_dict) | |
validate_workflow( | |
pipeline_state.workflow, | |
CONFIGS["bonus"]["required_input_vars"], | |
CONFIGS["bonus"]["required_output_vars"], | |
) | |
return pipeline_state | |
class UserInputWorkflowValidator: | |
def __init__(self, mode: Literal["tossup", "bonus"]): | |
self.required_input_vars = CONFIGS[mode]["required_input_vars"] | |
self.required_output_vars = CONFIGS[mode]["required_output_vars"] | |
def __call__(self, workflow: TossupWorkflow): | |
input_vars = set(workflow.inputs) | |
for req_var in self.required_input_vars: | |
if req_var not in input_vars: | |
default_str = "inputs:\n" + "\n".join([f"- {var}" for var in self.required_input_vars]) | |
raise ValueError( | |
f"Missing required input variable: '{req_var}'. " | |
"\nDon't modify the 'inputs' field in the workflow. " | |
"Please set it back to:" | |
f"\n{default_str}" | |
) | |
output_vars = set(workflow.outputs) | |
for req_var in self.required_output_vars: | |
if req_var not in output_vars: | |
default_str = "[" + ", ".join([f"'{var}'" for var in self.required_output_vars]) + "]" | |
raise ValueError( | |
f"Missing required output variable: '{req_var}'. " | |
"\nDon't remove the keys from the 'outputs' field in the workflow. Only update their values." | |
f"\nMake sure you have values set for all the outputs: {default_str}" | |
) | |