File size: 3,967 Bytes
849566b e272e20 f9589f4 5f3e7d5 f9589f4 849566b e272e20 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
from typing import Literal
from app_configs import AVAILABLE_MODELS, CONFIGS
from components.structs import PipelineState, TossupPipelineState
from components.typed_dicts import PipelineStateDict, TossupPipelineStateDict
from shared.workflows.structs import TossupWorkflow, Workflow
from shared.workflows.validators import WorkflowValidationError, WorkflowValidator
def validate_workflow(
workflow: TossupWorkflow | Workflow, required_input_vars: list[str], required_output_vars: list[str]
):
"""
Validate that a workflow is properly configured for the tossup task.
Args:
workflow (TossupWorkflow): The workflow to validate
Raises:
ValueError: If the workflow is not properly configured
"""
if not workflow.steps:
raise ValueError("Workflow must have at least one step")
# Check that the workflow has the correct structure
input_vars = set(workflow.inputs)
for req_var in required_input_vars:
if req_var not in input_vars:
raise ValueError(f"Workflow must have '{req_var}' as an input")
output_vars = set(workflow.outputs)
for req_var in required_output_vars:
if req_var not in output_vars:
raise ValueError(f"Workflow must produce '{req_var}' as an output")
# Ensure all steps are properly configured
WorkflowValidator().validate(workflow)
def validate_tossup_workflow(pipeline_state_dict: TossupPipelineStateDict) -> TossupPipelineState:
pipeline_state = TossupPipelineState(**pipeline_state_dict)
validate_workflow(
pipeline_state.workflow,
CONFIGS["tossup"]["required_input_vars"],
CONFIGS["tossup"]["required_output_vars"],
)
return pipeline_state
def validate_bonus_workflow(pipeline_state_dict: PipelineStateDict):
pipeline_state = PipelineState(**pipeline_state_dict)
validate_workflow(
pipeline_state.workflow,
CONFIGS["bonus"]["required_input_vars"],
CONFIGS["bonus"]["required_output_vars"],
)
return pipeline_state
class UserInputWorkflowValidator:
def __init__(self, mode: Literal["tossup", "bonus"]):
self.required_input_vars = CONFIGS[mode]["required_input_vars"]
self.required_output_vars = CONFIGS[mode]["required_output_vars"]
def __call__(self, workflow: TossupWorkflow):
input_vars = set(workflow.inputs)
for req_var in self.required_input_vars:
if req_var not in input_vars:
default_str = "inputs:\n" + "\n".join([f"- {var}" for var in self.required_input_vars])
raise ValueError(
f"Missing required input variable: '{req_var}'. "
"\nDon't modify the 'inputs' field in the workflow. "
"Please set it back to:"
f"\n{default_str}"
)
output_vars = set(workflow.outputs)
for req_var in self.required_output_vars:
if req_var not in output_vars:
default_str = "[" + ", ".join([f"'{var}'" for var in self.required_output_vars]) + "]"
raise ValueError(
f"Missing required output variable: '{req_var}'. "
"\nDon't remove the keys from the 'outputs' field in the workflow. Only update their values."
f"\nMake sure you have values set for all the outputs: {default_str}"
)
# Validate the workflow
allowed_model_names = AVAILABLE_MODELS.keys()
self.validator = WorkflowValidator(allowed_model_names=allowed_model_names)
try:
self.validator.validate(workflow, allow_empty=True)
except WorkflowValidationError as e:
error_msg_total = f"Found {len(e.errors)} errors in the workflow:\n"
error_msg_list = [f"- {err.message}" for err in e.errors]
error_msg = error_msg_total + "\n".join(error_msg_list)
raise ValueError(error_msg)
|