Maharshi Gor commited on
Commit
849566b
·
1 Parent(s): c1ae336

Add user input validation to pipeline interfaces error display on pipeline change.

Browse files
src/components/model_pipeline/model_pipeline.py CHANGED
@@ -10,13 +10,14 @@ from components.model_pipeline.state_manager import (
10
  PipelineState,
11
  PipelineStateManager,
12
  PipelineUIState,
 
13
  TossupPipelineState,
14
  TossupPipelineStateManager,
15
  )
16
  from components.model_step.model_step import ModelStepComponent
17
  from components.utils import make_state
18
  from workflows.structs import ModelStep, TossupWorkflow, Workflow
19
- from workflows.validators import WorkflowValidator
20
 
21
  from .state_manager import get_output_panel_state
22
 
@@ -33,6 +34,7 @@ class PipelineInterface:
33
  ui_state: PipelineUIState | None = None,
34
  model_options: list[str] = None,
35
  config: dict = {},
 
36
  ):
37
  self.app = app
38
  self.model_options = model_options
@@ -50,10 +52,10 @@ class PipelineInterface:
50
 
51
  if isinstance(workflow, TossupWorkflow):
52
  pipeline_state = TossupPipelineState(workflow=workflow, ui_state=ui_state)
53
- self.sm = TossupPipelineStateManager()
54
  else:
55
  pipeline_state = PipelineState(workflow=workflow, ui_state=ui_state)
56
- self.sm = PipelineStateManager()
57
  self.pipeline_state = make_state(pipeline_state.model_dump())
58
 
59
  def get_aux_states(pipeline_state_dict: td.PipelineStateDict):
@@ -169,7 +171,11 @@ class PipelineInterface:
169
  """Validate the workflow."""
170
  try:
171
  state = self.sm.make_pipeline_state(state_dict)
172
- WorkflowValidator().validate(state.workflow)
 
 
 
 
173
  except ValueError as e:
174
  logger.exception(e)
175
  state_dict_str = yaml.dump(state_dict, default_flow_style=False, indent=2)
@@ -244,6 +250,7 @@ class PipelineInterface:
244
  def _render_pipeline_preview(self):
245
  export_btn = gr.Button("Export Pipeline", elem_classes="export-button", visible=False)
246
  # components.append(export_btn)
 
247
 
248
  # Add a code box to display the workflow JSON
249
  # with gr.Column(elem_classes="workflow-json-container"):
@@ -262,7 +269,7 @@ class PipelineInterface:
262
  self.config_output.blur(
263
  fn=self.sm.update_workflow_from_code,
264
  inputs=[self.config_output, self.pipeline_change],
265
- outputs=[self.pipeline_state, self.pipeline_change],
266
  )
267
 
268
  # Connect the export button to show the workflow JSON
@@ -326,6 +333,6 @@ class PipelineInterface:
326
  ).success(
327
  fn=self.sm.get_formatted_config,
328
  inputs=[self.pipeline_state, gr.State("yaml")],
329
- outputs=[self.config_output],
330
  js=js,
331
  )
 
10
  PipelineState,
11
  PipelineStateManager,
12
  PipelineUIState,
13
+ PipelineValidator,
14
  TossupPipelineState,
15
  TossupPipelineStateManager,
16
  )
17
  from components.model_step.model_step import ModelStepComponent
18
  from components.utils import make_state
19
  from workflows.structs import ModelStep, TossupWorkflow, Workflow
20
+ from workflows.validators import WorkflowValidationError, WorkflowValidator
21
 
22
  from .state_manager import get_output_panel_state
23
 
 
34
  ui_state: PipelineUIState | None = None,
35
  model_options: list[str] = None,
36
  config: dict = {},
37
+ validator: PipelineValidator | None = None,
38
  ):
39
  self.app = app
40
  self.model_options = model_options
 
52
 
53
  if isinstance(workflow, TossupWorkflow):
54
  pipeline_state = TossupPipelineState(workflow=workflow, ui_state=ui_state)
55
+ self.sm = TossupPipelineStateManager(validator)
56
  else:
57
  pipeline_state = PipelineState(workflow=workflow, ui_state=ui_state)
58
+ self.sm = PipelineStateManager(validator)
59
  self.pipeline_state = make_state(pipeline_state.model_dump())
60
 
61
  def get_aux_states(pipeline_state_dict: td.PipelineStateDict):
 
171
  """Validate the workflow."""
172
  try:
173
  state = self.sm.make_pipeline_state(state_dict)
