Maharshi Gor
commited on
Commit
·
9756440
1
Parent(s):
b6a16f6
Squash merge dictify-states into main
Browse files- app.py +7 -3
- src/components/model_pipeline/model_pipeline.py +89 -87
- src/components/model_pipeline/state_manager.py +108 -163
- src/components/model_pipeline/tossup_pipeline.py +25 -44
- src/components/model_step/model_step.py +4 -3
- src/components/quizbowl/bonus.py +28 -17
- src/components/quizbowl/populate.py +6 -3
- src/components/quizbowl/tossup.py +32 -31
- src/components/structs.py +191 -0
- src/components/typed_dicts.py +64 -0
- src/display/guide.py +4 -0
- src/submission/structs.py +1 -0
- src/workflows/structs.py +42 -8
app.py
CHANGED
@@ -11,7 +11,7 @@ from components.quizbowl.bonus import BonusInterface
|
|
11 |
from components.quizbowl.tossup import TossupInterface
|
12 |
from display.css_html_js import fonts_header, js_head, leaderboard_css
|
13 |
from display.custom_css import css_bonus, css_pipeline, css_tossup
|
14 |
-
from display.guide import GUIDE_MARKDOWN
|
15 |
from display.utils import AutoEvalColumn, fields
|
16 |
|
17 |
# Constants
|
@@ -91,8 +91,6 @@ if __name__ == "__main__":
|
|
91 |
theme=THEME,
|
92 |
title="Quizbowl Bot",
|
93 |
) as demo:
|
94 |
-
with gr.Sidebar(width=400):
|
95 |
-
gr.Markdown(GUIDE_MARKDOWN)
|
96 |
with gr.Row():
|
97 |
gr.Markdown("## Welcome to Quizbowl Bot! This is a tool for creating and testing quizbowl agents.")
|
98 |
with gr.Tabs() as gtab:
|
@@ -121,5 +119,11 @@ if __name__ == "__main__":
|
|
121 |
visible=True,
|
122 |
)
|
123 |
refresh_btn.click(fn=fetch_leaderboard_df, inputs=[], outputs=leaderboard_table)
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
|
125 |
demo.queue(default_concurrency_limit=40).launch()
|
|
|
11 |
from components.quizbowl.tossup import TossupInterface
|
12 |
from display.css_html_js import fonts_header, js_head, leaderboard_css
|
13 |
from display.custom_css import css_bonus, css_pipeline, css_tossup
|
14 |
+
from display.guide import BUILDING_MARKDOWN, GUIDE_MARKDOWN, QUICKSTART_MARKDOWN
|
15 |
from display.utils import AutoEvalColumn, fields
|
16 |
|
17 |
# Constants
|
|
|
91 |
theme=THEME,
|
92 |
title="Quizbowl Bot",
|
93 |
) as demo:
|
|
|
|
|
94 |
with gr.Row():
|
95 |
gr.Markdown("## Welcome to Quizbowl Bot! This is a tool for creating and testing quizbowl agents.")
|
96 |
with gr.Tabs() as gtab:
|
|
|
119 |
visible=True,
|
120 |
)
|
121 |
refresh_btn.click(fn=fetch_leaderboard_df, inputs=[], outputs=leaderboard_table)
|
122 |
+
with gr.Tab("❓ Help", id="help"):
|
123 |
+
with gr.Row():
|
124 |
+
with gr.Column():
|
125 |
+
gr.Markdown(QUICKSTART_MARKDOWN)
|
126 |
+
with gr.Column():
|
127 |
+
gr.Markdown(BUILDING_MARKDOWN)
|
128 |
|
129 |
demo.queue(default_concurrency_limit=40).launch()
|
src/components/model_pipeline/model_pipeline.py
CHANGED
@@ -4,79 +4,74 @@ from loguru import logger
|
|
4 |
|
5 |
from app_configs import UNSELECTED_VAR_NAME
|
6 |
from components import commons
|
|
|
7 |
from components.model_pipeline.state_manager import (
|
8 |
ModelStepUIState,
|
9 |
PipelineState,
|
10 |
PipelineStateManager,
|
11 |
PipelineUIState,
|
|
|
|
|
12 |
)
|
13 |
from components.model_step.model_step import ModelStepComponent
|
14 |
from components.utils import make_state
|
15 |
-
from workflows.structs import ModelStep, Workflow
|
16 |
from workflows.validators import WorkflowValidator
|
17 |
|
18 |
|
19 |
-
def validate_simple_workflow(workflow: Workflow, required_output_variables: list[str]) -> Workflow:
|
20 |
-
"""Validate the workflow."""
|
21 |
-
step = next(iter(workflow.steps.values()))
|
22 |
-
if not step.output_fields:
|
23 |
-
raise ValueError("No output fields found in the workflow")
|
24 |
-
output_field_names = {output.name for output in step.output_fields}
|
25 |
-
if not set(required_output_variables) <= output_field_names:
|
26 |
-
missing_vars = required_output_variables - output_field_names
|
27 |
-
raise ValueError(f"Missing required output variables: {missing_vars}")
|
28 |
-
return workflow
|
29 |
-
|
30 |
-
|
31 |
-
def validate_complex_workflow(workflow: Workflow, required_output_variables: list[str]) -> Workflow:
|
32 |
-
"""Validate the workflow."""
|
33 |
-
print("Validating complex workflow.")
|
34 |
-
return workflow
|
35 |
-
step = next(iter(workflow.steps.values()))
|
36 |
-
if not step.output_fields:
|
37 |
-
raise ValueError("No output fields found in the workflow")
|
38 |
-
output_field_names = {output.name for output in step.output_fields}
|
39 |
-
if not output_field_names <= set(required_output_variables):
|
40 |
-
missing_vars = output_field_names - set(required_output_variables)
|
41 |
-
raise ValueError(f"Missing required output variables: {missing_vars}")
|
42 |
-
return workflow
|
43 |
-
|
44 |
-
|
45 |
-
def parse_yaml_workflow(yaml_str: str) -> Workflow:
|
46 |
-
"""Parse a YAML workflow."""
|
47 |
-
workflow = yaml.safe_load(yaml_str)
|
48 |
-
return Workflow(**workflow)
|
49 |
-
|
50 |
-
|
51 |
-
def update_workflow_from_code(yaml_str: str, ui_state: PipelineUIState) -> PipelineState:
|
52 |
-
"""Update a workflow from a YAML string."""
|
53 |
-
workflow = parse_yaml_workflow(yaml_str)
|
54 |
-
ui_state = PipelineUIState.from_workflow(workflow)
|
55 |
-
return PipelineState(workflow=workflow, ui_state=ui_state)
|
56 |
-
|
57 |
-
|
58 |
class PipelineInterface:
|
59 |
"""UI for the pipeline."""
|
60 |
|
61 |
def __init__(
|
62 |
self,
|
|
|
63 |
workflow: Workflow,
|
64 |
ui_state: PipelineUIState | None = None,
|
65 |
model_options: list[str] = None,
|
66 |
simple: bool = False,
|
67 |
-
show_pipeline_selector: bool = False,
|
68 |
):
|
|
|
69 |
self.model_options = model_options
|
70 |
self.simple = simple
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
self.
|
75 |
-
self.pipeline_state = make_state(PipelineState(workflow=workflow, ui_state=ui_state))
|
76 |
self.variables_state = make_state(workflow.get_available_variables())
|
77 |
-
self.model_selection_state = make_state(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
self.input_variables = workflow.inputs
|
81 |
self.required_output_variables = list(workflow.outputs.keys())
|
82 |
|
@@ -108,13 +103,13 @@ class PipelineInterface:
|
|
108 |
step_interface.on_model_step_change(
|
109 |
self.sm.update_model_step_state,
|
110 |
inputs=[self.pipeline_state, step_interface.model_step_state, step_interface.ui_state],
|
111 |
-
outputs=[self.pipeline_state
|
112 |
)
|
113 |
|
114 |
step_interface.on_ui_change(
|
115 |
self.sm.update_model_step_ui,
|
116 |
inputs=[self.pipeline_state, step_interface.ui_state, gr.State(model_step.id)],
|
117 |
-
outputs=[self.pipeline_state
|
118 |
)
|
119 |
|
120 |
if self.simple:
|
@@ -137,12 +132,20 @@ class PipelineInterface:
|
|
137 |
def _assign_step_controls(self, buttons: tuple[gr.Button, gr.Button, gr.Button], position: int):
|
138 |
up_button, down_button, remove_button = buttons
|
139 |
position = gr.State(position)
|
140 |
-
up_button.click(
|
141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
remove_button.click(
|
143 |
self.sm.remove_step,
|
144 |
-
inputs=[self.pipeline_state, position],
|
145 |
-
outputs=[self.pipeline_state, self.
|
146 |
)
|
147 |
|
148 |
def _render_add_step_button(self, position: int):
|
@@ -153,8 +156,8 @@ class PipelineInterface:
|
|
153 |
add_step_btn = gr.Button("➕ Add Step", elem_classes="add-step-button")
|
154 |
add_step_btn.click(
|
155 |
self.sm.add_step,
|
156 |
-
inputs=[self.pipeline_state, gr.State(position)],
|
157 |
-
outputs=[self.pipeline_state, self.
|
158 |
)
|
159 |
return add_step_btn
|
160 |
|
@@ -183,11 +186,9 @@ class PipelineInterface:
|
|
183 |
)
|
184 |
dropdowns[output_field] = dropdown
|
185 |
|
186 |
-
def update_choices(available_variables):
|
187 |
"""Update the choices for the dropdowns"""
|
188 |
-
return [
|
189 |
-
gr.update(choices=available_variables, value=None, selected=None) for dropdown in dropdowns.values()
|
190 |
-
]
|
191 |
|
192 |
self.variables_state.change(
|
193 |
update_choices,
|
@@ -196,16 +197,15 @@ class PipelineInterface:
|
|
196 |
)
|
197 |
return dropdowns
|
198 |
|
199 |
-
def validate_workflow(self,
|
200 |
"""Validate the workflow."""
|
201 |
try:
|
202 |
-
|
203 |
-
|
204 |
-
else:
|
205 |
-
workflow = validate_complex_workflow(state.workflow, self.required_output_variables)
|
206 |
-
state.workflow = workflow
|
207 |
-
return state
|
208 |
except ValueError as e:
|
|
|
|
|
|
|
209 |
raise gr.Error(e)
|
210 |
|
211 |
def _render_pipeline_header(self):
|
@@ -220,35 +220,35 @@ class PipelineInterface:
|
|
220 |
gr.Markdown(f"* Input Variables: {input_variables_str}")
|
221 |
gr.Markdown(f"* Output Variables: {output_variables_str}")
|
222 |
|
223 |
-
# if not self.simple:
|
224 |
-
# self._render_add_step_button(0)
|
225 |
-
|
226 |
def render(self):
|
227 |
"""Render the pipeline UI."""
|
228 |
# Create a placeholder for all the step components
|
229 |
self.all_components = []
|
230 |
|
231 |
-
# self.pipeline_state.change(
|
232 |
-
# lambda x, y: print(f"Pipeline state changed! UI:\n{x}\n\n Data:\n{y}"),
|
233 |
-
# inputs=[self.ui_state, self.pipeline_state],
|
234 |
-
# outputs=[],
|
235 |
-
# )
|
236 |
-
|
237 |
self._render_pipeline_header()
|
238 |
|
239 |
# Function to render all steps
|
240 |
-
@gr.render(
|
241 |
-
|
|
|
|
|
|
|
|
|
|
|
242 |
"""Render all steps in the pipeline"""
|
243 |
-
logger.info(
|
244 |
-
|
|
|
|
|
|
|
|
|
245 |
components = []
|
246 |
|
247 |
step_objects = [] # Reset step objects list
|
248 |
for i, step_id in enumerate(ui_state.step_ids):
|
249 |
step_data = workflow.steps[step_id]
|
250 |
step_ui_state = ui_state.steps[step_id]
|
251 |
-
available_variables =
|
252 |
sub_components = self._render_step(step_data, step_ui_state, available_variables, i, ui_state.n_steps)
|
253 |
step_objects.append(sub_components)
|
254 |
|
@@ -258,11 +258,13 @@ class PipelineInterface:
|
|
258 |
self._render_add_step_button(-1)
|
259 |
|
260 |
@gr.render(
|
|
|
261 |
inputs=[self.variables_state, self.pipeline_state],
|
262 |
concurrency_limit=1,
|
263 |
concurrency_id="render_output_fields",
|
264 |
)
|
265 |
-
def render_output_fields(available_variables,
|
|
|
266 |
logger.info(f"Rerendering output panel: {available_variables} {pipeline_state.workflow}")
|
267 |
self._render_output_panel(available_variables, pipeline_state)
|
268 |
|
@@ -285,14 +287,14 @@ class PipelineInterface:
|
|
285 |
# components.append(config_accordion)
|
286 |
|
287 |
self.config_output.blur(
|
288 |
-
fn=update_workflow_from_code,
|
289 |
-
inputs=[self.config_output, self.
|
290 |
-
outputs=[self.pipeline_state],
|
291 |
)
|
292 |
|
293 |
# Connect the export button to show the workflow JSON
|
294 |
self.add_triggers_for_pipeline_export([export_btn.click], self.pipeline_state)
|
295 |
-
export_btn.click(self.validate_workflow, inputs=[self.pipeline_state], outputs=[
|
296 |
fn=lambda: gr.update(visible=True, open=True), outputs=[self.config_accordion]
|
297 |
)
|
298 |
|
@@ -301,7 +303,7 @@ class PipelineInterface:
|
|
301 |
triggers,
|
302 |
self.validate_workflow,
|
303 |
inputs=[input_pipeline_state],
|
304 |
-
outputs=[
|
305 |
).success(
|
306 |
fn=self.sm.get_formatted_config,
|
307 |
inputs=[self.pipeline_state, gr.State("yaml")],
|
|
|
4 |
|
5 |
from app_configs import UNSELECTED_VAR_NAME
|
6 |
from components import commons
|
7 |
+
from components import typed_dicts as td
|
8 |
from components.model_pipeline.state_manager import (
|
9 |
ModelStepUIState,
|
10 |
PipelineState,
|
11 |
PipelineStateManager,
|
12 |
PipelineUIState,
|
13 |
+
TossupPipelineState,
|
14 |
+
TossupPipelineStateManager,
|
15 |
)
|
16 |
from components.model_step.model_step import ModelStepComponent
|
17 |
from components.utils import make_state
|
18 |
+
from workflows.structs import ModelStep, TossupWorkflow, Workflow
|
19 |
from workflows.validators import WorkflowValidator
|
20 |
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
class PipelineInterface:
|
23 |
"""UI for the pipeline."""
|
24 |
|
25 |
def __init__(
|
26 |
self,
|
27 |
+
app: gr.Blocks,
|
28 |
workflow: Workflow,
|
29 |
ui_state: PipelineUIState | None = None,
|
30 |
model_options: list[str] = None,
|
31 |
simple: bool = False,
|
|
|
32 |
):
|
33 |
+
self.app = app
|
34 |
self.model_options = model_options
|
35 |
self.simple = simple
|
36 |
+
ui_state = ui_state or PipelineUIState.from_workflow(workflow)
|
37 |
+
|
38 |
+
# Gradio States
|
39 |
+
self.workflow_state = make_state(workflow.model_dump())
|
|
|
40 |
self.variables_state = make_state(workflow.get_available_variables())
|
41 |
+
self.model_selection_state = make_state(workflow.get_model_selections())
|
42 |
+
self.pipeline_change = gr.State(False)
|
43 |
+
|
44 |
+
if isinstance(workflow, TossupWorkflow):
|
45 |
+
pipeline_state = TossupPipelineState(workflow=workflow, ui_state=ui_state)
|
46 |
+
self.sm = TossupPipelineStateManager()
|
47 |
+
else:
|
48 |
+
pipeline_state = PipelineState(workflow=workflow, ui_state=ui_state)
|
49 |
+
self.sm = PipelineStateManager()
|
50 |
+
self.pipeline_state = make_state(pipeline_state.model_dump())
|
51 |
+
|
52 |
+
def get_aux_states(pipeline_state_dict: td.PipelineStateDict):
|
53 |
+
"""Get the auxiliary states for the pipeline."""
|
54 |
+
pipeline_state = self.sm.make_pipeline_state(pipeline_state_dict)
|
55 |
+
return (
|
56 |
+
pipeline_state.workflow.model_dump(),
|
57 |
+
pipeline_state.workflow.get_available_variables(),
|
58 |
+
pipeline_state.workflow.get_model_selections(),
|
59 |
+
)
|
60 |
|
61 |
+
# Triggers for pipeline state changes
|
62 |
+
self.pipeline_state.change(
|
63 |
+
get_aux_states,
|
64 |
+
inputs=[self.pipeline_state],
|
65 |
+
outputs=[self.workflow_state, self.variables_state, self.model_selection_state],
|
66 |
+
)
|
67 |
+
|
68 |
+
self.workflow_state.change(
|
69 |
+
lambda x: logger.debug(f"Workflow state changed: {x}"),
|
70 |
+
inputs=[self.workflow_state],
|
71 |
+
outputs=[],
|
72 |
+
)
|
73 |
+
|
74 |
+
# IO Variables
|
75 |
self.input_variables = workflow.inputs
|
76 |
self.required_output_variables = list(workflow.outputs.keys())
|
77 |
|
|
|
103 |
step_interface.on_model_step_change(
|
104 |
self.sm.update_model_step_state,
|
105 |
inputs=[self.pipeline_state, step_interface.model_step_state, step_interface.ui_state],
|
106 |
+
outputs=[self.pipeline_state],
|
107 |
)
|
108 |
|
109 |
step_interface.on_ui_change(
|
110 |
self.sm.update_model_step_ui,
|
111 |
inputs=[self.pipeline_state, step_interface.ui_state, gr.State(model_step.id)],
|
112 |
+
outputs=[self.pipeline_state],
|
113 |
)
|
114 |
|
115 |
if self.simple:
|
|
|
132 |
def _assign_step_controls(self, buttons: tuple[gr.Button, gr.Button, gr.Button], position: int):
|
133 |
up_button, down_button, remove_button = buttons
|
134 |
position = gr.State(position)
|
135 |
+
up_button.click(
|
136 |
+
self.sm.move_up,
|
137 |
+
inputs=[self.pipeline_state, self.pipeline_change, position],
|
138 |
+
outputs=[self.pipeline_state, self.pipeline_change],
|
139 |
+
)
|
140 |
+
down_button.click(
|
141 |
+
self.sm.move_down,
|
142 |
+
inputs=[self.pipeline_state, self.pipeline_change, position],
|
143 |
+
outputs=[self.pipeline_state, self.pipeline_change],
|
144 |
+
)
|
145 |
remove_button.click(
|
146 |
self.sm.remove_step,
|
147 |
+
inputs=[self.pipeline_state, self.pipeline_change, position],
|
148 |
+
outputs=[self.pipeline_state, self.pipeline_change],
|
149 |
)
|
150 |
|
151 |
def _render_add_step_button(self, position: int):
|
|
|
156 |
add_step_btn = gr.Button("➕ Add Step", elem_classes="add-step-button")
|
157 |
add_step_btn.click(
|
158 |
self.sm.add_step,
|
159 |
+
inputs=[self.pipeline_state, self.pipeline_change, gr.State(position)],
|
160 |
+
outputs=[self.pipeline_state, self.pipeline_change],
|
161 |
)
|
162 |
return add_step_btn
|
163 |
|
|
|
186 |
)
|
187 |
dropdowns[output_field] = dropdown
|
188 |
|
189 |
+
def update_choices(available_variables: list[str]):
|
190 |
"""Update the choices for the dropdowns"""
|
191 |
+
return [gr.update(choices=available_variables, value=None, selected=None) for _ in dropdowns.values()]
|
|
|
|
|
192 |
|
193 |
self.variables_state.change(
|
194 |
update_choices,
|
|
|
197 |
)
|
198 |
return dropdowns
|
199 |
|
200 |
+
def validate_workflow(self, state_dict: td.PipelineStateDict):
|
201 |
"""Validate the workflow."""
|
202 |
try:
|
203 |
+
state = self.sm.make_pipeline_state(state_dict)
|
204 |
+
WorkflowValidator().validate(state.workflow)
|
|
|
|
|
|
|
|
|
205 |
except ValueError as e:
|
206 |
+
logger.exception(e)
|
207 |
+
state_dict_str = yaml.dump(state_dict, default_flow_style=False, indent=2)
|
208 |
+
logger.error(f"Could not validate workflow: \n{state_dict_str}")
|
209 |
raise gr.Error(e)
|
210 |
|
211 |
def _render_pipeline_header(self):
|
|
|
220 |
gr.Markdown(f"* Input Variables: {input_variables_str}")
|
221 |
gr.Markdown(f"* Output Variables: {output_variables_str}")
|
222 |
|
|
|
|
|
|
|
223 |
def render(self):
|
224 |
"""Render the pipeline UI."""
|
225 |
# Create a placeholder for all the step components
|
226 |
self.all_components = []
|
227 |
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
self._render_pipeline_header()
|
229 |
|
230 |
# Function to render all steps
|
231 |
+
@gr.render(
|
232 |
+
triggers=[self.app.load, self.pipeline_change.change],
|
233 |
+
inputs=[self.pipeline_state],
|
234 |
+
concurrency_limit=1,
|
235 |
+
concurrency_id="render_steps",
|
236 |
+
)
|
237 |
+
def render_steps(pipeline_state: td.PipelineStateDict, evt: gr.EventData):
|
238 |
"""Render all steps in the pipeline"""
|
239 |
+
logger.info(
|
240 |
+
f"Rerender triggered! \nInput Pipeline's UI State:{pipeline_state.get('ui_state')}\n Event: {evt.target} {evt._data}"
|
241 |
+
)
|
242 |
+
pipeline_state = self.sm.make_pipeline_state(pipeline_state)
|
243 |
+
ui_state = pipeline_state.ui_state
|
244 |
+
workflow = pipeline_state.workflow
|
245 |
components = []
|
246 |
|
247 |
step_objects = [] # Reset step objects list
|
248 |
for i, step_id in enumerate(ui_state.step_ids):
|
249 |
step_data = workflow.steps[step_id]
|
250 |
step_ui_state = ui_state.steps[step_id]
|
251 |
+
available_variables = pipeline_state.get_available_variables(step_id)
|
252 |
sub_components = self._render_step(step_data, step_ui_state, available_variables, i, ui_state.n_steps)
|
253 |
step_objects.append(sub_components)
|
254 |
|
|
|
258 |
self._render_add_step_button(-1)
|
259 |
|
260 |
@gr.render(
|
261 |
+
triggers=[self.variables_state.change, self.app.load],
|
262 |
inputs=[self.variables_state, self.pipeline_state],
|
263 |
concurrency_limit=1,
|
264 |
concurrency_id="render_output_fields",
|
265 |
)
|
266 |
+
def render_output_fields(available_variables: list[str], pipeline_state_dict: td.PipelineStateDict):
|
267 |
+
pipeline_state = self.sm.make_pipeline_state(pipeline_state_dict)
|
268 |
logger.info(f"Rerendering output panel: {available_variables} {pipeline_state.workflow}")
|
269 |
self._render_output_panel(available_variables, pipeline_state)
|
270 |
|
|
|
287 |
# components.append(config_accordion)
|
288 |
|
289 |
self.config_output.blur(
|
290 |
+
fn=self.sm.update_workflow_from_code,
|
291 |
+
inputs=[self.config_output, self.pipeline_change],
|
292 |
+
outputs=[self.pipeline_state, self.pipeline_change],
|
293 |
)
|
294 |
|
295 |
# Connect the export button to show the workflow JSON
|
296 |
self.add_triggers_for_pipeline_export([export_btn.click], self.pipeline_state)
|
297 |
+
export_btn.click(self.validate_workflow, inputs=[self.pipeline_state], outputs=[]).success(
|
298 |
fn=lambda: gr.update(visible=True, open=True), outputs=[self.config_accordion]
|
299 |
)
|
300 |
|
|
|
303 |
triggers,
|
304 |
self.validate_workflow,
|
305 |
inputs=[input_pipeline_state],
|
306 |
+
outputs=[],
|
307 |
).success(
|
308 |
fn=self.sm.get_formatted_config,
|
309 |
inputs=[self.pipeline_state, gr.State("yaml")],
|
src/components/model_pipeline/state_manager.py
CHANGED
@@ -1,137 +1,26 @@
|
|
1 |
import json
|
2 |
-
from typing import
|
3 |
|
4 |
-
import gradio as gr
|
5 |
import yaml
|
6 |
-
from loguru import logger
|
7 |
-
from pydantic import BaseModel, Field
|
8 |
|
|
|
|
|
9 |
from components import utils
|
|
|
10 |
from workflows.factory import create_new_llm_step
|
11 |
-
from workflows.structs import ModelStep, TossupWorkflow, Workflow
|
12 |
-
|
13 |
-
|
14 |
-
def make_step_id(step_number: int):
|
15 |
-
"""Make a step id from a step name."""
|
16 |
-
if step_number < 26:
|
17 |
-
return chr(ord("A") + step_number)
|
18 |
-
else:
|
19 |
-
# For more than 26 steps, use AA, AB, AC, etc.
|
20 |
-
first_char = chr(ord("A") + (step_number // 26) - 1)
|
21 |
-
second_char = chr(ord("A") + (step_number % 26))
|
22 |
-
return f"{first_char}{second_char}"
|
23 |
-
|
24 |
-
|
25 |
-
def make_step_number(step_id: str):
|
26 |
-
"""Make a step number from a step id."""
|
27 |
-
if len(step_id) == 1:
|
28 |
-
return ord(step_id) - ord("A")
|
29 |
-
else:
|
30 |
-
return (ord(step_id[0]) - ord("A")) * 26 + (ord(step_id[1]) - ord("A")) + 1
|
31 |
-
|
32 |
-
|
33 |
-
class ModelStepUIState(BaseModel):
|
34 |
-
"""Represents the UI state for a model step component."""
|
35 |
-
|
36 |
-
expanded: bool = True
|
37 |
-
active_tab: Literal["model-tab", "inputs-tab", "outputs-tab"] = "model-tab"
|
38 |
-
|
39 |
-
def update(self, key: str, value: Any) -> "ModelStepUIState":
|
40 |
-
"""Update the UI state."""
|
41 |
-
new_state = self.model_copy(update={key: value})
|
42 |
-
return new_state
|
43 |
-
|
44 |
-
|
45 |
-
class PipelineUIState(BaseModel):
|
46 |
-
"""Represents the UI state for a pipeline component."""
|
47 |
-
|
48 |
-
step_ids: list[str] = Field(default_factory=list)
|
49 |
-
steps: dict[str, ModelStepUIState] = Field(default_factory=dict)
|
50 |
-
|
51 |
-
def model_post_init(self, __context: utils.Any) -> None:
|
52 |
-
if not self.steps and self.step_ids:
|
53 |
-
self.steps = {step_id: ModelStepUIState() for step_id in self.step_ids}
|
54 |
-
return super().model_post_init(__context)
|
55 |
-
|
56 |
-
def get_step_position(self, step_id: str):
|
57 |
-
"""Get the position of a step in the pipeline."""
|
58 |
-
return next((i for i, step in enumerate(self.step_ids) if step == step_id), None)
|
59 |
-
|
60 |
-
@property
|
61 |
-
def n_steps(self) -> int:
|
62 |
-
"""Get the number of steps in the pipeline."""
|
63 |
-
return len(self.step_ids)
|
64 |
-
|
65 |
-
@classmethod
|
66 |
-
def from_workflow(cls, workflow: Workflow):
|
67 |
-
"""Create a pipeline UI state from a workflow."""
|
68 |
-
return PipelineUIState(
|
69 |
-
step_ids=list(workflow.steps.keys()),
|
70 |
-
steps={step_id: ModelStepUIState() for step_id in workflow.steps.keys()},
|
71 |
-
)
|
72 |
-
|
73 |
-
|
74 |
-
class PipelineState(BaseModel):
|
75 |
-
"""Represents the state for a pipeline component."""
|
76 |
-
|
77 |
-
workflow: Workflow
|
78 |
-
ui_state: PipelineUIState
|
79 |
-
|
80 |
-
def insert_step(self, position: int, step: ModelStep) -> "PipelineState":
|
81 |
-
if step.id in self.workflow.steps:
|
82 |
-
raise ValueError(f"Step {step.id} already exists in pipeline")
|
83 |
-
|
84 |
-
# Validate position
|
85 |
-
if position != -1 and (position < 0 or position > self.n_steps):
|
86 |
-
raise ValueError(f"Invalid position: {position}. Must be between 0 and {self.n_steps} or -1")
|
87 |
-
|
88 |
-
self.workflow.steps[step.id] = step
|
89 |
-
|
90 |
-
self.ui_state = self.ui_state.model_copy()
|
91 |
-
self.ui_state.steps[step.id] = ModelStepUIState()
|
92 |
-
if position == -1:
|
93 |
-
self.ui_state.step_ids.append(step.id)
|
94 |
-
else:
|
95 |
-
self.ui_state.step_ids.insert(position, step.id)
|
96 |
-
return self
|
97 |
-
|
98 |
-
def remove_step(self, position: int) -> "PipelineState":
|
99 |
-
step_id = self.ui_state.step_ids.pop(position)
|
100 |
-
self.workflow.steps.pop(step_id)
|
101 |
-
self.ui_state = self.ui_state.model_copy()
|
102 |
-
self.ui_state.steps.pop(step_id)
|
103 |
-
self.update_output_variables_mapping()
|
104 |
-
return self
|
105 |
-
|
106 |
-
def update_output_variables_mapping(self) -> "PipelineState":
|
107 |
-
available_variables = set(self.available_variables)
|
108 |
-
for output_field in self.workflow.outputs:
|
109 |
-
if self.workflow.outputs[output_field] not in available_variables:
|
110 |
-
self.workflow.outputs[output_field] = None
|
111 |
-
return self
|
112 |
-
|
113 |
-
@property
|
114 |
-
def available_variables(self) -> list[str]:
|
115 |
-
return self.workflow.get_available_variables()
|
116 |
-
|
117 |
-
@property
|
118 |
-
def n_steps(self) -> int:
|
119 |
-
return len(self.workflow.steps)
|
120 |
-
|
121 |
-
def get_new_step_id(self) -> str:
|
122 |
-
"""Get a step ID for a new step."""
|
123 |
-
if not self.workflow.steps:
|
124 |
-
return "A"
|
125 |
-
else:
|
126 |
-
last_step_number = max(map(make_step_number, self.workflow.steps.keys()))
|
127 |
-
return make_step_id(last_step_number + 1)
|
128 |
|
129 |
|
130 |
class PipelineStateManager:
|
131 |
"""Manages a pipeline of multiple steps."""
|
132 |
|
133 |
-
def
|
|
|
|
|
|
|
|
|
134 |
"""Get the full pipeline configuration."""
|
|
|
135 |
config = state.workflow.model_dump(exclude_defaults=True)
|
136 |
if isinstance(state.workflow, TossupWorkflow):
|
137 |
buzzer_config = state.workflow.buzzer.model_dump(exclude_defaults=False)
|
@@ -141,65 +30,121 @@ class PipelineStateManager:
|
|
141 |
else:
|
142 |
return json.dumps(config, indent=4, sort_keys=False)
|
143 |
|
144 |
-
def
|
145 |
-
|
146 |
-
|
147 |
-
def add_step(self, state: PipelineState, position: int = -1, name=""):
|
148 |
"""Create a new step and return its state."""
|
|
|
149 |
step_id = state.get_new_step_id()
|
150 |
step_name = name or f"Step {state.n_steps + 1}"
|
151 |
new_step = create_new_llm_step(step_id=step_id, name=step_name)
|
152 |
state = state.insert_step(position, new_step)
|
153 |
-
return state
|
154 |
|
155 |
-
def remove_step(
|
|
|
|
|
156 |
"""Remove a step from the pipeline."""
|
|
|
157 |
if 0 <= position < state.n_steps:
|
158 |
state = state.remove_step(position)
|
159 |
else:
|
160 |
raise ValueError(f"Invalid step position: {position}")
|
161 |
-
return state
|
162 |
|
163 |
-
def
|
164 |
-
|
165 |
-
|
166 |
-
|
|
|
|
|
|
|
167 |
|
168 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
"""Move a step down in the pipeline."""
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
|
|
|
|
|
|
184 |
produced_variable = None
|
185 |
"""Update the output variables for a step."""
|
186 |
-
state.
|
187 |
-
|
|
|
188 |
|
189 |
-
def update_model_step_ui(
|
|
|
|
|
190 |
"""Update a step in the pipeline."""
|
|
|
191 |
state.ui_state.steps[step_id] = step_ui.model_copy()
|
192 |
-
return state
|
193 |
|
194 |
-
def get_all_variables(self,
|
195 |
"""Get all variables from all steps."""
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
def
|
204 |
-
"""
|
205 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import json
|
2 |
+
from typing import Literal
|
3 |
|
|
|
4 |
import yaml
|
|
|
|
|
5 |
|
6 |
+
from app_configs import UNSELECTED_VAR_NAME
|
7 |
+
from components import typed_dicts as td
|
8 |
from components import utils
|
9 |
+
from components.structs import ModelStepUIState, PipelineState, PipelineUIState, TossupPipelineState
|
10 |
from workflows.factory import create_new_llm_step
|
11 |
+
from workflows.structs import Buzzer, ModelStep, TossupWorkflow, Workflow
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
|
14 |
class PipelineStateManager:
|
15 |
"""Manages a pipeline of multiple steps."""
|
16 |
|
17 |
+
def make_pipeline_state(self, state_dict: td.PipelineStateDict) -> PipelineState:
|
18 |
+
"""Make a state from a state dictionary."""
|
19 |
+
return PipelineState(**state_dict)
|
20 |
+
|
21 |
+
def get_formatted_config(self, state_dict: td.PipelineStateDict, format: Literal["json", "yaml"] = "yaml") -> str:
|
22 |
"""Get the full pipeline configuration."""
|
23 |
+
state = self.make_pipeline_state(state_dict)
|
24 |
config = state.workflow.model_dump(exclude_defaults=True)
|
25 |
if isinstance(state.workflow, TossupWorkflow):
|
26 |
buzzer_config = state.workflow.buzzer.model_dump(exclude_defaults=False)
|
|
|
30 |
else:
|
31 |
return json.dumps(config, indent=4, sort_keys=False)
|
32 |
|
33 |
+
def add_step(
|
34 |
+
self, state_dict: td.PipelineStateDict, pipeline_change: bool, position: int = -1, name=""
|
35 |
+
) -> td.PipelineStateDict:
|
|
|
36 |
"""Create a new step and return its state."""
|
37 |
+
state = self.make_pipeline_state(state_dict)
|
38 |
step_id = state.get_new_step_id()
|
39 |
step_name = name or f"Step {state.n_steps + 1}"
|
40 |
new_step = create_new_llm_step(step_id=step_id, name=step_name)
|
41 |
state = state.insert_step(position, new_step)
|
42 |
+
return state.model_dump(), not pipeline_change
|
43 |
|
44 |
+
def remove_step(
|
45 |
+
self, state_dict: td.PipelineStateDict, pipeline_change: bool, position: int
|
46 |
+
) -> td.PipelineStateDict:
|
47 |
"""Remove a step from the pipeline."""
|
48 |
+
state = self.make_pipeline_state(state_dict)
|
49 |
if 0 <= position < state.n_steps:
|
50 |
state = state.remove_step(position)
|
51 |
else:
|
52 |
raise ValueError(f"Invalid step position: {position}")
|
53 |
+
return state.model_dump(), not pipeline_change
|
54 |
|
55 |
+
def _move_step(
|
56 |
+
self, state_dict: td.PipelineStateDict, position: int, direction: Literal["up", "down"]
|
57 |
+
) -> tuple[td.PipelineStateDict, bool]:
|
58 |
+
state = self.make_pipeline_state(state_dict)
|
59 |
+
old_order = list(state.ui_state.step_ids)
|
60 |
+
utils.move_item(state.ui_state.step_ids, position, direction)
|
61 |
+
return state.model_dump(), old_order != list(state.ui_state.step_ids)
|
62 |
|
63 |
+
def move_up(self, state_dict: td.PipelineStateDict, pipeline_change: bool, position: int) -> td.PipelineStateDict:
|
64 |
+
"""Move a step up in the pipeline."""
|
65 |
+
new_state_dict, change = self._move_step(state_dict, position, "up")
|
66 |
+
if change:
|
67 |
+
pipeline_change = not pipeline_change
|
68 |
+
return new_state_dict, pipeline_change
|
69 |
+
|
70 |
+
def move_down(
|
71 |
+
self, state_dict: td.PipelineStateDict, pipeline_change: bool, position: int
|
72 |
+
) -> td.PipelineStateDict:
|
73 |
"""Move a step down in the pipeline."""
|
74 |
+
new_state_dict, change = self._move_step(state_dict, position, "down")
|
75 |
+
if change:
|
76 |
+
pipeline_change = not pipeline_change
|
77 |
+
return new_state_dict, pipeline_change
|
78 |
+
|
79 |
+
def update_model_step_state(
|
80 |
+
self, state_dict: td.PipelineStateDict, model_step: ModelStep, ui_state: ModelStepUIState
|
81 |
+
) -> td.PipelineStateDict:
|
82 |
+
"""Update a particular model step in the pipeline."""
|
83 |
+
state = self.make_pipeline_state(state_dict)
|
84 |
+
state = state.update_step(model_step, ui_state)
|
85 |
+
return state.model_dump()
|
86 |
+
|
87 |
+
def update_output_variables(
|
88 |
+
self, state_dict: td.PipelineStateDict, target: str, produced_variable: str
|
89 |
+
) -> td.PipelineStateDict:
|
90 |
+
if produced_variable == UNSELECTED_VAR_NAME:
|
91 |
produced_variable = None
|
92 |
"""Update the output variables for a step."""
|
93 |
+
state = self.make_pipeline_state(state_dict)
|
94 |
+
state.workflow.outputs[target] = produced_variable
|
95 |
+
return state.model_dump()
|
96 |
|
97 |
+
def update_model_step_ui(
|
98 |
+
self, state_dict: td.PipelineStateDict, step_ui: ModelStepUIState, step_id: str
|
99 |
+
) -> td.PipelineStateDict:
|
100 |
"""Update a step in the pipeline."""
|
101 |
+
state = self.make_pipeline_state(state_dict)
|
102 |
state.ui_state.steps[step_id] = step_ui.model_copy()
|
103 |
+
return state.model_dump()
|
104 |
|
105 |
+
def get_all_variables(self, state_dict: td.PipelineStateDict, model_step_id: str | None = None) -> list[str]:
|
106 |
"""Get all variables from all steps."""
|
107 |
+
return self.make_pipeline_state(state_dict)
|
108 |
+
|
109 |
+
def parse_yaml_workflow(self, yaml_str: str) -> Workflow:
|
110 |
+
"""Parse a YAML workflow."""
|
111 |
+
workflow = yaml.safe_load(yaml_str)
|
112 |
+
return Workflow(**workflow)
|
113 |
+
|
114 |
+
def update_workflow_from_code(self, yaml_str: str) -> td.PipelineStateDict:
|
115 |
+
"""Update a workflow from a YAML string."""
|
116 |
+
workflow = self.parse_yaml_workflow(yaml_str)
|
117 |
+
return PipelineState.from_workflow(workflow).model_dump()
|
118 |
+
|
119 |
+
|
120 |
+
class TossupPipelineStateManager(PipelineStateManager):
|
121 |
+
"""Manages a tossup pipeline state."""
|
122 |
+
|
123 |
+
def make_pipeline_state(self, state_dict: td.PipelineStateDict) -> TossupPipelineState:
|
124 |
+
"""Make a state from a state dictionary."""
|
125 |
+
return TossupPipelineState(**state_dict)
|
126 |
+
|
127 |
+
def parse_yaml_workflow(self, yaml_str: str) -> TossupWorkflow:
|
128 |
+
"""Parse a YAML workflow."""
|
129 |
+
workflow = yaml.safe_load(yaml_str)
|
130 |
+
return TossupWorkflow(**workflow)
|
131 |
+
|
132 |
+
def update_workflow_from_code(self, yaml_str: str, change_state: bool) -> tuple[td.PipelineStateDict, bool]:
|
133 |
+
"""Update a workflow from a YAML string."""
|
134 |
+
workflow = self.parse_yaml_workflow(yaml_str)
|
135 |
+
return TossupPipelineState.from_workflow(workflow).model_dump(), not change_state
|
136 |
+
|
137 |
+
def update_buzzer(
|
138 |
+
self,
|
139 |
+
state_dict: td.TossupPipelineStateDict,
|
140 |
+
confidence_threshold: float,
|
141 |
+
method: str,
|
142 |
+
tokens_prob: float | None,
|
143 |
+
) -> td.TossupPipelineStateDict:
|
144 |
+
"""Update the buzzer."""
|
145 |
+
state = self.make_pipeline_state(state_dict)
|
146 |
+
prob_threshold = float(tokens_prob) if tokens_prob and tokens_prob > 0 else None
|
147 |
+
state.workflow.buzzer = Buzzer(
|
148 |
+
method=method, confidence_threshold=confidence_threshold, prob_threshold=prob_threshold
|
149 |
+
)
|
150 |
+
return state.model_dump()
|
src/components/model_pipeline/tossup_pipeline.py
CHANGED
@@ -4,10 +4,13 @@ from loguru import logger
|
|
4 |
|
5 |
from app_configs import AVAILABLE_MODELS, UNSELECTED_VAR_NAME
|
6 |
from components import commons
|
|
|
|
|
7 |
from display.formatting import tiny_styled_warning
|
8 |
from workflows.structs import Buzzer, TossupWorkflow
|
9 |
|
10 |
from .model_pipeline import PipelineInterface, PipelineState, PipelineUIState
|
|
|
11 |
|
12 |
|
13 |
def toggleable_slider(
|
@@ -30,48 +33,34 @@ def toggleable_slider(
|
|
30 |
return checkbox, slider
|
31 |
|
32 |
|
33 |
-
class TossupPipelineState(PipelineState):
|
34 |
-
workflow: TossupWorkflow
|
35 |
-
|
36 |
-
|
37 |
class TossupPipelineInterface(PipelineInterface):
|
38 |
def __init__(
|
39 |
self,
|
|
|
40 |
workflow: TossupWorkflow,
|
41 |
ui_state: PipelineUIState | None = None,
|
42 |
model_options: list[str] = None,
|
43 |
simple: bool = False,
|
44 |
-
show_pipeline_selector: bool = False,
|
45 |
defaults: dict = {},
|
46 |
):
|
47 |
-
super().__init__(workflow, ui_state, model_options, simple
|
48 |
self.defaults = defaults
|
49 |
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
tokens_prob: float | None,
|
56 |
-
):
|
57 |
-
"""Update the buzzer."""
|
58 |
-
|
59 |
-
prob_threshold = float(tokens_prob) if tokens_prob and tokens_prob > 0 else None
|
60 |
-
state.workflow.buzzer = state.workflow.buzzer.model_copy(
|
61 |
-
update={
|
62 |
-
"method": method,
|
63 |
-
"confidence_threshold": confidence_threshold,
|
64 |
-
"prob_threshold": prob_threshold,
|
65 |
-
}
|
66 |
)
|
67 |
-
Buzzer.model_validate(state.workflow.buzzer)
|
68 |
-
return state
|
69 |
|
70 |
-
def update_prob_slider(
|
|
|
|
|
71 |
"""Update the probability slider based on the answer variable."""
|
|
|
72 |
if answer_var == UNSELECTED_VAR_NAME:
|
73 |
return (
|
74 |
-
state,
|
75 |
gr.update(interactive=True),
|
76 |
gr.update(value="AND", interactive=True),
|
77 |
gr.update(visible=False),
|
@@ -83,14 +72,13 @@ class TossupPipelineInterface(PipelineInterface):
|
|
83 |
buzzer = state.workflow.buzzer
|
84 |
tokens_prob_threshold = tokens_prob if is_model_with_logprobs else None
|
85 |
method = buzzer.method if is_model_with_logprobs else "AND"
|
86 |
-
state =
|
87 |
-
state,
|
88 |
-
confidence_threshold=buzzer.confidence_threshold,
|
89 |
method=method,
|
90 |
-
|
|
|
91 |
)
|
92 |
return (
|
93 |
-
state,
|
94 |
gr.update(interactive=is_model_with_logprobs),
|
95 |
gr.update(value=method, interactive=is_model_with_logprobs),
|
96 |
gr.update(
|
@@ -156,11 +144,9 @@ class TossupPipelineInterface(PipelineInterface):
|
|
156 |
)
|
157 |
self.buzzer_warning_display = gr.HTML(visible=False)
|
158 |
|
159 |
-
def update_choices(available_variables):
|
160 |
"""Update the choices for the dropdowns"""
|
161 |
-
return [
|
162 |
-
gr.update(choices=available_variables, value=None, selected=None) for dropdown in dropdowns.values()
|
163 |
-
]
|
164 |
|
165 |
self.variables_state.change(
|
166 |
update_choices,
|
@@ -170,17 +156,12 @@ class TossupPipelineInterface(PipelineInterface):
|
|
170 |
|
171 |
gr.on(
|
172 |
triggers=[
|
173 |
-
self.confidence_slider.
|
174 |
self.buzzer_method_dropdown.input,
|
175 |
-
self.prob_slider.
|
176 |
-
],
|
177 |
-
fn=self.update_buzzer,
|
178 |
-
inputs=[
|
179 |
-
self.pipeline_state,
|
180 |
-
self.confidence_slider,
|
181 |
-
self.buzzer_method_dropdown,
|
182 |
-
self.prob_slider,
|
183 |
],
|
|
|
|
|
184 |
outputs=[self.pipeline_state],
|
185 |
)
|
186 |
|
|
|
4 |
|
5 |
from app_configs import AVAILABLE_MODELS, UNSELECTED_VAR_NAME
|
6 |
from components import commons
|
7 |
+
from components.structs import TossupPipelineState
|
8 |
+
from components.typed_dicts import TossupPipelineStateDict
|
9 |
from display.formatting import tiny_styled_warning
|
10 |
from workflows.structs import Buzzer, TossupWorkflow
|
11 |
|
12 |
from .model_pipeline import PipelineInterface, PipelineState, PipelineUIState
|
13 |
+
from .state_manager import PipelineStateManager, TossupPipelineStateManager
|
14 |
|
15 |
|
16 |
def toggleable_slider(
|
|
|
33 |
return checkbox, slider
|
34 |
|
35 |
|
|
|
|
|
|
|
|
|
36 |
class TossupPipelineInterface(PipelineInterface):
|
37 |
def __init__(
|
38 |
self,
|
39 |
+
app: gr.Blocks,
|
40 |
workflow: TossupWorkflow,
|
41 |
ui_state: PipelineUIState | None = None,
|
42 |
model_options: list[str] = None,
|
43 |
simple: bool = False,
|
|
|
44 |
defaults: dict = {},
|
45 |
):
|
46 |
+
super().__init__(app, workflow, ui_state, model_options, simple)
|
47 |
self.defaults = defaults
|
48 |
|
49 |
+
self.pipeline_state.change(
|
50 |
+
lambda x: logger.debug(
|
51 |
+
f"Pipeline state changed. Type: {type(x)}. Has buzzer info: {x['workflow']['buzzer'] if isinstance(x, dict) else 'N/A'}"
|
52 |
+
),
|
53 |
+
inputs=[self.pipeline_state],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
)
|
|
|
|
|
55 |
|
56 |
+
def update_prob_slider(
|
57 |
+
self, state_dict: TossupPipelineStateDict, answer_var: str, tokens_prob: float | None
|
58 |
+
) -> tuple[TossupPipelineStateDict, dict, dict, dict]:
|
59 |
"""Update the probability slider based on the answer variable."""
|
60 |
+
state = TossupPipelineState(**state_dict)
|
61 |
if answer_var == UNSELECTED_VAR_NAME:
|
62 |
return (
|
63 |
+
state.model_dump(),
|
64 |
gr.update(interactive=True),
|
65 |
gr.update(value="AND", interactive=True),
|
66 |
gr.update(visible=False),
|
|
|
72 |
buzzer = state.workflow.buzzer
|
73 |
tokens_prob_threshold = tokens_prob if is_model_with_logprobs else None
|
74 |
method = buzzer.method if is_model_with_logprobs else "AND"
|
75 |
+
state.workflow.buzzer = Buzzer(
|
|
|
|
|
76 |
method=method,
|
77 |
+
confidence_threshold=buzzer.confidence_threshold,
|
78 |
+
prob_threshold=tokens_prob_threshold,
|
79 |
)
|
80 |
return (
|
81 |
+
state.model_dump(),
|
82 |
gr.update(interactive=is_model_with_logprobs),
|
83 |
gr.update(value=method, interactive=is_model_with_logprobs),
|
84 |
gr.update(
|
|
|
144 |
)
|
145 |
self.buzzer_warning_display = gr.HTML(visible=False)
|
146 |
|
147 |
+
def update_choices(available_variables: list[str]):
|
148 |
"""Update the choices for the dropdowns"""
|
149 |
+
return [gr.update(choices=available_variables, value=None, selected=None) for _ in dropdowns.values()]
|
|
|
|
|
150 |
|
151 |
self.variables_state.change(
|
152 |
update_choices,
|
|
|
156 |
|
157 |
gr.on(
|
158 |
triggers=[
|
159 |
+
self.confidence_slider.release,
|
160 |
self.buzzer_method_dropdown.input,
|
161 |
+
self.prob_slider.release,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
],
|
163 |
+
fn=self.sm.update_buzzer,
|
164 |
+
inputs=[self.pipeline_state, self.confidence_slider, self.buzzer_method_dropdown, self.prob_slider],
|
165 |
outputs=[self.pipeline_state],
|
166 |
)
|
167 |
|
src/components/model_step/model_step.py
CHANGED
@@ -5,7 +5,8 @@ import gradio as gr
|
|
5 |
from gradio.components import FormComponent
|
6 |
|
7 |
from app_configs import UNSELECTED_VAR_NAME
|
8 |
-
from components.model_pipeline.state_manager import ModelStepUIState,
|
|
|
9 |
from utils import get_full_model_name
|
10 |
from workflows.structs import ModelStep
|
11 |
|
@@ -454,12 +455,12 @@ class ModelStepComponent(FormComponent):
|
|
454 |
new_label = _make_accordion_label(new_model_step)
|
455 |
return new_model_step, gr.update(label=new_label)
|
456 |
|
457 |
-
def refresh_variable_dropdowns(self,
|
458 |
# TODO: Fix this. Not sure why this is needed.
|
459 |
"""Refresh the variable dropdown options in all input rows."""
|
460 |
variable_choices = []
|
461 |
if self.pipeline_sm is not None:
|
462 |
-
variable_choices = self.pipeline_sm.get_all_variables(
|
463 |
|
464 |
for _, fields, _ in self.input_rows:
|
465 |
_, inp_var, _ = fields
|
|
|
5 |
from gradio.components import FormComponent
|
6 |
|
7 |
from app_configs import UNSELECTED_VAR_NAME
|
8 |
+
from components.model_pipeline.state_manager import ModelStepUIState, PipelineStateManager
|
9 |
+
from components.typed_dicts import PipelineStateDict
|
10 |
from utils import get_full_model_name
|
11 |
from workflows.structs import ModelStep
|
12 |
|
|
|
455 |
new_label = _make_accordion_label(new_model_step)
|
456 |
return new_model_step, gr.update(label=new_label)
|
457 |
|
458 |
+
def refresh_variable_dropdowns(self, pipeline_state_dict: PipelineStateDict):
|
459 |
# TODO: Fix this. Not sure why this is needed.
|
460 |
"""Refresh the variable dropdown options in all input rows."""
|
461 |
variable_choices = []
|
462 |
if self.pipeline_sm is not None:
|
463 |
+
variable_choices = self.pipeline_sm.get_all_variables(pipeline_state_dict)
|
464 |
|
465 |
for _, fields, _ in self.input_rows:
|
466 |
_, inp_var, _ = fields
|
src/components/quizbowl/bonus.py
CHANGED
@@ -9,6 +9,7 @@ from loguru import logger
|
|
9 |
from app_configs import UNSELECTED_PIPELINE_NAME
|
10 |
from components import commons
|
11 |
from components.model_pipeline.model_pipeline import PipelineInterface, PipelineState, PipelineUIState
|
|
|
12 |
from display.formatting import styled_error
|
13 |
from submission import submit
|
14 |
from workflows.qb_agents import QuizBowlBonusAgent
|
@@ -128,6 +129,7 @@ class BonusInterface:
|
|
128 |
self.pipeline_selector = commons.get_pipeline_selector([])
|
129 |
self.load_btn = gr.Button("⬇️ Import Pipeline", variant="secondary")
|
130 |
self.pipeline_interface = PipelineInterface(
|
|
|
131 |
workflow,
|
132 |
simple=simple,
|
133 |
model_options=list(self.model_options.keys()),
|
@@ -214,24 +216,28 @@ class BonusInterface:
|
|
214 |
names = [UNSELECTED_PIPELINE_NAME] + populate.get_pipeline_names("bonus", profile)
|
215 |
return gr.update(choices=names, value=UNSELECTED_PIPELINE_NAME)
|
216 |
|
217 |
-
def load_pipeline(
|
|
|
|
|
218 |
try:
|
219 |
-
|
220 |
-
if
|
221 |
-
|
222 |
-
|
|
|
|
|
223 |
except Exception as e:
|
224 |
error_msg = styled_error(f"Error loading pipeline: {str(e)}")
|
225 |
-
return UNSELECTED_PIPELINE_NAME, gr.skip(), gr.update(visible=True, value=error_msg)
|
226 |
|
227 |
def single_run(
|
228 |
self,
|
229 |
question_id: int,
|
230 |
-
|
231 |
) -> tuple[str, Any, Any]:
|
232 |
"""Run the agent in bonus mode."""
|
233 |
try:
|
234 |
-
|
235 |
question_id = int(question_id - 1)
|
236 |
if not self.ds or question_id < 0 or question_id >= len(self.ds):
|
237 |
return "Invalid question ID or dataset not loaded", None, None
|
@@ -263,9 +269,10 @@ class BonusInterface:
|
|
263 |
gr.update(visible=True, value=error_msg),
|
264 |
)
|
265 |
|
266 |
-
def evaluate(self,
|
267 |
"""Evaluate the bonus questions."""
|
268 |
try:
|
|
|
269 |
# Validate inputs
|
270 |
if not self.ds or not self.ds.num_rows:
|
271 |
return "No dataset loaded", None, None
|
@@ -307,9 +314,14 @@ class BonusInterface:
|
|
307 |
return gr.skip(), gr.update(visible=True, value=error_msg)
|
308 |
|
309 |
def submit_model(
|
310 |
-
self,
|
|
|
|
|
|
|
|
|
311 |
):
|
312 |
"""Submit the model output."""
|
|
|
313 |
return submit.submit_model(model_name, description, pipeline_state.workflow, "bonus", profile)
|
314 |
|
315 |
def _setup_event_listeners(self):
|
@@ -328,20 +340,19 @@ class BonusInterface:
|
|
328 |
outputs=[self.pipeline_selector],
|
329 |
)
|
330 |
|
331 |
-
|
|
|
332 |
self.load_btn.click(
|
333 |
fn=self.load_pipeline,
|
334 |
-
inputs=[self.pipeline_selector],
|
335 |
-
outputs=[self.pipeline_selector,
|
336 |
-
)
|
337 |
-
self.pipeline_interface.add_triggers_for_pipeline_export(
|
338 |
-
[self.new_loaded_pipeline_state.change], self.new_loaded_pipeline_state
|
339 |
)
|
|
|
340 |
|
341 |
self.run_btn.click(
|
342 |
self.pipeline_interface.validate_workflow,
|
343 |
inputs=[self.pipeline_interface.pipeline_state],
|
344 |
-
outputs=[
|
345 |
).success(
|
346 |
self.single_run,
|
347 |
inputs=[
|
|
|
9 |
from app_configs import UNSELECTED_PIPELINE_NAME
|
10 |
from components import commons
|
11 |
from components.model_pipeline.model_pipeline import PipelineInterface, PipelineState, PipelineUIState
|
12 |
+
from components.typed_dicts import PipelineStateDict
|
13 |
from display.formatting import styled_error
|
14 |
from submission import submit
|
15 |
from workflows.qb_agents import QuizBowlBonusAgent
|
|
|
129 |
self.pipeline_selector = commons.get_pipeline_selector([])
|
130 |
self.load_btn = gr.Button("⬇️ Import Pipeline", variant="secondary")
|
131 |
self.pipeline_interface = PipelineInterface(
|
132 |
+
self.app,
|
133 |
workflow,
|
134 |
simple=simple,
|
135 |
model_options=list(self.model_options.keys()),
|
|
|
216 |
names = [UNSELECTED_PIPELINE_NAME] + populate.get_pipeline_names("bonus", profile)
|
217 |
return gr.update(choices=names, value=UNSELECTED_PIPELINE_NAME)
|
218 |
|
219 |
+
def load_pipeline(
|
220 |
+
self, model_name: str, pipeline_change: bool, profile: gr.OAuthProfile | None
|
221 |
+
) -> tuple[str, PipelineStateDict, bool, dict]:
|
222 |
try:
|
223 |
+
workflow = populate.load_workflow("bonus", model_name, profile)
|
224 |
+
if workflow is None:
|
225 |
+
logger.warning(f"Could not load workflow for {model_name}")
|
226 |
+
return UNSELECTED_PIPELINE_NAME, gr.skip(), gr.skip(), gr.update(visible=False)
|
227 |
+
pipeline_state_dict = PipelineState.from_workflow(workflow).model_dump()
|
228 |
+
return UNSELECTED_PIPELINE_NAME, pipeline_state_dict, not pipeline_change, gr.update(visible=True)
|
229 |
except Exception as e:
|
230 |
error_msg = styled_error(f"Error loading pipeline: {str(e)}")
|
231 |
+
return UNSELECTED_PIPELINE_NAME, gr.skip(), gr.skip(), gr.update(visible=True, value=error_msg)
|
232 |
|
233 |
def single_run(
|
234 |
self,
|
235 |
question_id: int,
|
236 |
+
state_dict: PipelineStateDict,
|
237 |
) -> tuple[str, Any, Any]:
|
238 |
"""Run the agent in bonus mode."""
|
239 |
try:
|
240 |
+
pipeline_state = PipelineState(**state_dict)
|
241 |
question_id = int(question_id - 1)
|
242 |
if not self.ds or question_id < 0 or question_id >= len(self.ds):
|
243 |
return "Invalid question ID or dataset not loaded", None, None
|
|
|
269 |
gr.update(visible=True, value=error_msg),
|
270 |
)
|
271 |
|
272 |
+
def evaluate(self, state_dict: PipelineStateDict, progress: gr.Progress = gr.Progress()):
|
273 |
"""Evaluate the bonus questions."""
|
274 |
try:
|
275 |
+
pipeline_state = PipelineState(**state_dict)
|
276 |
# Validate inputs
|
277 |
if not self.ds or not self.ds.num_rows:
|
278 |
return "No dataset loaded", None, None
|
|
|
314 |
return gr.skip(), gr.update(visible=True, value=error_msg)
|
315 |
|
316 |
def submit_model(
|
317 |
+
self,
|
318 |
+
model_name: str,
|
319 |
+
description: str,
|
320 |
+
state_dict: PipelineStateDict,
|
321 |
+
profile: gr.OAuthProfile = None,
|
322 |
):
|
323 |
"""Submit the model output."""
|
324 |
+
pipeline_state = PipelineState(**state_dict)
|
325 |
return submit.submit_model(model_name, description, pipeline_state.workflow, "bonus", profile)
|
326 |
|
327 |
def _setup_event_listeners(self):
|
|
|
340 |
outputs=[self.pipeline_selector],
|
341 |
)
|
342 |
|
343 |
+
pipeline_state = self.pipeline_interface.pipeline_state
|
344 |
+
pipeline_change = self.pipeline_interface.pipeline_change
|
345 |
self.load_btn.click(
|
346 |
fn=self.load_pipeline,
|
347 |
+
inputs=[self.pipeline_selector, pipeline_change],
|
348 |
+
outputs=[self.pipeline_selector, pipeline_state, pipeline_change, self.error_display],
|
|
|
|
|
|
|
349 |
)
|
350 |
+
self.pipeline_interface.add_triggers_for_pipeline_export([pipeline_state.change], pipeline_state)
|
351 |
|
352 |
self.run_btn.click(
|
353 |
self.pipeline_interface.validate_workflow,
|
354 |
inputs=[self.pipeline_interface.pipeline_state],
|
355 |
+
outputs=[],
|
356 |
).success(
|
357 |
self.single_run,
|
358 |
inputs=[
|
src/components/quizbowl/populate.py
CHANGED
@@ -4,7 +4,7 @@ import gradio as gr
|
|
4 |
from loguru import logger
|
5 |
|
6 |
from app_configs import UNSELECTED_PIPELINE_NAME
|
7 |
-
from components.
|
8 |
from display.formatting import styled_error
|
9 |
from submission import submit
|
10 |
|
@@ -24,7 +24,9 @@ def get_pipeline_names(competition_type: str, profile: gr.OAuthProfile | None) -
|
|
24 |
return all_names
|
25 |
|
26 |
|
27 |
-
def
|
|
|
|
|
28 |
if not model_name or model_name == UNSELECTED_PIPELINE_NAME:
|
29 |
return None
|
30 |
username, model_name = model_name.split("/")
|
@@ -35,4 +37,5 @@ def load_pipeline(competition_type: str, model_name: str, profile: gr.OAuthProfi
|
|
35 |
workflow = submission.workflow
|
36 |
else:
|
37 |
raise gr.Error("Authentication required. Please log in to view your submissions.")
|
38 |
-
|
|
|
|
4 |
from loguru import logger
|
5 |
|
6 |
from app_configs import UNSELECTED_PIPELINE_NAME
|
7 |
+
from components.structs import TossupWorkflow, Workflow
|
8 |
from display.formatting import styled_error
|
9 |
from submission import submit
|
10 |
|
|
|
24 |
return all_names
|
25 |
|
26 |
|
27 |
+
def load_workflow(
|
28 |
+
competition_type: str, model_name: str, profile: gr.OAuthProfile | None
|
29 |
+
) -> Workflow | TossupWorkflow | None:
|
30 |
if not model_name or model_name == UNSELECTED_PIPELINE_NAME:
|
31 |
return None
|
32 |
username, model_name = model_name.split("/")
|
|
|
37 |
workflow = submission.workflow
|
38 |
else:
|
39 |
raise gr.Error("Authentication required. Please log in to view your submissions.")
|
40 |
+
|
41 |
+
return workflow
|
src/components/quizbowl/tossup.py
CHANGED
@@ -11,6 +11,7 @@ from app_configs import UNSELECTED_PIPELINE_NAME
|
|
11 |
from components import commons
|
12 |
from components.model_pipeline.model_pipeline import PipelineInterface, PipelineState, PipelineUIState
|
13 |
from components.model_pipeline.tossup_pipeline import TossupPipelineInterface, TossupPipelineState
|
|
|
14 |
from display.formatting import styled_error
|
15 |
from submission import submit
|
16 |
from workflows.qb_agents import QuizBowlTossupAgent, TossupResult
|
@@ -174,13 +175,7 @@ def validate_model_step(model_step: ModelStep):
|
|
174 |
class TossupInterface:
|
175 |
"""Gradio interface for the Tossup mode."""
|
176 |
|
177 |
-
def __init__(
|
178 |
-
self,
|
179 |
-
app: gr.Blocks,
|
180 |
-
dataset: Dataset,
|
181 |
-
model_options: dict,
|
182 |
-
defaults: dict,
|
183 |
-
):
|
184 |
"""Initialize the Tossup interface."""
|
185 |
logger.info(f"Initializing Tossup interface with dataset size: {len(dataset)}")
|
186 |
self.ds = dataset
|
@@ -196,6 +191,7 @@ class TossupInterface:
|
|
196 |
self.pipeline_selector = commons.get_pipeline_selector([])
|
197 |
self.load_btn = gr.Button("⬇️ Import Pipeline", variant="secondary")
|
198 |
self.pipeline_interface = TossupPipelineInterface(
|
|
|
199 |
workflow,
|
200 |
simple=simple,
|
201 |
model_options=list(self.model_options.keys()),
|
@@ -255,9 +251,10 @@ class TossupInterface:
|
|
255 |
|
256 |
self._setup_event_listeners()
|
257 |
|
258 |
-
def validate_workflow(self,
|
259 |
"""Validate the workflow."""
|
260 |
try:
|
|
|
261 |
validate_workflow(pipeline_state.workflow)
|
262 |
except Exception as e:
|
263 |
raise gr.Error(f"Error validating workflow: {str(e)}")
|
@@ -293,20 +290,25 @@ class TossupInterface:
|
|
293 |
names = [UNSELECTED_PIPELINE_NAME] + populate.get_pipeline_names("tossup", profile)
|
294 |
return gr.update(choices=names, value=UNSELECTED_PIPELINE_NAME)
|
295 |
|
296 |
-
def load_pipeline(
|
|
|
|
|
297 |
try:
|
298 |
-
|
299 |
-
if
|
300 |
-
|
301 |
-
|
|
|
|
|
302 |
except Exception as e:
|
|
|
303 |
error_msg = styled_error(f"Error loading pipeline: {str(e)}")
|
304 |
-
return UNSELECTED_PIPELINE_NAME, gr.skip(), gr.update(visible=True, value=error_msg)
|
305 |
|
306 |
def single_run(
|
307 |
self,
|
308 |
question_id: int,
|
309 |
-
|
310 |
early_stop: bool = True,
|
311 |
) -> tuple[str, Any, Any]:
|
312 |
"""Run the agent in tossup mode with a system prompt."""
|
@@ -316,6 +318,7 @@ class TossupInterface:
|
|
316 |
if not self.ds or question_id < 0 or question_id >= len(self.ds):
|
317 |
return "Invalid question ID or dataset not loaded", None, None
|
318 |
example = self.ds[question_id]
|
|
|
319 |
outputs = self.get_model_outputs(example, pipeline_state, early_stop)
|
320 |
|
321 |
# Process results and prepare visualization data
|
@@ -343,13 +346,13 @@ class TossupInterface:
|
|
343 |
gr.update(visible=True, value=error_msg),
|
344 |
)
|
345 |
|
346 |
-
def evaluate(self,
|
347 |
"""Evaluate the tossup questions."""
|
348 |
try:
|
349 |
# Validate inputs
|
350 |
if not self.ds or not self.ds.num_rows:
|
351 |
return "No dataset loaded", None, None
|
352 |
-
|
353 |
buzz_counts = 0
|
354 |
correct_buzzes = 0
|
355 |
token_positions = []
|
@@ -389,9 +392,14 @@ class TossupInterface:
|
|
389 |
)
|
390 |
|
391 |
def submit_model(
|
392 |
-
self,
|
|
|
|
|
|
|
|
|
393 |
):
|
394 |
"""Submit the model output."""
|
|
|
395 |
return submit.submit_model(model_name, description, pipeline_state.workflow, "tossup", profile)
|
396 |
|
397 |
def _setup_event_listeners(self):
|
@@ -408,20 +416,19 @@ class TossupInterface:
|
|
408 |
outputs=[self.pipeline_selector],
|
409 |
)
|
410 |
|
411 |
-
|
|
|
412 |
self.load_btn.click(
|
413 |
fn=self.load_pipeline,
|
414 |
-
inputs=[self.pipeline_selector],
|
415 |
-
outputs=[self.pipeline_selector,
|
416 |
-
)
|
417 |
-
self.pipeline_interface.add_triggers_for_pipeline_export(
|
418 |
-
[self.new_loaded_pipeline_state.change], self.new_loaded_pipeline_state
|
419 |
)
|
|
|
420 |
|
421 |
self.run_btn.click(
|
422 |
self.pipeline_interface.validate_workflow,
|
423 |
inputs=[self.pipeline_interface.pipeline_state],
|
424 |
-
outputs=[
|
425 |
).success(
|
426 |
self.single_run,
|
427 |
inputs=[
|
@@ -454,9 +461,3 @@ class TossupInterface:
|
|
454 |
],
|
455 |
outputs=[self.submit_status],
|
456 |
)
|
457 |
-
|
458 |
-
self.hidden_input.change(
|
459 |
-
fn=update_tossup_plot,
|
460 |
-
inputs=[self.hidden_input, self.output_state],
|
461 |
-
outputs=[self.confidence_plot],
|
462 |
-
)
|
|
|
11 |
from components import commons
|
12 |
from components.model_pipeline.model_pipeline import PipelineInterface, PipelineState, PipelineUIState
|
13 |
from components.model_pipeline.tossup_pipeline import TossupPipelineInterface, TossupPipelineState
|
14 |
+
from components.typed_dicts import PipelineStateDict, TossupPipelineStateDict
|
15 |
from display.formatting import styled_error
|
16 |
from submission import submit
|
17 |
from workflows.qb_agents import QuizBowlTossupAgent, TossupResult
|
|
|
175 |
class TossupInterface:
|
176 |
"""Gradio interface for the Tossup mode."""
|
177 |
|
178 |
+
def __init__(self, app: gr.Blocks, dataset: Dataset, model_options: dict, defaults: dict):
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
"""Initialize the Tossup interface."""
|
180 |
logger.info(f"Initializing Tossup interface with dataset size: {len(dataset)}")
|
181 |
self.ds = dataset
|
|
|
191 |
self.pipeline_selector = commons.get_pipeline_selector([])
|
192 |
self.load_btn = gr.Button("⬇️ Import Pipeline", variant="secondary")
|
193 |
self.pipeline_interface = TossupPipelineInterface(
|
194 |
+
self.app,
|
195 |
workflow,
|
196 |
simple=simple,
|
197 |
model_options=list(self.model_options.keys()),
|
|
|
251 |
|
252 |
self._setup_event_listeners()
|
253 |
|
254 |
+
def validate_workflow(self, state_dict: TossupPipelineStateDict):
|
255 |
"""Validate the workflow."""
|
256 |
try:
|
257 |
+
pipeline_state = TossupPipelineState(**state_dict)
|
258 |
validate_workflow(pipeline_state.workflow)
|
259 |
except Exception as e:
|
260 |
raise gr.Error(f"Error validating workflow: {str(e)}")
|
|
|
290 |
names = [UNSELECTED_PIPELINE_NAME] + populate.get_pipeline_names("tossup", profile)
|
291 |
return gr.update(choices=names, value=UNSELECTED_PIPELINE_NAME)
|
292 |
|
293 |
+
def load_pipeline(
|
294 |
+
self, model_name: str, pipeline_change: bool, profile: gr.OAuthProfile | None
|
295 |
+
) -> tuple[str, PipelineStateDict, bool, dict]:
|
296 |
try:
|
297 |
+
workflow = populate.load_workflow("tossup", model_name, profile)
|
298 |
+
if workflow is None:
|
299 |
+
logger.warning(f"Could not load workflow for {model_name}")
|
300 |
+
return UNSELECTED_PIPELINE_NAME, gr.skip(), gr.skip(), gr.update(visible=False)
|
301 |
+
pipeline_state_dict = TossupPipelineState.from_workflow(workflow).model_dump()
|
302 |
+
return UNSELECTED_PIPELINE_NAME, pipeline_state_dict, not pipeline_change, gr.update(visible=True)
|
303 |
except Exception as e:
|
304 |
+
logger.exception(e)
|
305 |
error_msg = styled_error(f"Error loading pipeline: {str(e)}")
|
306 |
+
return UNSELECTED_PIPELINE_NAME, gr.skip(), gr.skip(), gr.update(visible=True, value=error_msg)
|
307 |
|
308 |
def single_run(
|
309 |
self,
|
310 |
question_id: int,
|
311 |
+
state_dict: TossupPipelineStateDict,
|
312 |
early_stop: bool = True,
|
313 |
) -> tuple[str, Any, Any]:
|
314 |
"""Run the agent in tossup mode with a system prompt."""
|
|
|
318 |
if not self.ds or question_id < 0 or question_id >= len(self.ds):
|
319 |
return "Invalid question ID or dataset not loaded", None, None
|
320 |
example = self.ds[question_id]
|
321 |
+
pipeline_state = TossupPipelineState(**state_dict)
|
322 |
outputs = self.get_model_outputs(example, pipeline_state, early_stop)
|
323 |
|
324 |
# Process results and prepare visualization data
|
|
|
346 |
gr.update(visible=True, value=error_msg),
|
347 |
)
|
348 |
|
349 |
+
def evaluate(self, state_dict: TossupPipelineStateDict, progress: gr.Progress = gr.Progress()):
|
350 |
"""Evaluate the tossup questions."""
|
351 |
try:
|
352 |
# Validate inputs
|
353 |
if not self.ds or not self.ds.num_rows:
|
354 |
return "No dataset loaded", None, None
|
355 |
+
pipeline_state = TossupPipelineState(**state_dict)
|
356 |
buzz_counts = 0
|
357 |
correct_buzzes = 0
|
358 |
token_positions = []
|
|
|
392 |
)
|
393 |
|
394 |
def submit_model(
|
395 |
+
self,
|
396 |
+
model_name: str,
|
397 |
+
description: str,
|
398 |
+
state_dict: TossupPipelineStateDict,
|
399 |
+
profile: gr.OAuthProfile = None,
|
400 |
):
|
401 |
"""Submit the model output."""
|
402 |
+
pipeline_state = TossupPipelineState(**state_dict)
|
403 |
return submit.submit_model(model_name, description, pipeline_state.workflow, "tossup", profile)
|
404 |
|
405 |
def _setup_event_listeners(self):
|
|
|
416 |
outputs=[self.pipeline_selector],
|
417 |
)
|
418 |
|
419 |
+
pipeline_state = self.pipeline_interface.pipeline_state
|
420 |
+
pipeline_change = self.pipeline_interface.pipeline_change
|
421 |
self.load_btn.click(
|
422 |
fn=self.load_pipeline,
|
423 |
+
inputs=[self.pipeline_selector, pipeline_change],
|
424 |
+
outputs=[self.pipeline_selector, pipeline_state, pipeline_change, self.error_display],
|
|
|
|
|
|
|
425 |
)
|
426 |
+
self.pipeline_interface.add_triggers_for_pipeline_export([pipeline_state.change], pipeline_state)
|
427 |
|
428 |
self.run_btn.click(
|
429 |
self.pipeline_interface.validate_workflow,
|
430 |
inputs=[self.pipeline_interface.pipeline_state],
|
431 |
+
outputs=[],
|
432 |
).success(
|
433 |
self.single_run,
|
434 |
inputs=[
|
|
|
461 |
],
|
462 |
outputs=[self.submit_status],
|
463 |
)
|
|
|
|
|
|
|
|
|
|
|
|
src/components/structs.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Literal
|
2 |
+
|
3 |
+
from pydantic import BaseModel, Field, model_validator
|
4 |
+
|
5 |
+
from workflows.structs import ModelStep, TossupWorkflow, Workflow
|
6 |
+
|
7 |
+
|
8 |
+
def make_step_id(step_number: int):
|
9 |
+
"""Make a step id from a step name."""
|
10 |
+
if step_number < 26:
|
11 |
+
return chr(ord("A") + step_number)
|
12 |
+
else:
|
13 |
+
# For more than 26 steps, use AA, AB, AC, etc.
|
14 |
+
first_char = chr(ord("A") + (step_number // 26) - 1)
|
15 |
+
second_char = chr(ord("A") + (step_number % 26))
|
16 |
+
return f"{first_char}{second_char}"
|
17 |
+
|
18 |
+
|
19 |
+
def make_step_number(step_id: str):
|
20 |
+
"""Make a step number from a step id."""
|
21 |
+
if len(step_id) == 1:
|
22 |
+
return ord(step_id) - ord("A")
|
23 |
+
else:
|
24 |
+
return (ord(step_id[0]) - ord("A")) * 26 + (ord(step_id[1]) - ord("A")) + 1
|
25 |
+
|
26 |
+
|
27 |
+
class ModelStepUIState(BaseModel):
|
28 |
+
"""Represents the UI state for a model step component."""
|
29 |
+
|
30 |
+
expanded: bool = True
|
31 |
+
active_tab: Literal["model-tab", "inputs-tab", "outputs-tab"] = "model-tab"
|
32 |
+
|
33 |
+
class Config:
|
34 |
+
frozen = True
|
35 |
+
|
36 |
+
def update(self, key: str, value: Any) -> "ModelStepUIState":
|
37 |
+
"""Update the UI state."""
|
38 |
+
return self.model_copy(update={key: value})
|
39 |
+
|
40 |
+
|
41 |
+
class PipelineUIState(BaseModel):
|
42 |
+
"""Represents the UI state for a pipeline component."""
|
43 |
+
|
44 |
+
step_ids: list[str] = Field(default_factory=list)
|
45 |
+
steps: dict[str, ModelStepUIState] = Field(default_factory=dict)
|
46 |
+
|
47 |
+
def model_post_init(self, __context: Any) -> None:
|
48 |
+
if not self.steps and self.step_ids:
|
49 |
+
self.steps = {step_id: ModelStepUIState() for step_id in self.step_ids}
|
50 |
+
return super().model_post_init(__context)
|
51 |
+
|
52 |
+
def get_step_position(self, step_id: str):
|
53 |
+
"""Get the position of a step in the pipeline."""
|
54 |
+
return next((i for i, step in enumerate(self.step_ids) if step == step_id), None)
|
55 |
+
|
56 |
+
@property
|
57 |
+
def n_steps(self) -> int:
|
58 |
+
"""Get the number of steps in the pipeline."""
|
59 |
+
return len(self.step_ids)
|
60 |
+
|
61 |
+
@classmethod
|
62 |
+
def from_workflow(cls, workflow: Workflow):
|
63 |
+
"""Create a pipeline UI state from a workflow."""
|
64 |
+
return PipelineUIState(
|
65 |
+
step_ids=list(workflow.steps.keys()),
|
66 |
+
steps={step_id: ModelStepUIState() for step_id in workflow.steps.keys()},
|
67 |
+
)
|
68 |
+
|
69 |
+
@classmethod
|
70 |
+
def from_pipeline_state(cls, pipeline_state: "PipelineState"):
|
71 |
+
"""Create a pipeline UI state from a pipeline state."""
|
72 |
+
return cls.from_workflow(pipeline_state.workflow)
|
73 |
+
|
74 |
+
# Update methods
|
75 |
+
|
76 |
+
def insert_step(self, step_id: str, position: int = -1) -> "PipelineUIState":
|
77 |
+
"""Insert a step into the pipeline at the given position."""
|
78 |
+
if position == -1:
|
79 |
+
position = len(self.step_ids)
|
80 |
+
self.step_ids.insert(position, step_id)
|
81 |
+
steps = self.steps | {step_id: ModelStepUIState()}
|
82 |
+
return self.model_copy(update={"step_ids": self.step_ids, "steps": steps})
|
83 |
+
|
84 |
+
def remove_step(self, step_id: str) -> "PipelineUIState":
|
85 |
+
"""Remove a step from the pipeline."""
|
86 |
+
self.step_ids.remove(step_id)
|
87 |
+
self.steps.pop(step_id)
|
88 |
+
return self.model_copy(update={"step_ids": self.step_ids, "steps": self.steps})
|
89 |
+
|
90 |
+
def update_step(self, step_id: str, ui_state: ModelStepUIState) -> "PipelineUIState":
|
91 |
+
"""Update a step in the pipeline."""
|
92 |
+
if step_id not in self.steps:
|
93 |
+
raise ValueError(f"Step {step_id} not found in pipeline")
|
94 |
+
return self.model_copy(update={"steps": self.steps | {step_id: ui_state}})
|
95 |
+
|
96 |
+
|
97 |
+
class PipelineState(BaseModel):
|
98 |
+
"""Represents the state for a pipeline component."""
|
99 |
+
|
100 |
+
workflow: Workflow
|
101 |
+
ui_state: PipelineUIState
|
102 |
+
|
103 |
+
@classmethod
|
104 |
+
def from_workflow(cls, workflow: Workflow):
|
105 |
+
"""Create a pipeline state from a workflow."""
|
106 |
+
return cls(workflow=workflow, ui_state=PipelineUIState.from_workflow(workflow))
|
107 |
+
|
108 |
+
def update_workflow(self, workflow: Workflow) -> "PipelineState":
|
109 |
+
return self.model_copy(update={"workflow": workflow})
|
110 |
+
|
111 |
+
def insert_step(self, position: int, step: ModelStep) -> "PipelineState":
|
112 |
+
if step.id in self.workflow.steps:
|
113 |
+
raise ValueError(f"Step {step.id} already exists in pipeline")
|
114 |
+
|
115 |
+
# Validate position
|
116 |
+
if position != -1 and (position < 0 or position > self.n_steps):
|
117 |
+
raise ValueError(f"Invalid position: {position}. Must be between 0 and {self.n_steps} or -1")
|
118 |
+
|
119 |
+
# Create a new workflow with updated steps
|
120 |
+
workflow = self.workflow.add_step(step)
|
121 |
+
|
122 |
+
self.ui_state = self.ui_state.insert_step(step.id, position)
|
123 |
+
|
124 |
+
# Return a new PipelineState with the updated workflow
|
125 |
+
return self.model_copy(update={"workflow": workflow, "ui_state": self.ui_state})
|
126 |
+
|
127 |
+
def remove_step(self, position: int) -> "PipelineState":
|
128 |
+
step_id = self.ui_state.step_ids.pop(position)
|
129 |
+
|
130 |
+
workflow = self.workflow.remove_step(step_id)
|
131 |
+
self.ui_state = self.ui_state.remove_step(step_id)
|
132 |
+
|
133 |
+
# Return a new PipelineState with the updated workflow
|
134 |
+
updated_outputs = self.get_output_variables_updates(workflow)
|
135 |
+
return self.model_copy(update={"workflow": workflow, "outputs": updated_outputs})
|
136 |
+
|
137 |
+
def update_step(self, step: ModelStep, ui_state: ModelStepUIState | None = None) -> "PipelineState":
|
138 |
+
"""Update a step in the pipeline."""
|
139 |
+
if step.id not in self.workflow.steps:
|
140 |
+
raise ValueError(f"Step {step.id} not found in pipeline")
|
141 |
+
steps = self.workflow.steps | {step.id: step}
|
142 |
+
workflow = self.workflow.model_copy(update={"steps": steps})
|
143 |
+
update = {"workflow": workflow, "outputs": self.get_output_variables_updates(workflow)}
|
144 |
+
if ui_state is not None:
|
145 |
+
update["ui_state"] = self.ui_state.update_step(step.id, ui_state)
|
146 |
+
return self.model_copy(update=update)
|
147 |
+
|
148 |
+
def get_output_variables_updates(self, new_workflow: Workflow) -> dict[str, str | None]:
|
149 |
+
available_variables = set(self.available_variables)
|
150 |
+
updated_outputs = new_workflow.outputs.copy()
|
151 |
+
for output_field in updated_outputs:
|
152 |
+
if updated_outputs[output_field] not in available_variables:
|
153 |
+
updated_outputs[output_field] = None
|
154 |
+
return updated_outputs
|
155 |
+
|
156 |
+
def update_output_variables_mapping(self) -> "PipelineState":
|
157 |
+
updated_outputs = self.get_output_variables_updates(self.workflow)
|
158 |
+
|
159 |
+
# Create a new workflow with updated outputs
|
160 |
+
workflow = self.workflow.model_copy(update={"outputs": updated_outputs})
|
161 |
+
|
162 |
+
# Return a new PipelineState with the updated workflow
|
163 |
+
return self.model_copy(update={"workflow": workflow})
|
164 |
+
|
165 |
+
def get_available_variables(self, model_step_id: str | None = None) -> list[str]:
|
166 |
+
"""Get all variables from all steps."""
|
167 |
+
available_variables = self.available_variables
|
168 |
+
if model_step_id is None:
|
169 |
+
return available_variables
|
170 |
+
prefix = f"{model_step_id}."
|
171 |
+
return [var for var in available_variables if not var.startswith(prefix)]
|
172 |
+
|
173 |
+
@property
|
174 |
+
def available_variables(self) -> list[str]:
|
175 |
+
return self.workflow.get_available_variables()
|
176 |
+
|
177 |
+
@property
|
178 |
+
def n_steps(self) -> int:
|
179 |
+
return len(self.workflow.steps)
|
180 |
+
|
181 |
+
def get_new_step_id(self) -> str:
|
182 |
+
"""Get a step ID for a new step."""
|
183 |
+
if not self.workflow.steps:
|
184 |
+
return "A"
|
185 |
+
else:
|
186 |
+
last_step_number = max(map(make_step_number, self.workflow.steps.keys()))
|
187 |
+
return make_step_id(last_step_number + 1)
|
188 |
+
|
189 |
+
|
190 |
+
class TossupPipelineState(PipelineState):
|
191 |
+
workflow: TossupWorkflow
|
src/components/typed_dicts.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union
|
2 |
+
|
3 |
+
|
4 |
+
# TypedDicts for workflows/structs.py
|
5 |
+
class InputFieldDict(TypedDict):
|
6 |
+
name: str
|
7 |
+
description: str
|
8 |
+
variable: str
|
9 |
+
func: Optional[str]
|
10 |
+
|
11 |
+
|
12 |
+
class OutputFieldDict(TypedDict):
|
13 |
+
name: str
|
14 |
+
type: Literal["str", "int", "float", "bool", "list[str]", "list[int]", "list[float]", "list[bool]"]
|
15 |
+
description: str
|
16 |
+
func: Optional[str]
|
17 |
+
|
18 |
+
|
19 |
+
class ModelStepDict(TypedDict):
|
20 |
+
id: str
|
21 |
+
name: str
|
22 |
+
model: str
|
23 |
+
provider: str
|
24 |
+
call_type: Literal["llm", "search", "python_func"]
|
25 |
+
temperature: Optional[float]
|
26 |
+
system_prompt: str
|
27 |
+
input_fields: List[InputFieldDict]
|
28 |
+
output_fields: List[OutputFieldDict]
|
29 |
+
|
30 |
+
|
31 |
+
class WorkflowDict(TypedDict):
|
32 |
+
inputs: List[str]
|
33 |
+
outputs: Dict[str, Optional[str]]
|
34 |
+
steps: Dict[str, ModelStepDict]
|
35 |
+
|
36 |
+
|
37 |
+
class BuzzerDict(TypedDict):
|
38 |
+
method: Literal["AND", "OR"]
|
39 |
+
confidence_threshold: float
|
40 |
+
prob_threshold: Optional[float]
|
41 |
+
|
42 |
+
|
43 |
+
class TossupWorkflowDict(WorkflowDict):
|
44 |
+
buzzer: BuzzerDict
|
45 |
+
|
46 |
+
|
47 |
+
# TypedDicts for components/model_pipeline/state_manager.py
|
48 |
+
class ModelStepUIStateDict(TypedDict):
|
49 |
+
expanded: bool
|
50 |
+
active_tab: Literal["model-tab", "inputs-tab", "outputs-tab"]
|
51 |
+
|
52 |
+
|
53 |
+
class PipelineUIStateDict(TypedDict):
|
54 |
+
step_ids: List[str]
|
55 |
+
steps: Dict[str, ModelStepUIStateDict]
|
56 |
+
|
57 |
+
|
58 |
+
class PipelineStateDict(TypedDict):
|
59 |
+
workflow: WorkflowDict
|
60 |
+
ui_state: PipelineUIStateDict
|
61 |
+
|
62 |
+
|
63 |
+
class TossupPipelineStateDict(PipelineStateDict):
|
64 |
+
workflow: TossupWorkflowDict
|
src/display/guide.py
CHANGED
@@ -2,7 +2,9 @@
|
|
2 |
|
3 |
GUIDE_MARKDOWN = """
|
4 |
# 🎯 Quizbowl Bot Guide
|
|
|
5 |
|
|
|
6 |
## Quick Start
|
7 |
1. Choose between Tossup or Bonus mode
|
8 |
2. Design your pipeline
|
@@ -35,7 +37,9 @@ GUIDE_MARKDOWN = """
|
|
35 |
- `answer`: Your predicted answer
|
36 |
- `confidence`: Score between 0-1
|
37 |
- `explanation`: Brief justification for human collaboration
|
|
|
38 |
|
|
|
39 |
## Building Your First Pipeline
|
40 |
|
41 |
### 1. Simple Pipeline (Recommended for First Submission)
|
|
|
2 |
|
3 |
GUIDE_MARKDOWN = """
|
4 |
# 🎯 Quizbowl Bot Guide
|
5 |
+
"""
|
6 |
|
7 |
+
QUICKSTART_MARKDOWN = """
|
8 |
## Quick Start
|
9 |
1. Choose between Tossup or Bonus mode
|
10 |
2. Design your pipeline
|
|
|
37 |
- `answer`: Your predicted answer
|
38 |
- `confidence`: Score between 0-1
|
39 |
- `explanation`: Brief justification for human collaboration
|
40 |
+
"""
|
41 |
|
42 |
+
BUILDING_MARKDOWN = """
|
43 |
## Building Your First Pipeline
|
44 |
|
45 |
### 1. Simple Pipeline (Recommended for First Submission)
|
src/submission/structs.py
CHANGED
@@ -37,6 +37,7 @@ class Submission(BaseModel):
|
|
37 |
description: str = Field(description="Detailed description of what the submission does")
|
38 |
competition_type: CompetitionType = Field(description="Type of competition (tossup or bonus)")
|
39 |
submission_type: SubmissionType = Field(description="Format of the submission (python file or workflow)")
|
|
|
40 |
workflow: Optional[Workflow] = Field(default=None, description="Optional workflow definition stored as JSON dict")
|
41 |
code: Optional[str] = Field(default=None, description="Optional code content for python file submissions")
|
42 |
status: SubmissionStatus = Field(description="Current status of the submission")
|
|
|
37 |
description: str = Field(description="Detailed description of what the submission does")
|
38 |
competition_type: CompetitionType = Field(description="Type of competition (tossup or bonus)")
|
39 |
submission_type: SubmissionType = Field(description="Format of the submission (python file or workflow)")
|
40 |
+
# TODO: Make workflow as json / yaml string instead of Workflow object
|
41 |
workflow: Optional[Workflow] = Field(default=None, description="Optional workflow definition stored as JSON dict")
|
42 |
code: Optional[str] = Field(default=None, description="Optional code content for python file submissions")
|
43 |
status: SubmissionStatus = Field(description="Current status of the submission")
|
src/workflows/structs.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
# %%
|
|
|
2 |
from enum import Enum
|
3 |
from typing import Any, Literal, Optional
|
4 |
|
@@ -48,6 +49,9 @@ class InputField(BaseModel):
|
|
48 |
# function to call on the input before passing it to the model
|
49 |
func: str | None = None
|
50 |
|
|
|
|
|
|
|
51 |
|
52 |
class OutputField(BaseModel):
|
53 |
"""
|
@@ -70,6 +74,9 @@ class OutputField(BaseModel):
|
|
70 |
# function to call on the output string from the model
|
71 |
func: str | None = None
|
72 |
|
|
|
|
|
|
|
73 |
|
74 |
class CallType(str, Enum):
|
75 |
LLM = "llm"
|
@@ -120,6 +127,7 @@ class ModelStep(BaseModel):
|
|
120 |
return [f"{self.id}.{field.name}" for field in self.output_fields if field.name]
|
121 |
|
122 |
def update(self, update: dict[str, Any]) -> "ModelStep":
|
|
|
123 |
return self.model_copy(update=update)
|
124 |
|
125 |
def update_property(self, field: str, value: Any) -> "ModelStep":
|
@@ -157,11 +165,16 @@ class ModelStep(BaseModel):
|
|
157 |
Returns:
|
158 |
A new ModelStep with the updated fields.
|
159 |
"""
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
|
|
|
|
|
|
|
|
|
|
165 |
|
166 |
def delete_field(self, field_type: FieldType, index: int) -> "ModelStep":
|
167 |
"""
|
@@ -174,10 +187,10 @@ class ModelStep(BaseModel):
|
|
174 |
Returns:
|
175 |
A new ModelStep with the updated fields.
|
176 |
"""
|
177 |
-
|
178 |
-
fields =
|
179 |
fields.pop(index)
|
180 |
-
return
|
181 |
|
182 |
|
183 |
class Workflow(BaseModel):
|
@@ -242,6 +255,22 @@ class Workflow(BaseModel):
|
|
242 |
variables.update(self.get_step_variables(step.id))
|
243 |
return list(variables)
|
244 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
|
246 |
class BuzzerMethod(str, Enum):
|
247 |
AND = "AND"
|
@@ -257,6 +286,7 @@ class Buzzer(BaseModel):
|
|
257 |
|
258 |
class Config:
|
259 |
use_enum_values = True
|
|
|
260 |
|
261 |
def run(self, confidence: float, prob: float | None = None, logprob: float | None = None) -> bool:
|
262 |
"""Run the buzzer logic."""
|
@@ -285,3 +315,7 @@ class TossupWorkflow(Workflow):
|
|
285 |
"""Workflow specialized for tossup questions with buzzing capability."""
|
286 |
|
287 |
buzzer: Buzzer
|
|
|
|
|
|
|
|
|
|
1 |
# %%
|
2 |
+
from copy import deepcopy
|
3 |
from enum import Enum
|
4 |
from typing import Any, Literal, Optional
|
5 |
|
|
|
49 |
# function to call on the input before passing it to the model
|
50 |
func: str | None = None
|
51 |
|
52 |
+
class Config:
|
53 |
+
frozen = True
|
54 |
+
|
55 |
|
56 |
class OutputField(BaseModel):
|
57 |
"""
|
|
|
74 |
# function to call on the output string from the model
|
75 |
func: str | None = None
|
76 |
|
77 |
+
class Config:
|
78 |
+
frozen = True
|
79 |
+
|
80 |
|
81 |
class CallType(str, Enum):
|
82 |
LLM = "llm"
|
|
|
127 |
return [f"{self.id}.{field.name}" for field in self.output_fields if field.name]
|
128 |
|
129 |
def update(self, update: dict[str, Any]) -> "ModelStep":
|
130 |
+
"""Returns a new copy with the updated properties."""
|
131 |
return self.model_copy(update=update)
|
132 |
|
133 |
def update_property(self, field: str, value: Any) -> "ModelStep":
|
|
|
165 |
Returns:
|
166 |
A new ModelStep with the updated fields.
|
167 |
"""
|
168 |
+
if field_type == "input":
|
169 |
+
fields = deepcopy(self.input_fields)
|
170 |
+
new_field = ModelStep.create_new_field(field_type, input_var)
|
171 |
+
fields.insert(index + 1, new_field) if index != -1 else fields.append(new_field)
|
172 |
+
return self.model_copy(update={"input_fields": fields})
|
173 |
+
else:
|
174 |
+
fields = deepcopy(self.output_fields)
|
175 |
+
new_field = ModelStep.create_new_field(field_type)
|
176 |
+
fields.insert(index + 1, new_field) if index != -1 else fields.append(new_field)
|
177 |
+
return self.model_copy(update={"output_fields": fields})
|
178 |
|
179 |
def delete_field(self, field_type: FieldType, index: int) -> "ModelStep":
|
180 |
"""
|
|
|
187 |
Returns:
|
188 |
A new ModelStep with the updated fields.
|
189 |
"""
|
190 |
+
fields = self.input_fields if field_type == "input" else self.output_fields
|
191 |
+
fields = deepcopy(fields)
|
192 |
fields.pop(index)
|
193 |
+
return self.model_copy(update={"input_fields": fields} if field_type == "input" else {"output_fields": fields})
|
194 |
|
195 |
|
196 |
class Workflow(BaseModel):
|
|
|
255 |
variables.update(self.get_step_variables(step.id))
|
256 |
return list(variables)
|
257 |
|
258 |
+
def get_model_selections(self) -> dict[str, str]:
|
259 |
+
"""Get all model selections for all steps."""
|
260 |
+
return {step_id: step.get_full_model_name() for step_id, step in self.steps.items()}
|
261 |
+
|
262 |
+
# Step update method
|
263 |
+
|
264 |
+
def add_step(self, step: ModelStep) -> "Workflow":
|
265 |
+
"""Add a step to the workflow."""
|
266 |
+
steps = self.steps | {step.id: step}
|
267 |
+
return self.model_copy(update={"steps": steps})
|
268 |
+
|
269 |
+
def remove_step(self, step_id: str) -> "Workflow":
|
270 |
+
"""Remove a step from the workflow."""
|
271 |
+
self.steps.pop(step_id)
|
272 |
+
return self.model_copy(update={"steps": self.steps})
|
273 |
+
|
274 |
|
275 |
class BuzzerMethod(str, Enum):
|
276 |
AND = "AND"
|
|
|
286 |
|
287 |
class Config:
|
288 |
use_enum_values = True
|
289 |
+
frozen = True
|
290 |
|
291 |
def run(self, confidence: float, prob: float | None = None, logprob: float | None = None) -> bool:
|
292 |
"""Run the buzzer logic."""
|
|
|
315 |
"""Workflow specialized for tossup questions with buzzing capability."""
|
316 |
|
317 |
buzzer: Buzzer
|
318 |
+
|
319 |
+
def update_buzzer(self, buzzer: Buzzer) -> "TossupWorkflow":
|
320 |
+
"""Update the buzzer."""
|
321 |
+
return self.model_copy(update={"buzzer": buzzer})
|