Maharshi Gor
Enhances model selection and logging in pipeline components; adds logprobs support and improves UI feedback for disabled sliders.
4f5d1cb
import json | |
from typing import Any, Literal | |
import gradio as gr | |
import yaml | |
from loguru import logger | |
from pydantic import BaseModel, Field | |
from components import utils | |
from workflows.factory import create_new_llm_step | |
from workflows.structs import ModelStep, TossupWorkflow, Workflow | |
def make_step_id(step_number: int): | |
"""Make a step id from a step name.""" | |
if step_number < 26: | |
return chr(ord("A") + step_number) | |
else: | |
# For more than 26 steps, use AA, AB, AC, etc. | |
first_char = chr(ord("A") + (step_number // 26) - 1) | |
second_char = chr(ord("A") + (step_number % 26)) | |
return f"{first_char}{second_char}" | |
def make_step_number(step_id: str): | |
"""Make a step number from a step id.""" | |
if len(step_id) == 1: | |
return ord(step_id) - ord("A") | |
else: | |
return (ord(step_id[0]) - ord("A")) * 26 + (ord(step_id[1]) - ord("A")) + 1 | |
class ModelStepUIState(BaseModel): | |
"""Represents the UI state for a model step component.""" | |
expanded: bool = True | |
active_tab: Literal["model-tab", "inputs-tab", "outputs-tab"] = "model-tab" | |
def update(self, key: str, value: Any) -> "ModelStepUIState": | |
"""Update the UI state.""" | |
new_state = self.model_copy(update={key: value}) | |
return new_state | |
class PipelineUIState(BaseModel): | |
"""Represents the UI state for a pipeline component.""" | |
step_ids: list[str] = Field(default_factory=list) | |
steps: dict[str, ModelStepUIState] = Field(default_factory=dict) | |
def model_post_init(self, __context: utils.Any) -> None: | |
if not self.steps and self.step_ids: | |
self.steps = {step_id: ModelStepUIState() for step_id in self.step_ids} | |
return super().model_post_init(__context) | |
def get_step_position(self, step_id: str): | |
"""Get the position of a step in the pipeline.""" | |
return next((i for i, step in enumerate(self.step_ids) if step == step_id), None) | |
def n_steps(self) -> int: | |
"""Get the number of steps in the pipeline.""" | |
return len(self.step_ids) | |
def from_workflow(cls, workflow: Workflow): | |
"""Create a pipeline UI state from a workflow.""" | |
return PipelineUIState( | |
step_ids=list(workflow.steps.keys()), | |
steps={step_id: ModelStepUIState() for step_id in workflow.steps.keys()}, | |
) | |
class PipelineState(BaseModel): | |
"""Represents the state for a pipeline component.""" | |
workflow: Workflow | |
ui_state: PipelineUIState | |
def insert_step(self, position: int, step: ModelStep) -> "PipelineState": | |
if step.id in self.workflow.steps: | |
raise ValueError(f"Step {step.id} already exists in pipeline") | |
# Validate position | |
if position != -1 and (position < 0 or position > self.n_steps): | |
raise ValueError(f"Invalid position: {position}. Must be between 0 and {self.n_steps} or -1") | |
self.workflow.steps[step.id] = step | |
self.ui_state = self.ui_state.model_copy() | |
self.ui_state.steps[step.id] = ModelStepUIState() | |
if position == -1: | |
self.ui_state.step_ids.append(step.id) | |
else: | |
self.ui_state.step_ids.insert(position, step.id) | |
return self | |
def remove_step(self, position: int) -> "PipelineState": | |
step_id = self.ui_state.step_ids.pop(position) | |
self.workflow.steps.pop(step_id) | |
self.ui_state = self.ui_state.model_copy() | |
self.ui_state.steps.pop(step_id) | |
self.update_output_variables_mapping() | |
return self | |
def update_output_variables_mapping(self) -> "PipelineState": | |
available_variables = set(self.available_variables) | |
for output_field in self.workflow.outputs: | |
if self.workflow.outputs[output_field] not in available_variables: | |
self.workflow.outputs[output_field] = None | |
return self | |
def available_variables(self) -> list[str]: | |
return self.workflow.get_available_variables() | |
def n_steps(self) -> int: | |
return len(self.workflow.steps) | |
def get_new_step_id(self) -> str: | |
"""Get a step ID for a new step.""" | |
if not self.workflow.steps: | |
return "A" | |
else: | |
last_step_number = max(map(make_step_number, self.workflow.steps.keys())) | |
return make_step_id(last_step_number + 1) | |
class PipelineStateManager: | |
"""Manages a pipeline of multiple steps.""" | |
def get_formatted_config(self, state: PipelineState, format: Literal["json", "yaml"] = "yaml"): | |
"""Get the full pipeline configuration.""" | |
config = state.workflow.model_dump(exclude_defaults=True) | |
if isinstance(state.workflow, TossupWorkflow): | |
buzzer_config = state.workflow.buzzer.model_dump(exclude_defaults=False) | |
config["buzzer"] = buzzer_config | |
if format == "yaml": | |
return yaml.dump(config, default_flow_style=False, sort_keys=False, indent=4) | |
else: | |
return json.dumps(config, indent=4, sort_keys=False) | |
def count_state(self): | |
return gr.State(len(self.steps)) | |
def add_step(self, state: PipelineState, position: int = -1, name=""): | |
"""Create a new step and return its state.""" | |
step_id = state.get_new_step_id() | |
step_name = name or f"Step {state.n_steps + 1}" | |
new_step = create_new_llm_step(step_id=step_id, name=step_name) | |
state = state.insert_step(position, new_step) | |
return state, state.ui_state, state.available_variables | |
def remove_step(self, state: PipelineState, position: int): | |
"""Remove a step from the pipeline.""" | |
if 0 <= position < state.n_steps: | |
state = state.remove_step(position) | |
else: | |
raise ValueError(f"Invalid step position: {position}") | |
return state, state.ui_state, state.available_variables | |
def move_up(self, ui_state: PipelineUIState, position: int): | |
"""Move a step up in the pipeline.""" | |
utils.move_item(ui_state.step_ids, position, "up") | |
return ui_state.model_copy() | |
def move_down(self, ui_state: PipelineUIState, position: int): | |
"""Move a step down in the pipeline.""" | |
utils.move_item(ui_state.step_ids, position, "down") | |
return ui_state.model_copy() | |
def update_model_step_state(self, state: PipelineState, model_step: ModelStep, ui_state: ModelStepUIState): | |
"""Update a step in the pipeline.""" | |
state.workflow.steps[model_step.id] = model_step.model_copy() | |
state.ui_state.steps[model_step.id] = ui_state.model_copy() | |
state.ui_state = state.ui_state.model_copy() | |
state.update_output_variables_mapping() | |
model_selections = {step_id: step.get_full_model_name() for step_id, step in state.workflow.steps.items()} | |
return state, state.ui_state, state.available_variables, model_selections | |
def update_output_variables(self, state: PipelineState, target: str, produced_variable: str): | |
if produced_variable == "Choose variable...": | |
produced_variable = None | |
"""Update the output variables for a step.""" | |
state.workflow.outputs.update({target: produced_variable}) | |
return state | |
def update_model_step_ui(self, state: PipelineState, step_ui: ModelStepUIState, step_id: str): | |
"""Update a step in the pipeline.""" | |
state.ui_state.steps[step_id] = step_ui.model_copy() | |
return state, state.ui_state | |
def get_all_variables(self, state: PipelineState, model_step_id: str | None = None) -> list[str]: | |
"""Get all variables from all steps.""" | |
available_variables = state.available_variables | |
if model_step_id is None: | |
return available_variables | |
else: | |
prefix = f"{model_step_id}." | |
return [var for var in available_variables if not var.startswith(prefix)] | |
def get_pipeline_config(self): | |
"""Get the full pipeline configuration.""" | |
return self.workflow | |