174
+ validator = WorkflowValidator(
175
+ max_temperature=self.config.get("max_temperature", 10),
176
+ )
177
+ if not validator.validate(state.workflow):
178
+ raise WorkflowValidationError(validator.errors)
179
  except ValueError as e:
180
  logger.exception(e)
181
  state_dict_str = yaml.dump(state_dict, default_flow_style=False, indent=2)
 
250
  def _render_pipeline_preview(self):
251
  export_btn = gr.Button("Export Pipeline", elem_classes="export-button", visible=False)
252
  # components.append(export_btn)
253
+ self.error_display = gr.HTML(label="Error", elem_id="pipeline-preview-error-display", visible=False)
254
 
255
  # Add a code box to display the workflow JSON
256
  # with gr.Column(elem_classes="workflow-json-container"):
 
269
  self.config_output.blur(
270
  fn=self.sm.update_workflow_from_code,
271
  inputs=[self.config_output, self.pipeline_change],
272
+ outputs=[self.pipeline_state, self.pipeline_change, self.error_display],
273
  )
274
 
275
  # Connect the export button to show the workflow JSON
 
333
  ).success(
334
  fn=self.sm.get_formatted_config,
335
  inputs=[self.pipeline_state, gr.State("yaml")],
336
+ outputs=[self.config_output, self.error_display],
337
  js=js,
338
  )
src/components/model_pipeline/state_manager.py CHANGED
@@ -1,7 +1,12 @@
 
1
  import json
 
2
  from typing import Literal
3
 
 
4
  import yaml
 
 
5
 
6
  from app_configs import UNSELECTED_VAR_NAME
7
  from components import typed_dicts as td
@@ -22,24 +27,49 @@ def get_output_panel_state(workflow: Workflow) -> dict:
22
  return state
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  class PipelineStateManager:
26
  """Manages a pipeline of multiple steps."""
27
 
 
 
 
 
 
 
28
  def make_pipeline_state(self, state_dict: td.PipelineStateDict) -> PipelineState:
29
  """Make a state from a state dictionary."""
30
- return PipelineState(**state_dict)
31
-
32
- def get_formatted_config(self, state_dict: td.PipelineStateDict, format: Literal["json", "yaml"] = "yaml") -> str:
33
- """Get the full pipeline configuration."""
34
- state = self.make_pipeline_state(state_dict)
35
- config = state.workflow.model_dump(exclude_defaults=True)
36
- if isinstance(state.workflow, TossupWorkflow):
37
- buzzer_config = state.workflow.buzzer.model_dump(exclude_defaults=False)
38
- config["buzzer"] = buzzer_config
39
- if format == "yaml":
40
- return yaml.dump(config, default_flow_style=False, sort_keys=False, indent=4)
41
- else:
42
- return json.dumps(config, indent=4, sort_keys=False)
43
 
