Maharshi Gor
Updated workflow APIs, code clean up and minor functions for hf pipeline support
f064c62
# %% | |
import json | |
from abc import ABC, abstractmethod | |
from typing import Literal | |
import gradio as gr | |
import yaml | |
from loguru import logger | |
from pydantic import BaseModel, ValidationError | |
from app_configs import UNSELECTED_VAR_NAME | |
from components import typed_dicts as td | |
from components import utils | |
from components.structs import ModelStepUIState, PipelineState, PipelineUIState, TossupPipelineState | |
from envs import DOCS_REPO_URL, DOCS_URL | |
from shared.workflows.factory import create_new_llm_step | |
from shared.workflows.structs import Buzzer, BuzzerMethod, ModelStep, TossupWorkflow, Workflow | |
def get_output_panel_state(workflow: Workflow) -> dict: | |
state = { | |
"variables": workflow.get_available_variables(), | |
"models": workflow.get_step_model_selections(), | |
"output_models": workflow.get_output_model_selections(), | |
} | |
if isinstance(workflow, TossupWorkflow): | |
state["buzzer"] = workflow.buzzer.model_dump(exclude_defaults=False) | |
return state | |
def strict_model_validate(model_cls: type[BaseModel], data: dict): | |
# Dynamically create a subclass with extra='forbid' | |
class_name = model_cls.__name__ | |
strict_class_name = f"Strict{class_name}" | |
strict_class = type( | |
strict_class_name, | |
(model_cls,), | |
{"model_config": {**getattr(model_cls, "model_config", {}), "extra": "forbid"}}, | |
) | |
return strict_class.model_validate(data) | |
class BasePipelineValidator(ABC): | |
"""Abstract base class for pipeline validators.""" | |
def __call__(self, workflow: Workflow): | |
""" | |
Validate the workflow. | |
Args: | |
workflow: The workflow to validate. | |
Raises: | |
ValueError: If the workflow is invalid. | |
""" | |
pass | |
class PipelineStateManager: | |
"""Manages a pipeline of multiple steps.""" | |
pipeline_state_cls = PipelineState | |
workflow_cls = Workflow | |
def __init__(self, validator: BasePipelineValidator | None = None): | |
self.validator = validator | |
def make_pipeline_state(self, state_dict: td.PipelineStateDict) -> PipelineState: | |
"""Make a state from a state dictionary.""" | |
return self.pipeline_state_cls(**state_dict) | |
def create_pipeline_state_dict(self, workflow: Workflow, ui_state: PipelineUIState) -> td.PipelineStateDict: | |
"""Create a pipeline state from a workflow.""" | |
return self.pipeline_state_cls(workflow=workflow, ui_state=ui_state).model_dump() | |
def add_step( | |
self, state_dict: td.PipelineStateDict, pipeline_change: bool, position: int = -1, name="" | |
) -> td.PipelineStateDict: | |
"""Create a new step and return its state.""" | |
state = self.make_pipeline_state(state_dict) | |
step_id = state.get_new_step_id() | |
step_name = name or f"Step {state.n_steps + 1}" | |
new_step = create_new_llm_step(step_id=step_id, name=step_name) | |
state = state.insert_step(position, new_step) | |
return state.model_dump(), not pipeline_change | |
def remove_step( | |
self, state_dict: td.PipelineStateDict, pipeline_change: bool, position: int | |
) -> td.PipelineStateDict: | |
"""Remove a step from the pipeline.""" | |
state = self.make_pipeline_state(state_dict) | |
if 0 <= position < state.n_steps: | |
state = state.remove_step(position) | |
else: | |
raise ValueError(f"Invalid step position: {position}") | |
return state.model_dump(), not pipeline_change | |
def _move_step( | |
self, state_dict: td.PipelineStateDict, position: int, direction: Literal["up", "down"] | |
) -> tuple[td.PipelineStateDict, bool]: | |
state = self.make_pipeline_state(state_dict) | |
old_order = list(state.ui_state.step_ids) | |
utils.move_item(state.ui_state.step_ids, position, direction) | |
return state.model_dump(), old_order != list(state.ui_state.step_ids) | |
def move_up(self, state_dict: td.PipelineStateDict, pipeline_change: bool, position: int) -> td.PipelineStateDict: | |
"""Move a step up in the pipeline.""" | |
new_state_dict, change = self._move_step(state_dict, position, "up") | |
if change: | |
pipeline_change = not pipeline_change | |
return new_state_dict, pipeline_change | |
def move_down(self, state_dict: td.PipelineStateDict, pipeline_change: bool, position: int) -> td.PipelineStateDict: | |
"""Move a step down in the pipeline.""" | |
new_state_dict, change = self._move_step(state_dict, position, "down") | |
if change: | |
pipeline_change = not pipeline_change | |
return new_state_dict, pipeline_change | |
def update_model_step_state( | |
self, state_dict: td.PipelineStateDict, model_step: ModelStep, ui_state: ModelStepUIState | |
) -> td.PipelineStateDict: | |
"""Update a particular model step in the pipeline.""" | |
state = self.make_pipeline_state(state_dict) | |
state = state.update_step(model_step, ui_state) | |
return state.model_dump() | |
def update_output_variables( | |
self, state_dict: td.PipelineStateDict, target: str, produced_variable: str | |
) -> td.PipelineStateDict: | |
if produced_variable == UNSELECTED_VAR_NAME: | |
produced_variable = None | |
"""Update the output variables for a step.""" | |
state = self.make_pipeline_state(state_dict) | |
state = state.update_output_variable(target, produced_variable) | |
return state.model_dump() | |
def update_model_step_ui( | |
self, state_dict: td.PipelineStateDict, step_ui: ModelStepUIState, step_id: str | |
) -> td.PipelineStateDict: | |
"""Update a step in the pipeline.""" | |
state = self.make_pipeline_state(state_dict) | |
state.ui_state.steps[step_id] = step_ui.model_copy() | |
return state.model_dump() | |
def get_all_variables(self, state_dict: td.PipelineStateDict, model_step_id: str | None = None) -> list[str]: | |
"""Get all variables from all steps.""" | |
return self.make_pipeline_state(state_dict) | |
def parse_yaml_workflow(self, yaml_str: str, strict: bool = True) -> Workflow: | |
"""Parse a YAML workflow.""" | |
workflow = yaml.safe_load(yaml_str) | |
try: | |
if strict: | |
return strict_model_validate(self.workflow_cls, workflow) | |
else: | |
return self.workflow_cls.model_validate(workflow) | |
except ValidationError as e: | |
new_exception = ValidationError.from_exception_data( | |
e.title.removeprefix("Strict"), e.errors(), input_type="json" | |
) | |
raise new_exception from e | |
def create_pipeline_error_response(self, e: Exception) -> str: | |
"""Format error messages for pipeline parsing errors with consistent styling.""" | |
error_template = """ | |
<div class="md" style='color: #FF0000; background-color: #FFEEEE; padding: 10px; border-radius: 5px; border-left: 4px solid #FF0000;'> | |
<strong style='color: #FF0000;'>{error_type}:</strong> <br> | |
<div class="code-wrap"> | |
<pre><code>{error_message}</code></pre> | |
</div> | |
{help_text} | |
</div> | |
""" | |
if isinstance(e, yaml.YAMLError): | |
error_type = "Invalid YAML Error" | |
help_text = "Refer to the <a href='https://spacelift.io/blog/yaml#basic-yaml-syntax' target='_blank'>YAML schema</a> for correct formatting." | |
elif isinstance(e, ValidationError): | |
error_type = "Pipeline Parsing Error" | |
help_text = f"Refer to the <a href='{DOCS_URL}/pipeline-schema.md' target='_blank'>documentation</a> for the correct pipeline schema." | |
elif isinstance(e, ValueError): | |
error_type = "Pipeline Validation Error" | |
help_text = f"Refer to the <a href='{DOCS_URL}/pipeline-schema.md' target='_blank'>documentation</a> for the correct pipeline schema." | |
else: | |
error_type = "Unexpected Error" | |
help_text = ( | |
f"Please report this issue to us at <a href='{DOCS_REPO_URL}/issues' target='_blank'>GitHub Issues</a>." | |
) | |
return error_template.format(error_type=error_type, error_message=str(e), help_text=help_text) | |
def get_formatted_config( | |
self, state_dict: td.PipelineStateDict, format: Literal["json", "yaml"] = "yaml" | |
) -> tuple[str, dict]: | |
"""Get the full pipeline configuration.""" | |
try: | |
state = self.make_pipeline_state(state_dict) | |
config = state.workflow.model_dump(exclude_defaults=True) | |
if isinstance(state.workflow, TossupWorkflow): | |
buzzer_config = state.workflow.buzzer.model_dump(exclude_defaults=False) | |
config["buzzer"] = buzzer_config | |
if format == "yaml": | |
config_str = yaml.dump(config, default_flow_style=False, sort_keys=False, indent=4) | |
else: | |
config_str = json.dumps(config, indent=4, sort_keys=False) | |
return config_str, gr.update(visible=False) | |
except Exception as e: | |
error_message = self.create_pipeline_error_response(e) | |
return gr.skip(), gr.update(value=error_message, visible=True) | |
def update_workflow_from_code(self, yaml_str: str, change_state: bool) -> tuple[td.PipelineStateDict, bool, dict]: | |
"""Update a workflow from a YAML string.""" | |
try: | |
workflow = self.parse_yaml_workflow(yaml_str, strict=True) | |
self.validator and self.validator(workflow) | |
state = self.pipeline_state_cls.from_workflow(workflow) | |
return state.model_dump(), not change_state, gr.update(visible=False) | |
except Exception as e: | |
error_message = self.create_pipeline_error_response(e) | |
return gr.skip(), gr.skip(), gr.update(value=error_message, visible=True) | |
class TossupPipelineStateManager(PipelineStateManager): | |
"""Manages a tossup pipeline state.""" | |
pipeline_state_cls = TossupPipelineState | |
workflow_cls = TossupWorkflow | |
def make_pipeline_state(self, state_dict: td.PipelineStateDict) -> TossupPipelineState: | |
return super().make_pipeline_state(state_dict) | |
def create_pipeline_state_dict( | |
self, workflow: TossupWorkflow, ui_state: PipelineUIState | |
) -> td.TossupPipelineStateDict: | |
return super().create_pipeline_state_dict(workflow, ui_state) | |
def update_workflow_from_code( | |
self, yaml_str: str, change_state: bool | |
) -> tuple[td.TossupPipelineStateDict, bool, dict]: | |
return super().update_workflow_from_code(yaml_str, change_state) | |
def update_model_step_state( | |
self, state_dict: td.TossupPipelineStateDict, model_step: ModelStep, ui_state: ModelStepUIState | |
) -> td.TossupPipelineStateDict: | |
return super().update_model_step_state(state_dict, model_step, ui_state) | |
def update_output_variables( | |
self, state_dict: td.TossupPipelineStateDict, target: str, produced_variable: str | |
) -> td.TossupPipelineStateDict: | |
return super().update_output_variables(state_dict, target, produced_variable) | |
def update_buzzer( | |
self, | |
state_dict: td.TossupPipelineStateDict, | |
confidence_threshold: float, | |
method: str, | |
tokens_prob: float | None, | |
) -> td.TossupPipelineStateDict: | |
"""Update the buzzer.""" | |
state = self.make_pipeline_state(state_dict) | |
prob_threshold = float(tokens_prob) if tokens_prob and tokens_prob > 0 else None | |
if method == BuzzerMethod.OR and prob_threshold is None: | |
prob_threshold = 0.0 | |
state.workflow.buzzer = Buzzer( | |
method=method, confidence_threshold=confidence_threshold, prob_threshold=prob_threshold | |
) | |
return state.model_dump() | |