|
from typing import Literal |
|
|
|
from app_configs import AVAILABLE_MODELS, CONFIGS |
|
from components.structs import PipelineState, TossupPipelineState |
|
from components.typed_dicts import PipelineStateDict, TossupPipelineStateDict |
|
from workflows.structs import TossupWorkflow, Workflow |
|
from workflows.validators import WorkflowValidationError, WorkflowValidator |
|
|
|
|
|
def validate_workflow( |
|
workflow: TossupWorkflow | Workflow, required_input_vars: list[str], required_output_vars: list[str] |
|
): |
|
""" |
|
Validate that a workflow is properly configured for the tossup task. |
|
|
|
Args: |
|
workflow (TossupWorkflow): The workflow to validate |
|
|
|
Raises: |
|
ValueError: If the workflow is not properly configured |
|
""" |
|
if not workflow.steps: |
|
raise ValueError("Workflow must have at least one step") |
|
|
|
|
|
input_vars = set(workflow.inputs) |
|
for req_var in required_input_vars: |
|
if req_var not in input_vars: |
|
raise ValueError(f"Workflow must have '{req_var}' as an input") |
|
|
|
output_vars = set(workflow.outputs) |
|
for req_var in required_output_vars: |
|
if req_var not in output_vars: |
|
raise ValueError(f"Workflow must produce '{req_var}' as an output") |
|
|
|
|
|
WorkflowValidator().validate(workflow) |
|
|
|
|
|
def validate_tossup_workflow(pipeline_state_dict: TossupPipelineStateDict) -> TossupPipelineState: |
|
pipeline_state = TossupPipelineState(**pipeline_state_dict) |
|
validate_workflow( |
|
pipeline_state.workflow, |
|
CONFIGS["tossup"]["required_input_vars"], |
|
CONFIGS["tossup"]["required_output_vars"], |
|
) |
|
return pipeline_state |
|
|
|
|
|
def validate_bonus_workflow(pipeline_state_dict: PipelineStateDict): |
|
pipeline_state = PipelineState(**pipeline_state_dict) |
|
validate_workflow( |
|
pipeline_state.workflow, |
|
CONFIGS["bonus"]["required_input_vars"], |
|
CONFIGS["bonus"]["required_output_vars"], |
|
) |
|
return pipeline_state |
|
|
|
|
|
class UserInputWorkflowValidator: |
|
def __init__(self, mode: Literal["tossup", "bonus"]): |
|
self.required_input_vars = CONFIGS[mode]["required_input_vars"] |
|
self.required_output_vars = CONFIGS[mode]["required_output_vars"] |
|
|
|
def __call__(self, workflow: TossupWorkflow): |
|
input_vars = set(workflow.inputs) |
|
for req_var in self.required_input_vars: |
|
if req_var not in input_vars: |
|
default_str = "inputs:\n" + "\n".join([f"- {var}" for var in self.required_input_vars]) |
|
raise ValueError( |
|
f"Missing required input variable: '{req_var}'. " |
|
"\nDon't modify the 'inputs' field in the workflow. " |
|
"Please set it back to:" |
|
f"\n{default_str}" |
|
) |
|
|
|
output_vars = set(workflow.outputs) |
|
for req_var in self.required_output_vars: |
|
if req_var not in output_vars: |
|
default_str = "[" + ", ".join([f"'{var}'" for var in self.required_output_vars]) + "]" |
|
raise ValueError( |
|
f"Missing required output variable: '{req_var}'. " |
|
"\nDon't remove the keys from the 'outputs' field in the workflow. Only update their values." |
|
f"\nMake sure you have values set for all the outputs: {default_str}" |
|
) |
|
|
|
|
|
allowed_model_names = AVAILABLE_MODELS.keys() |
|
self.validator = WorkflowValidator(allowed_model_names=allowed_model_names) |
|
try: |
|
self.validator.validate(workflow, allow_empty=True) |
|
except WorkflowValidationError as e: |
|
error_msg_total = f"Found {len(e.errors)} errors in the workflow:\n" |
|
error_msg_list = [f"- {err.message}" for err in e.errors] |
|
error_msg = error_msg_total + "\n".join(error_msg_list) |
|
raise ValueError(error_msg) |
|
|