44
  def add_step(
45
  self, state_dict: td.PipelineStateDict, pipeline_change: bool, position: int = -1, name=""
@@ -102,7 +132,7 @@ class PipelineStateManager:
102
  produced_variable = None
103
  """Update the output variables for a step."""
104
  state = self.make_pipeline_state(state_dict)
105
- state.workflow.outputs[target] = produced_variable
106
  return state.model_dump()
107
 
108
  def update_model_step_ui(
@@ -117,53 +147,101 @@ class PipelineStateManager:
117
  """Get all variables from all steps."""
118
  return self.make_pipeline_state(state_dict)
119
 
120
- def parse_yaml_workflow(self, yaml_str: str) -> Workflow:
121
  """Parse a YAML workflow."""
122
  workflow = yaml.safe_load(yaml_str)
123
- return Workflow(**workflow)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
- def update_workflow_from_code(self, yaml_str: str) -> td.PipelineStateDict:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  """Update a workflow from a YAML string."""
127
- workflow = self.parse_yaml_workflow(yaml_str)
128
- return PipelineState.from_workflow(workflow).model_dump()
 
 
 
 
 
 
129
 
130
 
131
  class TossupPipelineStateManager(PipelineStateManager):
132
  """Manages a tossup pipeline state."""
133
 
134
- def make_pipeline_state(self, state_dict: td.PipelineStateDict) -> TossupPipelineState:
135
- """Make a state from a state dictionary."""
136
- return TossupPipelineState(**state_dict)
137
 
138
- def parse_yaml_workflow(self, yaml_str: str) -> TossupWorkflow:
139
- """Parse a YAML workflow."""
140
- workflow = yaml.safe_load(yaml_str)
141
- return TossupWorkflow(**workflow)
142
 
143
- def update_workflow_from_code(self, yaml_str: str, change_state: bool) -> tuple[td.PipelineStateDict, bool]:
144
- """Update a workflow from a YAML string."""
145
- workflow = self.parse_yaml_workflow(yaml_str)
146
- return TossupPipelineState.from_workflow(workflow).model_dump(), not change_state
147
 
148
  def update_model_step_state(
149
  self, state_dict: td.TossupPipelineStateDict, model_step: ModelStep, ui_state: ModelStepUIState
150
  ) -> td.TossupPipelineStateDict:
151
- """Update a particular model step in the pipeline."""
152
- state = self.make_pipeline_state(state_dict)
153
- state = state.update_step(model_step, ui_state)
154
- state.workflow = state.workflow.refresh_buzzer()
155
- return state.model_dump()
156
 
157
  def update_output_variables(
158
  self, state_dict: td.TossupPipelineStateDict, target: str, produced_variable: str
159
  ) -> td.TossupPipelineStateDict:
160
- if produced_variable == UNSELECTED_VAR_NAME:
161
- produced_variable = None
162
- """Update the output variables for a step."""
163
- state = self.make_pipeline_state(state_dict)
164
- state.workflow.outputs[target] = produced_variable
165
- state.workflow = state.workflow.refresh_buzzer()
166
- return state.model_dump()
167
 
168
  def update_buzzer(
169
  self,
 
1
+ # %%
2
  import json
3
+ from abc import ABC, abstractmethod
4
  from typing import Literal
5
 
6
+ import gradio as gr
7
  import yaml
8
+ from loguru import logger
9
+ from pydantic import BaseModel, ValidationError
10
 
11
  from app_configs import UNSELECTED_VAR_NAME
12
  from components import typed_dicts as td
 
27
  return state
28
 
29
 
30
+ def strict_model_validate(model_cls: type[BaseModel], data: dict):
31
+ # Dynamically create a subclass with extra='forbid'
32
+ class_name = model_cls.__name__
33
+ strict_class_name = f"Strict{class_name}"
34
+
35
+ strict_class = type(
36
+ strict_class_name,
37
+ (model_cls,),
38
+ {"model_config": {**getattr(model_cls, "model_config", {}), "extra": "forbid"}},
39
+ )
40
+
41
+ return strict_class.model_validate(data)
42
+
43
+
44
+ class PipelineValidator(ABC):
45
+ """Abstract base class for pipeline validators."""
46
+
47
+ @abstractmethod
48
+ def __call__(self, workflow: Workflow):
49
+ """
50
+ Validate the workflow.
51
+
52
+ Args:
53
+ workflow: The workflow to validate.
54
+
55
+ Raises:
56
+ ValueError: If the workflow is invalid.
57
+ """
58
+ pass
59
+
60
+
61
  class PipelineStateManager:
62
  """Manages a pipeline of multiple steps."""
63
 
64
+ pipeline_state_cls = PipelineState
65
+ workflow_cls = Workflow
66
+
67
+ def __init__(self, validator: PipelineValidator | None = None):
68
+ self.validator = validator
69
+
70
  def make_pipeline_state(self, state_dict: td.PipelineStateDict) -> PipelineState:
71
  """Make a state from a state dictionary."""
72
+ return self.pipeline_state_cls(**state_dict)
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  def add_step(
75
  self, state_dict: td.PipelineStateDict, pipeline_change: bool, position: int = -1, name=""
 
132
  produced_variable = None
133
  """Update the output variables for a step."""
134
  state = self.make_pipeline_state(state_dict)
135
+ state = state.update_output_variable(target, produced_variable)
136
  return state.model_dump()
137
 
138
  def update_model_step_ui(
 
147
  """Get all variables from all steps."""
148
  return self.make_pipeline_state(state_dict)
149
 
150
+ def parse_yaml_workflow(self, yaml_str: str, strict: bool = True) -> Workflow:
151
  """Parse a YAML workflow."""
152
  workflow = yaml.safe_load(yaml_str)
153
+ try:
154
+ if strict:
155
+ return strict_model_validate(self.workflow_cls, workflow)
156
+ else:
157
+ return self.workflow_cls.model_validate(workflow)
158
+ except ValidationError as e:
159
+ new_exception = ValidationError.from_exception_data(
160
+ e.title.removeprefix("Strict"), e.errors(), input_type="json"
161
+ )
162
+ raise new_exception from e
163
+
164
+ def _handle_pipeline_parsing_error(self, e: Exception) -> str:
165
+ """Format error messages for pipeline parsing errors with consistent styling."""
166
+ error_template = """
167
+ <div class="md" style='color: #FF0000; background-color: #FFEEEE; padding: 10px; border-radius: 5px; border-left: 4px solid #FF0000;'>
168
+ <strong style='color: #FF0000;'>{error_type}:</strong> <br>
169
+ <div class="code-wrap">
170
+ <pre><code>{error_message}</code></pre>
171
+ </div>
172
+ {help_text}
173
+ </div>
174
+ """
175
+ logger.exception(e)
176
+ if isinstance(e, yaml.YAMLError):
177
+ error_type = "Invalid YAML Error"
178
+ help_text = "Refer to the <a href='https://spacelift.io/blog/yaml#basic-yaml-syntax' target='_blank'>YAML schema</a> for correct formatting."
179
+ elif isinstance(e, ValidationError):
180
+ error_type = "Pipeline Parsing Error"
181
+ help_text = "Refer to the <a href='https://mgor.info' target='_blank'>documentation</a> for the correct pipeline schema."
182
+ elif isinstance(e, ValueError):
183
+ error_type = "Pipeline Validation Error"
184
+ help_text = "Refer to the <a href='https://mgor.info' target='_blank'>documentation</a> for the correct pipeline schema."
185
+ else:
186
+ error_type = "Unexpected Error"
187
+ help_text = "Please report this issue to us at <a href='https://github.com/maharshi95/QANTA25/issues' target='_blank'>GitHub Issues</a>."
188
+
189
+ return error_template.format(error_type=error_type, error_message=str(e), help_text=help_text)
190
 
191
+ def get_formatted_config(
192
+ self, state_dict: td.PipelineStateDict, format: Literal["json", "yaml"] = "yaml"
193
+ ) -> tuple[str, dict]:
194
+ """Get the full pipeline configuration."""
195
+ try:
196
+ state = self.make_pipeline_state(state_dict)
197
+ config = state.workflow.model_dump(exclude_defaults=True)
198
+ if isinstance(state.workflow, TossupWorkflow):
199
+ buzzer_config = state.workflow.buzzer.model_dump(exclude_defaults=False)
200
+ config["buzzer"] = buzzer_config
201
+ if format == "yaml":
202
+ config_str = yaml.dump(config, default_flow_style=False, sort_keys=False, indent=4)
203
+ else:
204
+ config_str = json.dumps(config, indent=4, sort_keys=False)
205
+ return config_str, gr.update(visible=False)
206
+ except Exception as e:
207
+ error_message = self._handle_pipeline_parsing_error(e)
208
+ return gr.skip(), gr.update(value=error_message, visible=True)
209
+
210
+ def update_workflow_from_code(self, yaml_str: str, change_state: bool) -> tuple[td.PipelineStateDict, bool, dict]:
211
  """Update a workflow from a YAML string."""
212
+ try:
213
+ workflow = self.parse_yaml_workflow(yaml_str, strict=True)
214
+ self.validator and self.validator(workflow)
215
+ state = self.pipeline_state_cls.from_workflow(workflow)
216
+ return state.model_dump(), not change_state, gr.update(visible=False)
217
+ except Exception as e:
218
+ error_message = self._handle_pipeline_parsing_error(e)
219
+ return gr.skip(), gr.skip(), gr.update(value=error_message, visible=True)
220
 
221
 
222
  class TossupPipelineStateManager(PipelineStateManager):
223
  """Manages a tossup pipeline state."""
224
 
225
+ pipeline_state_cls = TossupPipelineState
226
+ workflow_cls = TossupWorkflow
 
227
 
228
+ def make_pipeline_state(self, state_dict: td.PipelineStateDict) -> TossupPipelineState:
229
+ return super().make_pipeline_state(state_dict)
 
 
230
 
231
+ def update_workflow_from_code(
232
+ self, yaml_str: str, change_state: bool
233
+ ) -> tuple[td.TossupPipelineStateDict, bool, dict]:
234
+ return super().update_workflow_from_code(yaml_str, change_state)
235
 
236
  def update_model_step_state(
237
  self, state_dict: td.TossupPipelineStateDict, model_step: ModelStep, ui_state: ModelStepUIState
238
  ) -> td.TossupPipelineStateDict:
239
+ return super().update_model_step_state(state_dict, model_step, ui_state)
 
 
 
 
240
 
241
  def update_output_variables(
242
  self, state_dict: td.TossupPipelineStateDict, target: str, produced_variable: str
243
  ) -> td.TossupPipelineStateDict:
244
+ return super().update_output_variables(state_dict, target, produced_variable)
 
 
 
 
 
 
245
 
246
  def update_buzzer(
247
  self,
src/components/model_pipeline/tossup_pipeline.py CHANGED
@@ -1,5 +1,4 @@
1
  import gradio as gr
2
- import numpy as np
3
  from loguru import logger
4
 
5
  from app_configs import AVAILABLE_MODELS, UNSELECTED_VAR_NAME
@@ -9,7 +8,8 @@ from components.typed_dicts import TossupPipelineStateDict
9
  from display.formatting import tiny_styled_warning
10
  from workflows.structs import Buzzer, TossupWorkflow
11
 
12
- from .model_pipeline import PipelineInterface, PipelineState, PipelineUIState
 
13
 
14
 
15
  def toggleable_slider(
@@ -40,8 +40,9 @@ class TossupPipelineInterface(PipelineInterface):
40
  ui_state: PipelineUIState | None = None,
41
  model_options: list[str] = None,
42
  config: dict = {},
 
43
  ):
44
- super().__init__(app, workflow, ui_state, model_options, config)
45
 
46
  self.buzzer_state = gr.State(workflow.buzzer.model_dump())
47
 
 
1
  import gradio as gr
 
2
  from loguru import logger
3
 
4
  from app_configs import AVAILABLE_MODELS, UNSELECTED_VAR_NAME
 
8
  from display.formatting import tiny_styled_warning
9
  from workflows.structs import Buzzer, TossupWorkflow
10
 
11
+ from .model_pipeline import PipelineInterface
12
+ from .state_manager import PipelineUIState, PipelineValidator
13
 
14
 
15
  def toggleable_slider(
 
40
  ui_state: PipelineUIState | None = None,
41
  model_options: list[str] = None,
42
  config: dict = {},
43
+ validator: PipelineValidator | None = None,
44
  ):
45
+ super().__init__(app, workflow, ui_state, model_options, config, validator)
46
 
47
  self.buzzer_state = gr.State(workflow.buzzer.model_dump())
48
 
src/components/quizbowl/bonus.py CHANGED
@@ -19,6 +19,7 @@ from workflows.qb_agents import QuizBowlBonusAgent
19
  from . import populate, validation
20
  from .plotting import create_bonus_confidence_plot, create_bonus_html
21
  from .utils import evaluate_prediction
 
22
 
23
 
24
  def process_bonus_results(results: list[dict]) -> pd.DataFrame:
@@ -105,6 +106,7 @@ class BonusInterface:
105
  ui_state=pipeline_state.ui_state,
106
  model_options=list(self.model_options.keys()),
107
  config=self.defaults,
 
108
  )
109
 
110
  def _render_qb_interface(self):
 
19
  from . import populate, validation
20
  from .plotting import create_bonus_confidence_plot, create_bonus_html
21
  from .utils import evaluate_prediction
22
+ from .validation import UserInputWorkflowValidator
23
 
24
 
25
  def process_bonus_results(results: list[dict]) -> pd.DataFrame:
 
106
  ui_state=pipeline_state.ui_state,
107
  model_options=list(self.model_options.keys()),
108
  config=self.defaults,
109
+ validator=UserInputWorkflowValidator("bonus"),
110
  )
111
 
112
  def _render_qb_interface(self):
src/components/quizbowl/tossup.py CHANGED
@@ -25,6 +25,7 @@ from .plotting import (
25
  prepare_tossup_results_df,
26
  )
27
  from .utils import evaluate_prediction
 
28
 
29
  # TODO: Error handling on run tossup and evaluate tossup and show correct messages
30
  # TODO: ^^ Same for Bonus
@@ -135,7 +136,7 @@ class TossupInterface:
135
  self.output_state = gr.State(value={})
136
  self.render()
137
 
138
- # ------------------------------------- LOAD PIPELINE STATE FROM BROWSER STATE -------------------------------------
139
 
140
  def load_presaved_pipeline_state(self, browser_state: dict, pipeline_change: bool):
141
  logger.debug(f"Loading presaved pipeline state from browser state:\n{json.dumps(browser_state, indent=4)}")
@@ -165,6 +166,7 @@ class TossupInterface:
165
  ui_state=pipeline_state.ui_state,
166
  model_options=list(self.model_options.keys()),
167
  config=self.defaults,
 
168
  )
169
 
170
  def _render_qb_interface(self):
 
25
  prepare_tossup_results_df,
26
  )
27
  from .utils import evaluate_prediction
28
+ from .validation import UserInputWorkflowValidator
29
 
30
  # TODO: Error handling on run tossup and evaluate tossup and show correct messages
31
  # TODO: ^^ Same for Bonus
 
136
  self.output_state = gr.State(value={})
137
  self.render()
138
 
139
+ # ------------------------------------- LOAD PIPELINE STATE FROM BROWSER STATE ------------------------------------
140
 
141
  def load_presaved_pipeline_state(self, browser_state: dict, pipeline_change: bool):
142
  logger.debug(f"Loading presaved pipeline state from browser state:\n{json.dumps(browser_state, indent=4)}")
 
166
  ui_state=pipeline_state.ui_state,
167
  model_options=list(self.model_options.keys()),
168
  config=self.defaults,
169
+ validator=UserInputWorkflowValidator("tossup"),
170
  )
171
 
172
  def _render_qb_interface(self):
src/components/quizbowl/validation.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from app_configs import CONFIGS
2
  from components.structs import PipelineState, TossupPipelineState
3
  from components.typed_dicts import PipelineStateDict, TossupPipelineStateDict
@@ -53,3 +55,31 @@ def validate_bonus_workflow(pipeline_state_dict: PipelineStateDict):
53
  CONFIGS["bonus"]["required_output_vars"],
54
  )
