Maharshi Gor
commited on
Commit
·
849566b
1
Parent(s):
c1ae336
Add user input validation to pipeline interfaces error display on pipeline change.
Browse files- src/components/model_pipeline/model_pipeline.py +13 -6
- src/components/model_pipeline/state_manager.py +120 -42
- src/components/model_pipeline/tossup_pipeline.py +4 -3
- src/components/quizbowl/bonus.py +2 -0
- src/components/quizbowl/tossup.py +3 -1
- src/components/quizbowl/validation.py +30 -0
- src/components/structs.py +16 -0
- src/workflows/validators.py +27 -23
src/components/model_pipeline/model_pipeline.py
CHANGED
@@ -10,13 +10,14 @@ from components.model_pipeline.state_manager import (
|
|
10 |
PipelineState,
|
11 |
PipelineStateManager,
|
12 |
PipelineUIState,
|
|
|
13 |
TossupPipelineState,
|
14 |
TossupPipelineStateManager,
|
15 |
)
|
16 |
from components.model_step.model_step import ModelStepComponent
|
17 |
from components.utils import make_state
|
18 |
from workflows.structs import ModelStep, TossupWorkflow, Workflow
|
19 |
-
from workflows.validators import WorkflowValidator
|
20 |
|
21 |
from .state_manager import get_output_panel_state
|
22 |
|
@@ -33,6 +34,7 @@ class PipelineInterface:
|
|
33 |
ui_state: PipelineUIState | None = None,
|
34 |
model_options: list[str] = None,
|
35 |
config: dict = {},
|
|
|
36 |
):
|
37 |
self.app = app
|
38 |
self.model_options = model_options
|
@@ -50,10 +52,10 @@ class PipelineInterface:
|
|
50 |
|
51 |
if isinstance(workflow, TossupWorkflow):
|
52 |
pipeline_state = TossupPipelineState(workflow=workflow, ui_state=ui_state)
|
53 |
-
self.sm = TossupPipelineStateManager()
|
54 |
else:
|
55 |
pipeline_state = PipelineState(workflow=workflow, ui_state=ui_state)
|
56 |
-
self.sm = PipelineStateManager()
|
57 |
self.pipeline_state = make_state(pipeline_state.model_dump())
|
58 |
|
59 |
def get_aux_states(pipeline_state_dict: td.PipelineStateDict):
|
@@ -169,7 +171,11 @@ class PipelineInterface:
|
|
169 |
"""Validate the workflow."""
|
170 |
try:
|
171 |
state = self.sm.make_pipeline_state(state_dict)
|
172 |
-
WorkflowValidator(
|
|
|
|
|
|
|
|
|
173 |
except ValueError as e:
|
174 |
logger.exception(e)
|
175 |
state_dict_str = yaml.dump(state_dict, default_flow_style=False, indent=2)
|
@@ -244,6 +250,7 @@ class PipelineInterface:
|
|
244 |
def _render_pipeline_preview(self):
|
245 |
export_btn = gr.Button("Export Pipeline", elem_classes="export-button", visible=False)
|
246 |
# components.append(export_btn)
|
|
|
247 |
|
248 |
# Add a code box to display the workflow JSON
|
249 |
# with gr.Column(elem_classes="workflow-json-container"):
|
@@ -262,7 +269,7 @@ class PipelineInterface:
|
|
262 |
self.config_output.blur(
|
263 |
fn=self.sm.update_workflow_from_code,
|
264 |
inputs=[self.config_output, self.pipeline_change],
|
265 |
-
outputs=[self.pipeline_state, self.pipeline_change],
|
266 |
)
|
267 |
|
268 |
# Connect the export button to show the workflow JSON
|
@@ -326,6 +333,6 @@ class PipelineInterface:
|
|
326 |
).success(
|
327 |
fn=self.sm.get_formatted_config,
|
328 |
inputs=[self.pipeline_state, gr.State("yaml")],
|
329 |
-
outputs=[self.config_output],
|
330 |
js=js,
|
331 |
)
|
|
|
10 |
PipelineState,
|
11 |
PipelineStateManager,
|
12 |
PipelineUIState,
|
13 |
+
PipelineValidator,
|
14 |
TossupPipelineState,
|
15 |
TossupPipelineStateManager,
|
16 |
)
|
17 |
from components.model_step.model_step import ModelStepComponent
|
18 |
from components.utils import make_state
|
19 |
from workflows.structs import ModelStep, TossupWorkflow, Workflow
|
20 |
+
from workflows.validators import WorkflowValidationError, WorkflowValidator
|
21 |
|
22 |
from .state_manager import get_output_panel_state
|
23 |
|
|
|
34 |
ui_state: PipelineUIState | None = None,
|
35 |
model_options: list[str] = None,
|
36 |
config: dict = {},
|
37 |
+
validator: PipelineValidator | None = None,
|
38 |
):
|
39 |
self.app = app
|
40 |
self.model_options = model_options
|
|
|
52 |
|
53 |
if isinstance(workflow, TossupWorkflow):
|
54 |
pipeline_state = TossupPipelineState(workflow=workflow, ui_state=ui_state)
|
55 |
+
self.sm = TossupPipelineStateManager(validator)
|
56 |
else:
|
57 |
pipeline_state = PipelineState(workflow=workflow, ui_state=ui_state)
|
58 |
+
self.sm = PipelineStateManager(validator)
|
59 |
self.pipeline_state = make_state(pipeline_state.model_dump())
|
60 |
|
61 |
def get_aux_states(pipeline_state_dict: td.PipelineStateDict):
|
|
|
171 |
"""Validate the workflow."""
|
172 |
try:
|
173 |
state = self.sm.make_pipeline_state(state_dict)
|
174 |
+
validator = WorkflowValidator(
|
175 |
+
max_temperature=self.config.get("max_temperature", 10),
|
176 |
+
)
|
177 |
+
if not validator.validate(state.workflow):
|
178 |
+
raise WorkflowValidationError(validator.errors)
|
179 |
except ValueError as e:
|
180 |
logger.exception(e)
|
181 |
state_dict_str = yaml.dump(state_dict, default_flow_style=False, indent=2)
|
|
|
250 |
def _render_pipeline_preview(self):
|
251 |
export_btn = gr.Button("Export Pipeline", elem_classes="export-button", visible=False)
|
252 |
# components.append(export_btn)
|
253 |
+
self.error_display = gr.HTML(label="Error", elem_id="pipeline-preview-error-display", visible=False)
|
254 |
|
255 |
# Add a code box to display the workflow JSON
|
256 |
# with gr.Column(elem_classes="workflow-json-container"):
|
|
|
269 |
self.config_output.blur(
|
270 |
fn=self.sm.update_workflow_from_code,
|
271 |
inputs=[self.config_output, self.pipeline_change],
|
272 |
+
outputs=[self.pipeline_state, self.pipeline_change, self.error_display],
|
273 |
)
|
274 |
|
275 |
# Connect the export button to show the workflow JSON
|
|
|
333 |
).success(
|
334 |
fn=self.sm.get_formatted_config,
|
335 |
inputs=[self.pipeline_state, gr.State("yaml")],
|
336 |
+
outputs=[self.config_output, self.error_display],
|
337 |
js=js,
|
338 |
)
|
src/components/model_pipeline/state_manager.py
CHANGED
@@ -1,7 +1,12 @@
|
|
|
|
1 |
import json
|
|
|
2 |
from typing import Literal
|
3 |
|
|
|
4 |
import yaml
|
|
|
|
|
5 |
|
6 |
from app_configs import UNSELECTED_VAR_NAME
|
7 |
from components import typed_dicts as td
|
@@ -22,24 +27,49 @@ def get_output_panel_state(workflow: Workflow) -> dict:
|
|
22 |
return state
|
23 |
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
class PipelineStateManager:
|
26 |
"""Manages a pipeline of multiple steps."""
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
def make_pipeline_state(self, state_dict: td.PipelineStateDict) -> PipelineState:
|
29 |
"""Make a state from a state dictionary."""
|
30 |
-
return
|
31 |
-
|
32 |
-
def get_formatted_config(self, state_dict: td.PipelineStateDict, format: Literal["json", "yaml"] = "yaml") -> str:
|
33 |
-
"""Get the full pipeline configuration."""
|
34 |
-
state = self.make_pipeline_state(state_dict)
|
35 |
-
config = state.workflow.model_dump(exclude_defaults=True)
|
36 |
-
if isinstance(state.workflow, TossupWorkflow):
|
37 |
-
buzzer_config = state.workflow.buzzer.model_dump(exclude_defaults=False)
|
38 |
-
config["buzzer"] = buzzer_config
|
39 |
-
if format == "yaml":
|
40 |
-
return yaml.dump(config, default_flow_style=False, sort_keys=False, indent=4)
|
41 |
-
else:
|
42 |
-
return json.dumps(config, indent=4, sort_keys=False)
|
43 |
|
44 |
def add_step(
|
45 |
self, state_dict: td.PipelineStateDict, pipeline_change: bool, position: int = -1, name=""
|
@@ -102,7 +132,7 @@ class PipelineStateManager:
|
|
102 |
produced_variable = None
|
103 |
"""Update the output variables for a step."""
|
104 |
state = self.make_pipeline_state(state_dict)
|
105 |
-
state.
|
106 |
return state.model_dump()
|
107 |
|
108 |
def update_model_step_ui(
|
@@ -117,53 +147,101 @@ class PipelineStateManager:
|
|
117 |
"""Get all variables from all steps."""
|
118 |
return self.make_pipeline_state(state_dict)
|
119 |
|
120 |
-
def parse_yaml_workflow(self, yaml_str: str) -> Workflow:
|
121 |
"""Parse a YAML workflow."""
|
122 |
workflow = yaml.safe_load(yaml_str)
|
123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
|
125 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
"""Update a workflow from a YAML string."""
|
127 |
-
|
128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
|
130 |
|
131 |
class TossupPipelineStateManager(PipelineStateManager):
|
132 |
"""Manages a tossup pipeline state."""
|
133 |
|
134 |
-
|
135 |
-
|
136 |
-
return TossupPipelineState(**state_dict)
|
137 |
|
138 |
-
def
|
139 |
-
|
140 |
-
workflow = yaml.safe_load(yaml_str)
|
141 |
-
return TossupWorkflow(**workflow)
|
142 |
|
143 |
-
def update_workflow_from_code(
|
144 |
-
|
145 |
-
|
146 |
-
return
|
147 |
|
148 |
def update_model_step_state(
|
149 |
self, state_dict: td.TossupPipelineStateDict, model_step: ModelStep, ui_state: ModelStepUIState
|
150 |
) -> td.TossupPipelineStateDict:
|
151 |
-
|
152 |
-
state = self.make_pipeline_state(state_dict)
|
153 |
-
state = state.update_step(model_step, ui_state)
|
154 |
-
state.workflow = state.workflow.refresh_buzzer()
|
155 |
-
return state.model_dump()
|
156 |
|
157 |
def update_output_variables(
|
158 |
self, state_dict: td.TossupPipelineStateDict, target: str, produced_variable: str
|
159 |
) -> td.TossupPipelineStateDict:
|
160 |
-
|
161 |
-
produced_variable = None
|
162 |
-
"""Update the output variables for a step."""
|
163 |
-
state = self.make_pipeline_state(state_dict)
|
164 |
-
state.workflow.outputs[target] = produced_variable
|
165 |
-
state.workflow = state.workflow.refresh_buzzer()
|
166 |
-
return state.model_dump()
|
167 |
|
168 |
def update_buzzer(
|
169 |
self,
|
|
|
1 |
+
# %%
|
2 |
import json
|
3 |
+
from abc import ABC, abstractmethod
|
4 |
from typing import Literal
|
5 |
|
6 |
+
import gradio as gr
|
7 |
import yaml
|
8 |
+
from loguru import logger
|
9 |
+
from pydantic import BaseModel, ValidationError
|
10 |
|
11 |
from app_configs import UNSELECTED_VAR_NAME
|
12 |
from components import typed_dicts as td
|
|
|
27 |
return state
|
28 |
|
29 |
|
30 |
+
def strict_model_validate(model_cls: type[BaseModel], data: dict):
|
31 |
+
# Dynamically create a subclass with extra='forbid'
|
32 |
+
class_name = model_cls.__name__
|
33 |
+
strict_class_name = f"Strict{class_name}"
|
34 |
+
|
35 |
+
strict_class = type(
|
36 |
+
strict_class_name,
|
37 |
+
(model_cls,),
|
38 |
+
{"model_config": {**getattr(model_cls, "model_config", {}), "extra": "forbid"}},
|
39 |
+
)
|
40 |
+
|
41 |
+
return strict_class.model_validate(data)
|
42 |
+
|
43 |
+
|
44 |
+
class PipelineValidator(ABC):
|
45 |
+
"""Abstract base class for pipeline validators."""
|
46 |
+
|
47 |
+
@abstractmethod
|
48 |
+
def __call__(self, workflow: Workflow):
|
49 |
+
"""
|
50 |
+
Validate the workflow.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
workflow: The workflow to validate.
|
54 |
+
|
55 |
+
Raises:
|
56 |
+
ValueError: If the workflow is invalid.
|
57 |
+
"""
|
58 |
+
pass
|
59 |
+
|
60 |
+
|
61 |
class PipelineStateManager:
|
62 |
"""Manages a pipeline of multiple steps."""
|
63 |
|
64 |
+
pipeline_state_cls = PipelineState
|
65 |
+
workflow_cls = Workflow
|
66 |
+
|
67 |
+
def __init__(self, validator: PipelineValidator | None = None):
|
68 |
+
self.validator = validator
|
69 |
+
|
70 |
def make_pipeline_state(self, state_dict: td.PipelineStateDict) -> PipelineState:
|
71 |
"""Make a state from a state dictionary."""
|
72 |
+
return self.pipeline_state_cls(**state_dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
def add_step(
|
75 |
self, state_dict: td.PipelineStateDict, pipeline_change: bool, position: int = -1, name=""
|
|
|
132 |
produced_variable = None
|
133 |
"""Update the output variables for a step."""
|
134 |
state = self.make_pipeline_state(state_dict)
|
135 |
+
state = state.update_output_variable(target, produced_variable)
|
136 |
return state.model_dump()
|
137 |
|
138 |
def update_model_step_ui(
|
|
|
147 |
"""Get all variables from all steps."""
|
148 |
return self.make_pipeline_state(state_dict)
|
149 |
|
150 |
+
def parse_yaml_workflow(self, yaml_str: str, strict: bool = True) -> Workflow:
|
151 |
"""Parse a YAML workflow."""
|
152 |
workflow = yaml.safe_load(yaml_str)
|
153 |
+
try:
|
154 |
+
if strict:
|
155 |
+
return strict_model_validate(self.workflow_cls, workflow)
|
156 |
+
else:
|
157 |
+
return self.workflow_cls.model_validate(workflow)
|
158 |
+
except ValidationError as e:
|
159 |
+
new_exception = ValidationError.from_exception_data(
|
160 |
+
e.title.removeprefix("Strict"), e.errors(), input_type="json"
|
161 |
+
)
|
162 |
+
raise new_exception from e
|
163 |
+
|
164 |
+
def _handle_pipeline_parsing_error(self, e: Exception) -> str:
|
165 |
+
"""Format error messages for pipeline parsing errors with consistent styling."""
|
166 |
+
error_template = """
|
167 |
+
<div class="md" style='color: #FF0000; background-color: #FFEEEE; padding: 10px; border-radius: 5px; border-left: 4px solid #FF0000;'>
|
168 |
+
<strong style='color: #FF0000;'>{error_type}:</strong> <br>
|
169 |
+
<div class="code-wrap">
|
170 |
+
<pre><code>{error_message}</code></pre>
|
171 |
+
</div>
|
172 |
+
{help_text}
|
173 |
+
</div>
|
174 |
+
"""
|
175 |
+
logger.exception(e)
|
176 |
+
if isinstance(e, yaml.YAMLError):
|
177 |
+
error_type = "Invalid YAML Error"
|
178 |
+
help_text = "Refer to the <a href='https://spacelift.io/blog/yaml#basic-yaml-syntax' target='_blank'>YAML schema</a> for correct formatting."
|
179 |
+
elif isinstance(e, ValidationError):
|
180 |
+
error_type = "Pipeline Parsing Error"
|
181 |
+
help_text = "Refer to the <a href='https://mgor.info' target='_blank'>documentation</a> for the correct pipeline schema."
|
182 |
+
elif isinstance(e, ValueError):
|
183 |
+
error_type = "Pipeline Validation Error"
|
184 |
+
help_text = "Refer to the <a href='https://mgor.info' target='_blank'>documentation</a> for the correct pipeline schema."
|
185 |
+
else:
|
186 |
+
error_type = "Unexpected Error"
|
187 |
+
help_text = "Please report this issue to us at <a href='https://github.com/maharshi95/QANTA25/issues' target='_blank'>GitHub Issues</a>."
|
188 |
+
|
189 |
+
return error_template.format(error_type=error_type, error_message=str(e), help_text=help_text)
|
190 |
|
191 |
+
def get_formatted_config(
|
192 |
+
self, state_dict: td.PipelineStateDict, format: Literal["json", "yaml"] = "yaml"
|
193 |
+
) -> tuple[str, dict]:
|
194 |
+
"""Get the full pipeline configuration."""
|
195 |
+
try:
|
196 |
+
state = self.make_pipeline_state(state_dict)
|
197 |
+
config = state.workflow.model_dump(exclude_defaults=True)
|
198 |
+
if isinstance(state.workflow, TossupWorkflow):
|
199 |
+
buzzer_config = state.workflow.buzzer.model_dump(exclude_defaults=False)
|
200 |
+
config["buzzer"] = buzzer_config
|
201 |
+
if format == "yaml":
|
202 |
+
config_str = yaml.dump(config, default_flow_style=False, sort_keys=False, indent=4)
|
203 |
+
else:
|
204 |
+
config_str = json.dumps(config, indent=4, sort_keys=False)
|
205 |
+
return config_str, gr.update(visible=False)
|
206 |
+
except Exception as e:
|
207 |
+
error_message = self._handle_pipeline_parsing_error(e)
|
208 |
+
return gr.skip(), gr.update(value=error_message, visible=True)
|
209 |
+
|
210 |
+
def update_workflow_from_code(self, yaml_str: str, change_state: bool) -> tuple[td.PipelineStateDict, bool, dict]:
|
211 |
"""Update a workflow from a YAML string."""
|
212 |
+
try:
|
213 |
+
workflow = self.parse_yaml_workflow(yaml_str, strict=True)
|
214 |
+
self.validator and self.validator(workflow)
|
215 |
+
state = self.pipeline_state_cls.from_workflow(workflow)
|
216 |
+
return state.model_dump(), not change_state, gr.update(visible=False)
|
217 |
+
except Exception as e:
|
218 |
+
error_message = self._handle_pipeline_parsing_error(e)
|
219 |
+
return gr.skip(), gr.skip(), gr.update(value=error_message, visible=True)
|
220 |
|
221 |
|
222 |
class TossupPipelineStateManager(PipelineStateManager):
|
223 |
"""Manages a tossup pipeline state."""
|
224 |
|
225 |
+
pipeline_state_cls = TossupPipelineState
|
226 |
+
workflow_cls = TossupWorkflow
|
|
|
227 |
|
228 |
+
def make_pipeline_state(self, state_dict: td.PipelineStateDict) -> TossupPipelineState:
|
229 |
+
return super().make_pipeline_state(state_dict)
|
|
|
|
|
230 |
|
231 |
+
def update_workflow_from_code(
|
232 |
+
self, yaml_str: str, change_state: bool
|
233 |
+
) -> tuple[td.TossupPipelineStateDict, bool, dict]:
|
234 |
+
return super().update_workflow_from_code(yaml_str, change_state)
|
235 |
|
236 |
def update_model_step_state(
|
237 |
self, state_dict: td.TossupPipelineStateDict, model_step: ModelStep, ui_state: ModelStepUIState
|
238 |
) -> td.TossupPipelineStateDict:
|
239 |
+
return super().update_model_step_state(state_dict, model_step, ui_state)
|
|
|
|
|
|
|
|
|
240 |
|
241 |
def update_output_variables(
|
242 |
self, state_dict: td.TossupPipelineStateDict, target: str, produced_variable: str
|
243 |
) -> td.TossupPipelineStateDict:
|
244 |
+
return super().update_output_variables(state_dict, target, produced_variable)
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
|
246 |
def update_buzzer(
|
247 |
self,
|
src/components/model_pipeline/tossup_pipeline.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
import gradio as gr
|
2 |
-
import numpy as np
|
3 |
from loguru import logger
|
4 |
|
5 |
from app_configs import AVAILABLE_MODELS, UNSELECTED_VAR_NAME
|
@@ -9,7 +8,8 @@ from components.typed_dicts import TossupPipelineStateDict
|
|
9 |
from display.formatting import tiny_styled_warning
|
10 |
from workflows.structs import Buzzer, TossupWorkflow
|
11 |
|
12 |
-
from .model_pipeline import PipelineInterface
|
|
|
13 |
|
14 |
|
15 |
def toggleable_slider(
|
@@ -40,8 +40,9 @@ class TossupPipelineInterface(PipelineInterface):
|
|
40 |
ui_state: PipelineUIState | None = None,
|
41 |
model_options: list[str] = None,
|
42 |
config: dict = {},
|
|
|
43 |
):
|
44 |
-
super().__init__(app, workflow, ui_state, model_options, config)
|
45 |
|
46 |
self.buzzer_state = gr.State(workflow.buzzer.model_dump())
|
47 |
|
|
|
1 |
import gradio as gr
|
|
|
2 |
from loguru import logger
|
3 |
|
4 |
from app_configs import AVAILABLE_MODELS, UNSELECTED_VAR_NAME
|
|
|
8 |
from display.formatting import tiny_styled_warning
|
9 |
from workflows.structs import Buzzer, TossupWorkflow
|
10 |
|
11 |
+
from .model_pipeline import PipelineInterface
|
12 |
+
from .state_manager import PipelineUIState, PipelineValidator
|
13 |
|
14 |
|
15 |
def toggleable_slider(
|
|
|
40 |
ui_state: PipelineUIState | None = None,
|
41 |
model_options: list[str] = None,
|
42 |
config: dict = {},
|
43 |
+
validator: PipelineValidator | None = None,
|
44 |
):
|
45 |
+
super().__init__(app, workflow, ui_state, model_options, config, validator)
|
46 |
|
47 |
self.buzzer_state = gr.State(workflow.buzzer.model_dump())
|
48 |
|
src/components/quizbowl/bonus.py
CHANGED
@@ -19,6 +19,7 @@ from workflows.qb_agents import QuizBowlBonusAgent
|
|
19 |
from . import populate, validation
|
20 |
from .plotting import create_bonus_confidence_plot, create_bonus_html
|
21 |
from .utils import evaluate_prediction
|
|
|
22 |
|
23 |
|
24 |
def process_bonus_results(results: list[dict]) -> pd.DataFrame:
|
@@ -105,6 +106,7 @@ class BonusInterface:
|
|
105 |
ui_state=pipeline_state.ui_state,
|
106 |
model_options=list(self.model_options.keys()),
|
107 |
config=self.defaults,
|
|
|
108 |
)
|
109 |
|
110 |
def _render_qb_interface(self):
|
|
|
19 |
from . import populate, validation
|
20 |
from .plotting import create_bonus_confidence_plot, create_bonus_html
|
21 |
from .utils import evaluate_prediction
|
22 |
+
from .validation import UserInputWorkflowValidator
|
23 |
|
24 |
|
25 |
def process_bonus_results(results: list[dict]) -> pd.DataFrame:
|
|
|
106 |
ui_state=pipeline_state.ui_state,
|
107 |
model_options=list(self.model_options.keys()),
|
108 |
config=self.defaults,
|
109 |
+
validator=UserInputWorkflowValidator("bonus"),
|
110 |
)
|
111 |
|
112 |
def _render_qb_interface(self):
|
src/components/quizbowl/tossup.py
CHANGED
@@ -25,6 +25,7 @@ from .plotting import (
|
|
25 |
prepare_tossup_results_df,
|
26 |
)
|
27 |
from .utils import evaluate_prediction
|
|
|
28 |
|
29 |
# TODO: Error handling on run tossup and evaluate tossup and show correct messages
|
30 |
# TODO: ^^ Same for Bonus
|
@@ -135,7 +136,7 @@ class TossupInterface:
|
|
135 |
self.output_state = gr.State(value={})
|
136 |
self.render()
|
137 |
|
138 |
-
# ------------------------------------- LOAD PIPELINE STATE FROM BROWSER STATE
|
139 |
|
140 |
def load_presaved_pipeline_state(self, browser_state: dict, pipeline_change: bool):
|
141 |
logger.debug(f"Loading presaved pipeline state from browser state:\n{json.dumps(browser_state, indent=4)}")
|
@@ -165,6 +166,7 @@ class TossupInterface:
|
|
165 |
ui_state=pipeline_state.ui_state,
|
166 |
model_options=list(self.model_options.keys()),
|
167 |
config=self.defaults,
|
|
|
168 |
)
|
169 |
|
170 |
def _render_qb_interface(self):
|
|
|
25 |
prepare_tossup_results_df,
|
26 |
)
|
27 |
from .utils import evaluate_prediction
|
28 |
+
from .validation import UserInputWorkflowValidator
|
29 |
|
30 |
# TODO: Error handling on run tossup and evaluate tossup and show correct messages
|
31 |
# TODO: ^^ Same for Bonus
|
|
|
136 |
self.output_state = gr.State(value={})
|
137 |
self.render()
|
138 |
|
139 |
+
# ------------------------------------- LOAD PIPELINE STATE FROM BROWSER STATE ------------------------------------
|
140 |
|
141 |
def load_presaved_pipeline_state(self, browser_state: dict, pipeline_change: bool):
|
142 |
logger.debug(f"Loading presaved pipeline state from browser state:\n{json.dumps(browser_state, indent=4)}")
|
|
|
166 |
ui_state=pipeline_state.ui_state,
|
167 |
model_options=list(self.model_options.keys()),
|
168 |
config=self.defaults,
|
169 |
+
validator=UserInputWorkflowValidator("tossup"),
|
170 |
)
|
171 |
|
172 |
def _render_qb_interface(self):
|
src/components/quizbowl/validation.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
from app_configs import CONFIGS
|
2 |
from components.structs import PipelineState, TossupPipelineState
|
3 |
from components.typed_dicts import PipelineStateDict, TossupPipelineStateDict
|
@@ -53,3 +55,31 @@ def validate_bonus_workflow(pipeline_state_dict: PipelineStateDict):
|
|
53 |
CONFIGS["bonus"]["required_output_vars"],
|
54 |
)
|
55 |
return pipeline_state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal
|
2 |
+
|
3 |
from app_configs import CONFIGS
|
4 |
from components.structs import PipelineState, TossupPipelineState
|
5 |
from components.typed_dicts import PipelineStateDict, TossupPipelineStateDict
|
|
|
55 |
CONFIGS["bonus"]["required_output_vars"],
|
56 |
)
|
57 |
return pipeline_state
|
58 |
+
|
59 |
+
|
60 |
+
class UserInputWorkflowValidator:
|
61 |
+
def __init__(self, mode: Literal["tossup", "bonus"]):
|
62 |
+
self.required_input_vars = CONFIGS[mode]["required_input_vars"]
|
63 |
+
self.required_output_vars = CONFIGS[mode]["required_output_vars"]
|
64 |
+
|
65 |
+
def __call__(self, workflow: TossupWorkflow):
|
66 |
+
input_vars = set(workflow.inputs)
|
67 |
+
for req_var in self.required_input_vars:
|
68 |
+
if req_var not in input_vars:
|
69 |
+
default_str = "inputs:\n" + "\n".join([f"- {var}" for var in self.required_input_vars])
|
70 |
+
raise ValueError(
|
71 |
+
f"Missing required input variable: '{req_var}'. "
|
72 |
+
"\nDon't modify the 'inputs' field in the workflow. "
|
73 |
+
"Please set it back to:"
|
74 |
+
f"\n{default_str}"
|
75 |
+
)
|
76 |
+
|
77 |
+
output_vars = set(workflow.outputs)
|
78 |
+
for req_var in self.required_output_vars:
|
79 |
+
if req_var not in output_vars:
|
80 |
+
default_str = "[" + ", ".join([f"'{var}'" for var in self.required_output_vars]) + "]"
|
81 |
+
raise ValueError(
|
82 |
+
f"Missing required output variable: '{req_var}'. "
|
83 |
+
"\nDon't remove the keys from the 'outputs' field in the workflow. Only update their values."
|
84 |
+
f"\nMake sure you have values set for all the outputs: {default_str}"
|
85 |
+
)
|
src/components/structs.py
CHANGED
@@ -143,6 +143,11 @@ class PipelineState(BaseModel):
|
|
143 |
update["ui_state"] = self.ui_state.update_step(step.id, ui_state)
|
144 |
return self.model_copy(update=update)
|
145 |
|
|
|
|
|
|
|
|
|
|
|
146 |
def get_available_variables(self, model_step_id: str | None = None) -> list[str]:
|
147 |
"""Get all variables from all steps."""
|
148 |
available_variables = self.available_variables
|
@@ -170,3 +175,14 @@ class PipelineState(BaseModel):
|
|
170 |
|
171 |
class TossupPipelineState(PipelineState):
|
172 |
workflow: TossupWorkflow
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
update["ui_state"] = self.ui_state.update_step(step.id, ui_state)
|
144 |
return self.model_copy(update=update)
|
145 |
|
146 |
+
def update_output_variable(self, target: str, produced_variable: str) -> "PipelineState":
|
147 |
+
"""Update the output variables for a step."""
|
148 |
+
self.workflow.outputs[target] = produced_variable
|
149 |
+
return self
|
150 |
+
|
151 |
def get_available_variables(self, model_step_id: str | None = None) -> list[str]:
|
152 |
"""Get all variables from all steps."""
|
153 |
available_variables = self.available_variables
|
|
|
175 |
|
176 |
class TossupPipelineState(PipelineState):
|
177 |
workflow: TossupWorkflow
|
178 |
+
|
179 |
+
def update_step(self, step: ModelStep, ui_state: ModelStepUIState | None = None) -> "TossupPipelineState":
|
180 |
+
"""Update a step in the pipeline."""
|
181 |
+
state = super().update_step(step, ui_state)
|
182 |
+
state.workflow = state.workflow.refresh_buzzer()
|
183 |
+
return state
|
184 |
+
|
185 |
+
def update_output_variable(self, target: str, produced_variable: str) -> "TossupPipelineState":
|
186 |
+
state = super().update_output_variable(target, produced_variable)
|
187 |
+
state.workflow = state.workflow.refresh_buzzer()
|
188 |
+
return state
|
src/workflows/validators.py
CHANGED
@@ -13,7 +13,6 @@ SUPPORTED_TYPES = {"str", "int", "float", "bool", "list[str]", "list[int]", "lis
|
|
13 |
MAX_FIELD_NAME_LENGTH = 50
|
14 |
MAX_DESCRIPTION_LENGTH = 200
|
15 |
MAX_SYSTEM_PROMPT_LENGTH = 4000
|
16 |
-
MIN_TEMPERATURE = 0.0
|
17 |
MAX_TEMPERATURE = 10.0
|
18 |
|
19 |
|
@@ -40,7 +39,7 @@ class ValidationError:
|
|
40 |
field_name: Optional[str] = None
|
41 |
|
42 |
|
43 |
-
class WorkflowValidationError(
|
44 |
"""Base class for workflow validation errors"""
|
45 |
|
46 |
def __init__(self, errors: list[ValidationError]):
|
@@ -77,9 +76,18 @@ def create_step_dep_graph(workflow: Workflow) -> dict[str, set[str]]:
|
|
77 |
class WorkflowValidator:
|
78 |
"""Validates workflows for correctness and consistency"""
|
79 |
|
80 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
self.errors: list[ValidationError] = []
|
82 |
self.workflow: Optional[Workflow] = None
|
|
|
|
|
83 |
|
84 |
def validate(self, workflow: Workflow) -> bool:
|
85 |
"""Main validation entry point"""
|
@@ -272,7 +280,7 @@ class WorkflowValidator:
|
|
272 |
self.errors.append(
|
273 |
ValidationError(
|
274 |
ValidationErrorType.NAMING,
|
275 |
-
f"Invalid step ID format: {step.id}. Must be a valid
|
276 |
step.id,
|
277 |
)
|
278 |
)
|
@@ -286,11 +294,11 @@ class WorkflowValidator:
|
|
286 |
)
|
287 |
return False
|
288 |
|
289 |
-
if not
|
290 |
self.errors.append(
|
291 |
ValidationError(
|
292 |
ValidationErrorType.RANGE,
|
293 |
-
f"Temperature must be between {
|
294 |
step.id,
|
295 |
)
|
296 |
)
|
@@ -304,11 +312,11 @@ class WorkflowValidator:
|
|
304 |
)
|
305 |
return False
|
306 |
|
307 |
-
if len(step.system_prompt) >
|
308 |
self.errors.append(
|
309 |
ValidationError(
|
310 |
ValidationErrorType.LENGTH,
|
311 |
-
f"System prompt exceeds maximum length of {
|
312 |
step.id,
|
313 |
)
|
314 |
)
|
@@ -365,22 +373,22 @@ class WorkflowValidator:
|
|
365 |
return False
|
366 |
|
367 |
# Validate field name length
|
368 |
-
if len(field.name) >
|
369 |
self.errors.append(
|
370 |
ValidationError(
|
371 |
ValidationErrorType.LENGTH,
|
372 |
-
f"Field name exceeds maximum length of {
|
373 |
field_name=field.name,
|
374 |
)
|
375 |
)
|
376 |
return False
|
377 |
|
378 |
# Validate description length
|
379 |
-
if len(field.description) >
|
380 |
self.errors.append(
|
381 |
ValidationError(
|
382 |
ValidationErrorType.LENGTH,
|
383 |
-
f"Description exceeds maximum length of {
|
384 |
field_name=field.name,
|
385 |
)
|
386 |
)
|
@@ -422,22 +430,22 @@ class WorkflowValidator:
|
|
422 |
return False
|
423 |
|
424 |
# Validate field name length
|
425 |
-
if len(field.name) >
|
426 |
self.errors.append(
|
427 |
ValidationError(
|
428 |
ValidationErrorType.LENGTH,
|
429 |
-
f"Field name exceeds maximum length of {
|
430 |
field_name=field.name,
|
431 |
)
|
432 |
)
|
433 |
return False
|
434 |
|
435 |
# Validate description length
|
436 |
-
if len(field.description) >
|
437 |
self.errors.append(
|
438 |
ValidationError(
|
439 |
ValidationErrorType.LENGTH,
|
440 |
-
f"Description exceeds maximum length of {
|
441 |
field_name=field.name,
|
442 |
)
|
443 |
)
|
@@ -545,10 +553,6 @@ class WorkflowValidator:
|
|
545 |
|
546 |
def _is_valid_identifier(self, name: str) -> bool:
|
547 |
"""Validates if a string is a valid Python identifier"""
|
548 |
-
if
|
549 |
-
return
|
550 |
-
|
551 |
-
return False
|
552 |
-
if not name.strip(): # Check for whitespace-only strings
|
553 |
-
return False
|
554 |
-
return bool(re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", name))
|
|
|
13 |
MAX_FIELD_NAME_LENGTH = 50
|
14 |
MAX_DESCRIPTION_LENGTH = 200
|
15 |
MAX_SYSTEM_PROMPT_LENGTH = 4000
|
|
|
16 |
MAX_TEMPERATURE = 10.0
|
17 |
|
18 |
|
|
|
39 |
field_name: Optional[str] = None
|
40 |
|
41 |
|
42 |
+
class WorkflowValidationError(ValueError):
|
43 |
"""Base class for workflow validation errors"""
|
44 |
|
45 |
def __init__(self, errors: list[ValidationError]):
|
|
|
76 |
class WorkflowValidator:
|
77 |
"""Validates workflows for correctness and consistency"""
|
78 |
|
79 |
+
def __init__(
|
80 |
+
self,
|
81 |
+
min_temperature: float = 0,
|
82 |
+
max_temperature: float = MAX_TEMPERATURE,
|
83 |
+
max_field_name_length: int = MAX_FIELD_NAME_LENGTH,
|
84 |
+
max_description_length: int = MAX_DESCRIPTION_LENGTH,
|
85 |
+
max_system_prompt_length: int = MAX_SYSTEM_PROMPT_LENGTH,
|
86 |
+
):
|
87 |
self.errors: list[ValidationError] = []
|
88 |
self.workflow: Optional[Workflow] = None
|
89 |
+
self.min_temperature = min_temperature
|
90 |
+
self.max_temperature = max_temperature
|
91 |
|
92 |
def validate(self, workflow: Workflow) -> bool:
|
93 |
"""Main validation entry point"""
|
|
|
280 |
self.errors.append(
|
281 |
ValidationError(
|
282 |
ValidationErrorType.NAMING,
|
283 |
+
f"Invalid step ID format: {step.id}. Must be a valid identifier.",
|
284 |
step.id,
|
285 |
)
|
286 |
)
|
|
|
294 |
)
|
295 |
return False
|
296 |
|
297 |
+
if not self.min_temperature <= step.temperature <= self.max_temperature:
|
298 |
self.errors.append(
|
299 |
ValidationError(
|
300 |
ValidationErrorType.RANGE,
|
301 |
+
f"Temperature must be between {self.min_temperature} and {self.max_temperature}",
|
302 |
step.id,
|
303 |
)
|
304 |
)
|
|
|
312 |
)
|
313 |
return False
|
314 |
|
315 |
+
if len(step.system_prompt) > self.max_system_prompt_length:
|
316 |
self.errors.append(
|
317 |
ValidationError(
|
318 |
ValidationErrorType.LENGTH,
|
319 |
+
f"System prompt exceeds maximum length of {self.max_system_prompt_length} characters",
|
320 |
step.id,
|
321 |
)
|
322 |
)
|
|
|
373 |
return False
|
374 |
|
375 |
# Validate field name length
|
376 |
+
if len(field.name) > self.max_field_name_length:
|
377 |
self.errors.append(
|
378 |
ValidationError(
|
379 |
ValidationErrorType.LENGTH,
|
380 |
+
f"Field name exceeds maximum length of {self.max_field_name_length} characters",
|
381 |
field_name=field.name,
|
382 |
)
|
383 |
)
|
384 |
return False
|
385 |
|
386 |
# Validate description length
|
387 |
+
if len(field.description) > self.max_description_length:
|
388 |
self.errors.append(
|
389 |
ValidationError(
|
390 |
ValidationErrorType.LENGTH,
|
391 |
+
f"Description exceeds maximum length of {self.max_description_length} characters",
|
392 |
field_name=field.name,
|
393 |
)
|
394 |
)
|
|
|
430 |
return False
|
431 |
|
432 |
# Validate field name length
|
433 |
+
if len(field.name) > self.max_field_name_length:
|
434 |
self.errors.append(
|
435 |
ValidationError(
|
436 |
ValidationErrorType.LENGTH,
|
437 |
+
f"Field name exceeds maximum length of {self.max_field_name_length} characters",
|
438 |
field_name=field.name,
|
439 |
)
|
440 |
)
|
441 |
return False
|
442 |
|
443 |
# Validate description length
|
444 |
+
if len(field.description) > self.max_description_length:
|
445 |
self.errors.append(
|
446 |
ValidationError(
|
447 |
ValidationErrorType.LENGTH,
|
448 |
+
f"Description exceeds maximum length of {self.max_description_length} characters",
|
449 |
field_name=field.name,
|
450 |
)
|
451 |
)
|
|
|
553 |
|
554 |
def _is_valid_identifier(self, name: str) -> bool:
|
555 |
"""Validates if a string is a valid Python identifier"""
|
556 |
+
if name and name.strip():
|
557 |
+
return bool(re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", name))
|
558 |
+
return False
|
|
|
|
|
|
|
|