from typing import Literal from app_configs import AVAILABLE_MODELS, CONFIGS from components.structs import PipelineState, TossupPipelineState from components.typed_dicts import PipelineStateDict, TossupPipelineStateDict from shared.workflows.structs import TossupWorkflow, Workflow from shared.workflows.validators import WorkflowValidationError, WorkflowValidator def validate_workflow( workflow: TossupWorkflow | Workflow, required_input_vars: list[str], required_output_vars: list[str] ): """ Validate that a workflow is properly configured for the tossup task. Args: workflow (TossupWorkflow): The workflow to validate Raises: ValueError: If the workflow is not properly configured """ if not workflow.steps: raise ValueError("Workflow must have at least one step") # Check that the workflow has the correct structure input_vars = set(workflow.inputs) for req_var in required_input_vars: if req_var not in input_vars: raise ValueError(f"Workflow must have '{req_var}' as an input") output_vars = set(workflow.outputs) for req_var in required_output_vars: if req_var not in output_vars: raise ValueError(f"Workflow must produce '{req_var}' as an output") # Ensure all steps are properly configured WorkflowValidator().validate(workflow) def validate_tossup_workflow(pipeline_state_dict: TossupPipelineStateDict) -> TossupPipelineState: pipeline_state = TossupPipelineState(**pipeline_state_dict) validate_workflow( pipeline_state.workflow, CONFIGS["tossup"]["required_input_vars"], CONFIGS["tossup"]["required_output_vars"], ) return pipeline_state def validate_bonus_workflow(pipeline_state_dict: PipelineStateDict): pipeline_state = PipelineState(**pipeline_state_dict) validate_workflow( pipeline_state.workflow, CONFIGS["bonus"]["required_input_vars"], CONFIGS["bonus"]["required_output_vars"], ) return pipeline_state class UserInputWorkflowValidator: def __init__(self, mode: Literal["tossup", "bonus"]): self.required_input_vars = CONFIGS[mode]["required_input_vars"] self.required_output_vars = CONFIGS[mode]["required_output_vars"] def __call__(self, workflow: TossupWorkflow): input_vars = set(workflow.inputs) for req_var in self.required_input_vars: if req_var not in input_vars: default_str = "inputs:\n" + "\n".join([f"- {var}" for var in self.required_input_vars]) raise ValueError( f"Missing required input variable: '{req_var}'. " "\nDon't modify the 'inputs' field in the workflow. " "Please set it back to:" f"\n{default_str}" ) output_vars = set(workflow.outputs) for req_var in self.required_output_vars: if req_var not in output_vars: default_str = "[" + ", ".join([f"'{var}'" for var in self.required_output_vars]) + "]" raise ValueError( f"Missing required output variable: '{req_var}'. " "\nDon't remove the keys from the 'outputs' field in the workflow. Only update their values." f"\nMake sure you have values set for all the outputs: {default_str}" ) # Validate the workflow allowed_model_names = AVAILABLE_MODELS.keys() self.validator = WorkflowValidator(allowed_model_names=allowed_model_names) try: self.validator.validate(workflow, allow_empty=True) except WorkflowValidationError as e: error_msg_total = f"Found {len(e.errors)} errors in the workflow:\n" error_msg_list = [f"- {err.message}" for err in e.errors] error_msg = error_msg_total + "\n".join(error_msg_list) raise ValueError(error_msg)