55
  return pipeline_state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal
2
+
3
  from app_configs import CONFIGS
4
  from components.structs import PipelineState, TossupPipelineState
5
  from components.typed_dicts import PipelineStateDict, TossupPipelineStateDict
 
55
  CONFIGS["bonus"]["required_output_vars"],
56
  )
57
  return pipeline_state
58
+
59
+
60
+ class UserInputWorkflowValidator:
61
+ def __init__(self, mode: Literal["tossup", "bonus"]):
62
+ self.required_input_vars = CONFIGS[mode]["required_input_vars"]
63
+ self.required_output_vars = CONFIGS[mode]["required_output_vars"]
64
+
65
+ def __call__(self, workflow: TossupWorkflow):
66
+ input_vars = set(workflow.inputs)
67
+ for req_var in self.required_input_vars:
68
+ if req_var not in input_vars:
69
+ default_str = "inputs:\n" + "\n".join([f"- {var}" for var in self.required_input_vars])
70
+ raise ValueError(
71
+ f"Missing required input variable: '{req_var}'. "
72
+ "\nDon't modify the 'inputs' field in the workflow. "
73
+ "Please set it back to:"
74
+ f"\n{default_str}"
75
+ )
76
+
77
+ output_vars = set(workflow.outputs)
78
+ for req_var in self.required_output_vars:
79
+ if req_var not in output_vars:
80
+ default_str = "[" + ", ".join([f"'{var}'" for var in self.required_output_vars]) + "]"
81
+ raise ValueError(
82
+ f"Missing required output variable: '{req_var}'. "
83
+ "\nDon't remove the keys from the 'outputs' field in the workflow. Only update their values."
84
+ f"\nMake sure you have values set for all the outputs: {default_str}"
85
+ )
src/components/structs.py CHANGED
@@ -143,6 +143,11 @@ class PipelineState(BaseModel):
143
  update["ui_state"] = self.ui_state.update_step(step.id, ui_state)
