# %% import json from abc import ABC, abstractmethod from typing import Literal import gradio as gr import yaml from loguru import logger from pydantic import BaseModel, ValidationError from app_configs import UNSELECTED_VAR_NAME from components import typed_dicts as td from components import utils from components.structs import ModelStepUIState, PipelineState, PipelineUIState, TossupPipelineState from envs import DOCS_REPO_URL, DOCS_URL from shared.workflows.factory import create_new_llm_step from shared.workflows.structs import Buzzer, BuzzerMethod, ModelStep, TossupWorkflow, Workflow def get_output_panel_state(workflow: Workflow) -> dict: state = { "variables": workflow.get_available_variables(), "models": workflow.get_step_model_selections(), "output_models": workflow.get_output_model_selections(), } if isinstance(workflow, TossupWorkflow): state["buzzer"] = workflow.buzzer.model_dump(exclude_defaults=False) return state def strict_model_validate(model_cls: type[BaseModel], data: dict): # Dynamically create a subclass with extra='forbid' class_name = model_cls.__name__ strict_class_name = f"Strict{class_name}" strict_class = type( strict_class_name, (model_cls,), {"model_config": {**getattr(model_cls, "model_config", {}), "extra": "forbid"}}, ) return strict_class.model_validate(data) class BasePipelineValidator(ABC): """Abstract base class for pipeline validators.""" @abstractmethod def __call__(self, workflow: Workflow): """ Validate the workflow. Args: workflow: The workflow to validate. Raises: ValueError: If the workflow is invalid. """ pass class PipelineStateManager: """Manages a pipeline of multiple steps.""" pipeline_state_cls = PipelineState workflow_cls = Workflow def __init__(self, validator: BasePipelineValidator | None = None): self.validator = validator def make_pipeline_state(self, state_dict: td.PipelineStateDict) -> PipelineState: """Make a state from a state dictionary.""" return self.pipeline_state_cls(**state_dict) def create_pipeline_state_dict(self, workflow: Workflow, ui_state: PipelineUIState) -> td.PipelineStateDict: """Create a pipeline state from a workflow.""" return self.pipeline_state_cls(workflow=workflow, ui_state=ui_state).model_dump() def add_step( self, state_dict: td.PipelineStateDict, pipeline_change: bool, position: int = -1, name="" ) -> td.PipelineStateDict: """Create a new step and return its state.""" state = self.make_pipeline_state(state_dict) step_id = state.get_new_step_id() step_name = name or f"Step {state.n_steps + 1}" new_step = create_new_llm_step(step_id=step_id, name=step_name) state = state.insert_step(position, new_step) return state.model_dump(), not pipeline_change def remove_step( self, state_dict: td.PipelineStateDict, pipeline_change: bool, position: int ) -> td.PipelineStateDict: """Remove a step from the pipeline.""" state = self.make_pipeline_state(state_dict) if 0 <= position < state.n_steps: state = state.remove_step(position) else: raise ValueError(f"Invalid step position: {position}") return state.model_dump(), not pipeline_change def _move_step( self, state_dict: td.PipelineStateDict, position: int, direction: Literal["up", "down"] ) -> tuple[td.PipelineStateDict, bool]: state = self.make_pipeline_state(state_dict) old_order = list(state.ui_state.step_ids) utils.move_item(state.ui_state.step_ids, position, direction) return state.model_dump(), old_order != list(state.ui_state.step_ids) def move_up(self, state_dict: td.PipelineStateDict, pipeline_change: bool, position: int) -> td.PipelineStateDict: """Move a step up in the pipeline.""" new_state_dict, change = self._move_step(state_dict, position, "up") if change: pipeline_change = not pipeline_change return new_state_dict, pipeline_change def move_down(self, state_dict: td.PipelineStateDict, pipeline_change: bool, position: int) -> td.PipelineStateDict: """Move a step down in the pipeline.""" new_state_dict, change = self._move_step(state_dict, position, "down") if change: pipeline_change = not pipeline_change return new_state_dict, pipeline_change def update_model_step_state( self, state_dict: td.PipelineStateDict, model_step: ModelStep, ui_state: ModelStepUIState ) -> td.PipelineStateDict: """Update a particular model step in the pipeline.""" state = self.make_pipeline_state(state_dict) state = state.update_step(model_step, ui_state) return state.model_dump() def update_output_variables( self, state_dict: td.PipelineStateDict, target: str, produced_variable: str ) -> td.PipelineStateDict: if produced_variable == UNSELECTED_VAR_NAME: produced_variable = None """Update the output variables for a step.""" state = self.make_pipeline_state(state_dict) state = state.update_output_variable(target, produced_variable) return state.model_dump() def update_model_step_ui( self, state_dict: td.PipelineStateDict, step_ui: ModelStepUIState, step_id: str ) -> td.PipelineStateDict: """Update a step in the pipeline.""" state = self.make_pipeline_state(state_dict) state.ui_state.steps[step_id] = step_ui.model_copy() return state.model_dump() def get_all_variables(self, state_dict: td.PipelineStateDict, model_step_id: str | None = None) -> list[str]: """Get all variables from all steps.""" return self.make_pipeline_state(state_dict) def parse_yaml_workflow(self, yaml_str: str, strict: bool = True) -> Workflow: """Parse a YAML workflow.""" workflow = yaml.safe_load(yaml_str) try: if strict: return strict_model_validate(self.workflow_cls, workflow) else: return self.workflow_cls.model_validate(workflow) except ValidationError as e: new_exception = ValidationError.from_exception_data( e.title.removeprefix("Strict"), e.errors(), input_type="json" ) raise new_exception from e def create_pipeline_error_response(self, e: Exception) -> str: """Format error messages for pipeline parsing errors with consistent styling.""" error_template = """
{error_type}:
{error_message}
{help_text}
""" if isinstance(e, yaml.YAMLError): error_type = "Invalid YAML Error" help_text = "Refer to the YAML schema for correct formatting." elif isinstance(e, ValidationError): error_type = "Pipeline Parsing Error" help_text = f"Refer to the documentation for the correct pipeline schema." elif isinstance(e, ValueError): error_type = "Pipeline Validation Error" help_text = f"Refer to the documentation for the correct pipeline schema." else: error_type = "Unexpected Error" help_text = ( f"Please report this issue to us at GitHub Issues." ) return error_template.format(error_type=error_type, error_message=str(e), help_text=help_text) def get_formatted_config( self, state_dict: td.PipelineStateDict, format: Literal["json", "yaml"] = "yaml" ) -> tuple[str, dict]: """Get the full pipeline configuration.""" try: state = self.make_pipeline_state(state_dict) config = state.workflow.model_dump(exclude_defaults=True) if isinstance(state.workflow, TossupWorkflow): buzzer_config = state.workflow.buzzer.model_dump(exclude_defaults=False) config["buzzer"] = buzzer_config if format == "yaml": config_str = yaml.dump(config, default_flow_style=False, sort_keys=False, indent=4) else: config_str = json.dumps(config, indent=4, sort_keys=False) return config_str, gr.update(visible=False) except Exception as e: error_message = self.create_pipeline_error_response(e) return gr.skip(), gr.update(value=error_message, visible=True) def update_workflow_from_code(self, yaml_str: str, change_state: bool) -> tuple[td.PipelineStateDict, bool, dict]: """Update a workflow from a YAML string.""" try: workflow = self.parse_yaml_workflow(yaml_str, strict=True) self.validator and self.validator(workflow) state = self.pipeline_state_cls.from_workflow(workflow) return state.model_dump(), not change_state, gr.update(visible=False) except Exception as e: error_message = self.create_pipeline_error_response(e) return gr.skip(), gr.skip(), gr.update(value=error_message, visible=True) class TossupPipelineStateManager(PipelineStateManager): """Manages a tossup pipeline state.""" pipeline_state_cls = TossupPipelineState workflow_cls = TossupWorkflow def make_pipeline_state(self, state_dict: td.PipelineStateDict) -> TossupPipelineState: return super().make_pipeline_state(state_dict) def create_pipeline_state_dict( self, workflow: TossupWorkflow, ui_state: PipelineUIState ) -> td.TossupPipelineStateDict: return super().create_pipeline_state_dict(workflow, ui_state) def update_workflow_from_code( self, yaml_str: str, change_state: bool ) -> tuple[td.TossupPipelineStateDict, bool, dict]: return super().update_workflow_from_code(yaml_str, change_state) def update_model_step_state( self, state_dict: td.TossupPipelineStateDict, model_step: ModelStep, ui_state: ModelStepUIState ) -> td.TossupPipelineStateDict: return super().update_model_step_state(state_dict, model_step, ui_state) def update_output_variables( self, state_dict: td.TossupPipelineStateDict, target: str, produced_variable: str ) -> td.TossupPipelineStateDict: return super().update_output_variables(state_dict, target, produced_variable) def update_buzzer( self, state_dict: td.TossupPipelineStateDict, confidence_threshold: float, method: str, tokens_prob: float | None, ) -> td.TossupPipelineStateDict: """Update the buzzer.""" state = self.make_pipeline_state(state_dict) prob_threshold = float(tokens_prob) if tokens_prob and tokens_prob > 0 else None if method == BuzzerMethod.OR and prob_threshold is None: prob_threshold = 0.0 state.workflow.buzzer = Buzzer( method=method, confidence_threshold=confidence_threshold, prob_threshold=prob_threshold ) return state.model_dump()