File size: 3,967 Bytes
849566b
 
e272e20
f9589f4
 
5f3e7d5
 
f9589f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
849566b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e272e20
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from typing import Literal

from app_configs import AVAILABLE_MODELS, CONFIGS
from components.structs import PipelineState, TossupPipelineState
from components.typed_dicts import PipelineStateDict, TossupPipelineStateDict
from shared.workflows.structs import TossupWorkflow, Workflow
from shared.workflows.validators import WorkflowValidationError, WorkflowValidator


def validate_workflow(
    workflow: TossupWorkflow | Workflow, required_input_vars: list[str], required_output_vars: list[str]
):
    """
    Validate that a workflow is properly configured for the tossup task.

    Args:
        workflow (TossupWorkflow): The workflow to validate

    Raises:
        ValueError: If the workflow is not properly configured
    """
    if not workflow.steps:
        raise ValueError("Workflow must have at least one step")

    # Check that the workflow has the correct structure
    input_vars = set(workflow.inputs)
    for req_var in required_input_vars:
        if req_var not in input_vars:
            raise ValueError(f"Workflow must have '{req_var}' as an input")

    output_vars = set(workflow.outputs)
    for req_var in required_output_vars:
        if req_var not in output_vars:
            raise ValueError(f"Workflow must produce '{req_var}' as an output")

    # Ensure all steps are properly configured
    WorkflowValidator().validate(workflow)


def validate_tossup_workflow(pipeline_state_dict: TossupPipelineStateDict) -> TossupPipelineState:
    pipeline_state = TossupPipelineState(**pipeline_state_dict)
    validate_workflow(
        pipeline_state.workflow,
        CONFIGS["tossup"]["required_input_vars"],
        CONFIGS["tossup"]["required_output_vars"],
    )
    return pipeline_state


def validate_bonus_workflow(pipeline_state_dict: PipelineStateDict):
    pipeline_state = PipelineState(**pipeline_state_dict)
    validate_workflow(
        pipeline_state.workflow,
        CONFIGS["bonus"]["required_input_vars"],
        CONFIGS["bonus"]["required_output_vars"],
    )
    return pipeline_state


class UserInputWorkflowValidator:
    def __init__(self, mode: Literal["tossup", "bonus"]):
        self.required_input_vars = CONFIGS[mode]["required_input_vars"]
        self.required_output_vars = CONFIGS[mode]["required_output_vars"]

    def __call__(self, workflow: TossupWorkflow):
        input_vars = set(workflow.inputs)
        for req_var in self.required_input_vars:
            if req_var not in input_vars:
                default_str = "inputs:\n" + "\n".join([f"- {var}" for var in self.required_input_vars])
                raise ValueError(
                    f"Missing required input variable: '{req_var}'. "
                    "\nDon't modify the 'inputs' field in the workflow. "
                    "Please set it back to:"
                    f"\n{default_str}"
                )

        output_vars = set(workflow.outputs)
        for req_var in self.required_output_vars:
            if req_var not in output_vars:
                default_str = "[" + ", ".join([f"'{var}'" for var in self.required_output_vars]) + "]"
                raise ValueError(
                    f"Missing required output variable: '{req_var}'. "
                    "\nDon't remove the keys from the 'outputs' field in the workflow. Only update their values."
                    f"\nMake sure you have values set for all the outputs: {default_str}"
                )

        # Validate the workflow
        allowed_model_names = AVAILABLE_MODELS.keys()
        self.validator = WorkflowValidator(allowed_model_names=allowed_model_names)
        try:
            self.validator.validate(workflow, allow_empty=True)
        except WorkflowValidationError as e:
            error_msg_total = f"Found {len(e.errors)} errors in the workflow:\n"
            error_msg_list = [f"- {err.message}" for err in e.errors]
            error_msg = error_msg_total + "\n".join(error_msg_list)
            raise ValueError(error_msg)