144
  return self.model_copy(update=update)
145
 
 
 
 
 
 
146
  def get_available_variables(self, model_step_id: str | None = None) -> list[str]:
147
  """Get all variables from all steps."""
148
  available_variables = self.available_variables
@@ -170,3 +175,14 @@ class PipelineState(BaseModel):
170
 
171
  class TossupPipelineState(PipelineState):
172
  workflow: TossupWorkflow
 
 
 
 
 
 
 
 
 
 
 
 
143
  update["ui_state"] = self.ui_state.update_step(step.id, ui_state)
144
  return self.model_copy(update=update)
145
 
146
+ def update_output_variable(self, target: str, produced_variable: str) -> "PipelineState":
147
+ """Update the output variables for a step."""
148
+ self.workflow.outputs[target] = produced_variable
149
+ return self
150
+
151
  def get_available_variables(self, model_step_id: str | None = None) -> list[str]:
152
  """Get all variables from all steps."""
153
  available_variables = self.available_variables
 
175
 
176
  class TossupPipelineState(PipelineState):
177
  workflow: TossupWorkflow
178
+
179
+ def update_step(self, step: ModelStep, ui_state: ModelStepUIState | None = None) -> "TossupPipelineState":
180
+ """Update a step in the pipeline."""
181
+ state = super().update_step(step, ui_state)
182
+ state.workflow = state.workflow.refresh_buzzer()
183
+ return state
184
+
185
+ def update_output_variable(self, target: str, produced_variable: str) -> "TossupPipelineState":
186
+ state = super().update_output_variable(target, produced_variable)
187
+ state.workflow = state.workflow.refresh_buzzer()
188
+ return state
src/workflows/validators.py CHANGED
@@ -13,7 +13,6 @@ SUPPORTED_TYPES = {"str", "int", "float", "bool", "list[str]", "list[int]", "lis
13
  MAX_FIELD_NAME_LENGTH = 50
14
  MAX_DESCRIPTION_LENGTH = 200
15
  MAX_SYSTEM_PROMPT_LENGTH = 4000
16
- MIN_TEMPERATURE = 0.0
17
  MAX_TEMPERATURE = 10.0
18
 
19
 
@@ -40,7 +39,7 @@ class ValidationError:
40
  field_name: Optional[str] = None
41
 
42
 
43
- class WorkflowValidationError(Exception):
44
  """Base class for workflow validation errors"""
45
 
46
  def __init__(self, errors: list[ValidationError]):
@@ -77,9 +76,18 @@ def create_step_dep_graph(workflow: Workflow) -> dict[str, set[str]]:
77
  class WorkflowValidator:
78
  """Validates workflows for correctness and consistency"""
79
 
80
- def __init__(self):
 
 
 
 
 
 
 
81
  self.errors: list[ValidationError] = []
82
  self.workflow: Optional[Workflow] = None
 
 
83
 
84
  def validate(self, workflow: Workflow) -> bool:
85
  """Main validation entry point"""
@@ -272,7 +280,7 @@ class WorkflowValidator:
272
  self.errors.append(
273
  ValidationError(
274
  ValidationErrorType.NAMING,
275
- f"Invalid step ID format: {step.id}. Must be a valid Python identifier.",
276
  step.id,
277
  )
278
  )
@@ -286,11 +294,11 @@ class WorkflowValidator:
286
  )
287
  return False
288
 
289
- if not MIN_TEMPERATURE <= step.temperature <= MAX_TEMPERATURE:
290
  self.errors.append(
291
  ValidationError(
292
  ValidationErrorType.RANGE,
293
- f"Temperature must be between {MIN_TEMPERATURE} and {MAX_TEMPERATURE}",
294
  step.id,
295
  )
296
  )
@@ -304,11 +312,11 @@ class WorkflowValidator:
304
  )
305
  return False
306
 
307
- if len(step.system_prompt) > MAX_SYSTEM_PROMPT_LENGTH:
308
  self.errors.append(
309
  ValidationError(
310
  ValidationErrorType.LENGTH,
311
- f"System prompt exceeds maximum length of {MAX_SYSTEM_PROMPT_LENGTH} characters",
312
  step.id,
313
  )
314
  )
@@ -365,22 +373,22 @@ class WorkflowValidator:
365
  return False
366
 
367
  # Validate field name length
368
- if len(field.name) > MAX_FIELD_NAME_LENGTH:
369
  self.errors.append(
370
  ValidationError(
371
  ValidationErrorType.LENGTH,
372
- f"Field name exceeds maximum length of {MAX_FIELD_NAME_LENGTH} characters",
373
  field_name=field.name,
374
  )
375
  )
376
  return False
377
 
378
  # Validate description length
379
- if len(field.description) > MAX_DESCRIPTION_LENGTH:
380
  self.errors.append(
381
  ValidationError(
382
  ValidationErrorType.LENGTH,
383
- f"Description exceeds maximum length of {MAX_DESCRIPTION_LENGTH} characters",
384
  field_name=field.name,
385
  )
386
  )
@@ -422,22 +430,22 @@ class WorkflowValidator:
422
  return False
423
 
424
  # Validate field name length
425
- if len(field.name) > MAX_FIELD_NAME_LENGTH:
426
  self.errors.append(
427
  ValidationError(
428
  ValidationErrorType.LENGTH,
429
- f"Field name exceeds maximum length of {MAX_FIELD_NAME_LENGTH} characters",
430
  field_name=field.name,
431
  )
432
  )
433
  return False
434
 
435
  # Validate description length
436
- if len(field.description) > MAX_DESCRIPTION_LENGTH:
437
  self.errors.append(
438
  ValidationError(
439
  ValidationErrorType.LENGTH,
440
- f"Description exceeds maximum length of {MAX_DESCRIPTION_LENGTH} characters",
441
  field_name=field.name,
442
  )
443
  )
@@ -545,10 +553,6 @@ class WorkflowValidator:
545
 
546
  def _is_valid_identifier(self, name: str) -> bool:
547
  """Validates if a string is a valid Python identifier"""
548
- if not name:
549
- return False
550
- if keyword.iskeyword(name):
551
- return False
552
- if not name.strip(): # Check for whitespace-only strings
553
- return False
554
- return bool(re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", name))
 
13
  MAX_FIELD_NAME_LENGTH = 50
14
  MAX_DESCRIPTION_LENGTH = 200
15
  MAX_SYSTEM_PROMPT_LENGTH = 4000
 
16
  MAX_TEMPERATURE = 10.0
17
 
18
 
 
39
  field_name: Optional[str] = None
40
 
41
 
42
+ class WorkflowValidationError(ValueError):
43
  """Base class for workflow validation errors"""
44
 
45
  def __init__(self, errors: list[ValidationError]):
 
76
  class WorkflowValidator:
77
  """Validates workflows for correctness and consistency"""
78
 
79
+ def __init__(
80
+ self,
81
+ min_temperature: float = 0,
82
+ max_temperature: float = MAX_TEMPERATURE,
83
+ max_field_name_length: int = MAX_FIELD_NAME_LENGTH,
84
+ max_description_length: int = MAX_DESCRIPTION_LENGTH,
85
+ max_system_prompt_length: int = MAX_SYSTEM_PROMPT_LENGTH,
86
+ ):
87
  self.errors: list[ValidationError] = []
88
  self.workflow: Optional[Workflow] = None
89
+ self.min_temperature = min_temperature
90
+ self.max_temperature = max_temperature
91
 
92
  def validate(self, workflow: Workflow) -> bool:
93
  """Main validation entry point"""
 
280
  self.errors.append(
281
  ValidationError(
282
  ValidationErrorType.NAMING,
283
+ f"Invalid step ID format: {step.id}. Must be a valid identifier.",
284
  step.id,
285
  )
286
  )
 
294
  )
295
  return False
296
 
297
+ if not self.min_temperature <= step.temperature <= self.max_temperature:
298
  self.errors.append(
299
  ValidationError(
300
  ValidationErrorType.RANGE,
301
+ f"Temperature must be between {self.min_temperature} and {self.max_temperature}",
302
  step.id,
303
  )
304
  )
 
312
  )
313
  return False
314
 
315
+ if len(step.system_prompt) > self.max_system_prompt_length:
316
  self.errors.append(
317
  ValidationError(
318
  ValidationErrorType.LENGTH,
319
+ f"System prompt exceeds maximum length of {self.max_system_prompt_length} characters",
320
  step.id,
321
  )
322
  )
 
373
  return False
374
 
375
  # Validate field name length
376
+ if len(field.name) > self.max_field_name_length:
377
  self.errors.append(
378
  ValidationError(
379
  ValidationErrorType.LENGTH,
380
+ f"Field name exceeds maximum length of {self.max_field_name_length} characters",
381
  field_name=field.name,
382
  )
383
  )
384
  return False
385
 
386
  # Validate description length
387
+ if len(field.description) > self.max_description_length:
388
  self.errors.append(
389
  ValidationError(
390
  ValidationErrorType.LENGTH,
391
+ f"Description exceeds maximum length of {self.max_description_length} characters",
392
  field_name=field.name,
393
  )
394
  )
 
430
  return False
431
 
432
  # Validate field name length
433
+ if len(field.name) > self.max_field_name_length:
434
  self.errors.append(
435
  ValidationError(
436
  ValidationErrorType.LENGTH,
437
+ f"Field name exceeds maximum length of {self.max_field_name_length} characters",
438
  field_name=field.name,
439
  )
440
  )
441
  return False
442
 
443
  # Validate description length
444
+ if len(field.description) > self.max_description_length:
445
  self.errors.append(
446
  ValidationError(
447
  ValidationErrorType.LENGTH,
448
+ f"Description exceeds maximum length of {self.max_description_length} characters",
449
  field_name=field.name,
450
  )
451
  )
 
553
 
554
  def _is_valid_identifier(self, name: str) -> bool:
555
  """Validates if a string is a valid Python identifier"""
556
+ if name and name.strip():
557
+ return bool(re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", name))
558
+ return False