Maharshi Gor
commited on
Commit
·
5f3e7d5
1
Parent(s):
5d637a7
Made workflows a submodule
Browse files- .gitignore +4 -4
- .gitmodules +3 -0
- app.py +2 -2
- check_repos.py +0 -26
- shared/__init__.py +0 -0
- shared/workflows +1 -0
- src/components/model_pipeline/model_pipeline.py +2 -2
- src/components/model_pipeline/state_manager.py +6 -6
- src/components/model_pipeline/tossup_pipeline.py +3 -7
- src/components/model_step/model_step.py +1 -1
- src/components/model_step/state_manager.py +1 -1
- src/components/quizbowl/bonus.py +2 -2
- src/components/quizbowl/tossup.py +2 -2
- src/components/quizbowl/validation.py +2 -2
- src/components/structs.py +1 -1
- src/components/typed_dicts.py +1 -1
- src/populate.py +3 -5
- src/submission/_submit.py +119 -0
- src/submission/check_validity.py +99 -0
- src/submission/structs.py +1 -1
- src/submission/submit.py +2 -2
- src/workflows/README.md +0 -129
- src/workflows/configs.py +0 -56
- src/workflows/errors.py +0 -63
- src/workflows/executors.py +0 -673
- src/workflows/factory.py +0 -176
- src/workflows/llmcache.py +0 -488
- src/workflows/llms.py +0 -285
- src/workflows/qb_agents.py +0 -232
- src/workflows/structs.py +0 -370
- src/workflows/utils.py +0 -195
- src/workflows/validators.py +0 -615
.gitignore
CHANGED
@@ -16,8 +16,8 @@ __pycache__/
|
|
16 |
*ipynb
|
17 |
.vscode/
|
18 |
|
19 |
-
eval
|
20 |
-
eval-results/
|
21 |
-
eval-queue-bk/
|
22 |
-
eval-results-bk/
|
23 |
logs/
|
|
|
|
|
|
|
|
16 |
*ipynb
|
17 |
.vscode/
|
18 |
|
19 |
+
eval-*/
|
|
|
|
|
|
|
20 |
logs/
|
21 |
+
data/
|
22 |
+
outputs/
|
23 |
+
hf_cache/
|
.gitmodules
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[submodule "shared/workflows"]
|
2 |
+
path = shared/workflows
|
3 |
+
url = https://github.com/qanta-challenge/ai_workflows
|
app.py
CHANGED
@@ -30,8 +30,8 @@ from envs import (
|
|
30 |
RESULTS_REPO,
|
31 |
SERVER_REFRESH_INTERVAL,
|
32 |
)
|
33 |
-
from workflows import factory
|
34 |
-
from workflows.configs import AVAILABLE_MODELS
|
35 |
|
36 |
|
37 |
def restart_space():
|
|
|
30 |
RESULTS_REPO,
|
31 |
SERVER_REFRESH_INTERVAL,
|
32 |
)
|
33 |
+
from shared.workflows import factory
|
34 |
+
from shared.workflows.configs import AVAILABLE_MODELS
|
35 |
|
36 |
|
37 |
def restart_space():
|
check_repos.py
DELETED
@@ -1,26 +0,0 @@
|
|
1 |
-
from huggingface_hub import HfApi
|
2 |
-
|
3 |
-
from src.envs import LLM_CACHE_REPO, QUEUE_REPO, RESULTS_REPO, TOKEN
|
4 |
-
|
5 |
-
|
6 |
-
def check_and_create_dataset_repo(repo_id: str):
|
7 |
-
api = HfApi(token=TOKEN)
|
8 |
-
try:
|
9 |
-
api.repo_info(repo_id=repo_id, repo_type="dataset")
|
10 |
-
print(f"{repo_id} exists")
|
11 |
-
except Exception:
|
12 |
-
print(f"Creating {repo_id}")
|
13 |
-
api.create_repo(repo_id=repo_id, repo_type="dataset", exist_ok=True, private=True)
|
14 |
-
|
15 |
-
|
16 |
-
def check_and_create_repos():
|
17 |
-
print("1. QUEUE Repository")
|
18 |
-
check_and_create_dataset_repo(QUEUE_REPO)
|
19 |
-
print("2. RESULTS Repository")
|
20 |
-
check_and_create_dataset_repo(RESULTS_REPO)
|
21 |
-
print("3. LLM Cache Repository")
|
22 |
-
check_and_create_dataset_repo(LLM_CACHE_REPO)
|
23 |
-
|
24 |
-
|
25 |
-
if __name__ == "__main__":
|
26 |
-
check_and_create_repos()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
shared/__init__.py
ADDED
File without changes
|
shared/workflows
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Subproject commit 9d8bfae31f4db8b25165c950742d13e6c4e80de8
|
src/components/model_pipeline/model_pipeline.py
CHANGED
@@ -12,8 +12,8 @@ from components.model_pipeline.state_manager import (
|
|
12 |
from components.model_step.model_step import ModelStepComponent
|
13 |
from components.structs import ModelStepUIState, PipelineState, PipelineUIState
|
14 |
from components.utils import make_state
|
15 |
-
from workflows.structs import ModelStep, Workflow
|
16 |
-
from workflows.validators import WorkflowValidationError, WorkflowValidator
|
17 |
|
18 |
from .state_manager import get_output_panel_state
|
19 |
|
|
|
12 |
from components.model_step.model_step import ModelStepComponent
|
13 |
from components.structs import ModelStepUIState, PipelineState, PipelineUIState
|
14 |
from components.utils import make_state
|
15 |
+
from shared.workflows.structs import ModelStep, Workflow
|
16 |
+
from shared.workflows.validators import WorkflowValidationError, WorkflowValidator
|
17 |
|
18 |
from .state_manager import get_output_panel_state
|
19 |
|
src/components/model_pipeline/state_manager.py
CHANGED
@@ -13,8 +13,8 @@ from components import typed_dicts as td
|
|
13 |
from components import utils
|
14 |
from components.structs import ModelStepUIState, PipelineState, PipelineUIState, TossupPipelineState
|
15 |
from envs import DOCS_REPO_BRANCH, DOCS_REPO_URL
|
16 |
-
from workflows.factory import create_new_llm_step
|
17 |
-
from workflows.structs import Buzzer, BuzzerMethod, ModelStep, TossupWorkflow, Workflow
|
18 |
|
19 |
|
20 |
def get_output_panel_state(workflow: Workflow) -> dict:
|
@@ -113,9 +113,7 @@ class PipelineStateManager:
|
|
113 |
pipeline_change = not pipeline_change
|
114 |
return new_state_dict, pipeline_change
|
115 |
|
116 |
-
def move_down(
|
117 |
-
self, state_dict: td.PipelineStateDict, pipeline_change: bool, position: int
|
118 |
-
) -> td.PipelineStateDict:
|
119 |
"""Move a step down in the pipeline."""
|
120 |
new_state_dict, change = self._move_step(state_dict, position, "down")
|
121 |
if change:
|
@@ -189,7 +187,9 @@ class PipelineStateManager:
|
|
189 |
help_text = f"Refer to the <a href='{repo_files_url}/pipeline-schema.md' target='_blank'>documentation</a> for the correct pipeline schema."
|
190 |
else:
|
191 |
error_type = "Unexpected Error"
|
192 |
-
help_text =
|
|
|
|
|
193 |
|
194 |
return error_template.format(error_type=error_type, error_message=str(e), help_text=help_text)
|
195 |
|
|
|
13 |
from components import utils
|
14 |
from components.structs import ModelStepUIState, PipelineState, PipelineUIState, TossupPipelineState
|
15 |
from envs import DOCS_REPO_BRANCH, DOCS_REPO_URL
|
16 |
+
from shared.workflows.factory import create_new_llm_step
|
17 |
+
from shared.workflows.structs import Buzzer, BuzzerMethod, ModelStep, TossupWorkflow, Workflow
|
18 |
|
19 |
|
20 |
def get_output_panel_state(workflow: Workflow) -> dict:
|
|
|
113 |
pipeline_change = not pipeline_change
|
114 |
return new_state_dict, pipeline_change
|
115 |
|
116 |
+
def move_down(self, state_dict: td.PipelineStateDict, pipeline_change: bool, position: int) -> td.PipelineStateDict:
|
|
|
|
|
117 |
"""Move a step down in the pipeline."""
|
118 |
new_state_dict, change = self._move_step(state_dict, position, "down")
|
119 |
if change:
|
|
|
187 |
help_text = f"Refer to the <a href='{repo_files_url}/pipeline-schema.md' target='_blank'>documentation</a> for the correct pipeline schema."
|
188 |
else:
|
189 |
error_type = "Unexpected Error"
|
190 |
+
help_text = (
|
191 |
+
f"Please report this issue to us at <a href='{DOCS_REPO_URL}/issues' target='_blank'>GitHub Issues</a>."
|
192 |
+
)
|
193 |
|
194 |
return error_template.format(error_type=error_type, error_message=str(e), help_text=help_text)
|
195 |
|
src/components/model_pipeline/tossup_pipeline.py
CHANGED
@@ -6,15 +6,13 @@ from components import commons
|
|
6 |
from components.structs import PipelineUIState, TossupPipelineState
|
7 |
from components.typed_dicts import TossupPipelineStateDict
|
8 |
from display.formatting import tiny_styled_warning
|
9 |
-
from workflows.structs import Buzzer, TossupWorkflow
|
10 |
|
11 |
from .model_pipeline import PipelineInterface
|
12 |
from .state_manager import BasePipelineValidator, TossupPipelineStateManager
|
13 |
|
14 |
|
15 |
-
def toggleable_slider(
|
16 |
-
value, minimum, maximum, step, toggle_value=False, label=None, info=None, min_width=200, scale=1
|
17 |
-
):
|
18 |
with gr.Column(elem_classes="toggleable", min_width=min_width, scale=scale):
|
19 |
show_label = label is not None
|
20 |
checkbox = gr.Checkbox(label=label, value=toggle_value, container=False, info=info, show_label=show_label)
|
@@ -90,9 +88,7 @@ class TossupPipelineInterface(PipelineInterface):
|
|
90 |
),
|
91 |
)
|
92 |
|
93 |
-
def _render_buzzer_panel(
|
94 |
-
self, buzzer: Buzzer, prob_slider_supported: bool, selected_model_name: str | None = None
|
95 |
-
):
|
96 |
with gr.Row(elem_classes="control-panel"):
|
97 |
self.confidence_slider = gr.Slider(
|
98 |
minimum=0.0,
|
|
|
6 |
from components.structs import PipelineUIState, TossupPipelineState
|
7 |
from components.typed_dicts import TossupPipelineStateDict
|
8 |
from display.formatting import tiny_styled_warning
|
9 |
+
from shared.workflows.structs import Buzzer, TossupWorkflow
|
10 |
|
11 |
from .model_pipeline import PipelineInterface
|
12 |
from .state_manager import BasePipelineValidator, TossupPipelineStateManager
|
13 |
|
14 |
|
15 |
+
def toggleable_slider(value, minimum, maximum, step, toggle_value=False, label=None, info=None, min_width=200, scale=1):
|
|
|
|
|
16 |
with gr.Column(elem_classes="toggleable", min_width=min_width, scale=scale):
|
17 |
show_label = label is not None
|
18 |
checkbox = gr.Checkbox(label=label, value=toggle_value, container=False, info=info, show_label=show_label)
|
|
|
88 |
),
|
89 |
)
|
90 |
|
91 |
+
def _render_buzzer_panel(self, buzzer: Buzzer, prob_slider_supported: bool, selected_model_name: str | None = None):
|
|
|
|
|
92 |
with gr.Row(elem_classes="control-panel"):
|
93 |
self.confidence_slider = gr.Slider(
|
94 |
minimum=0.0,
|
src/components/model_step/model_step.py
CHANGED
@@ -7,8 +7,8 @@ from gradio.components import FormComponent
|
|
7 |
from app_configs import UNSELECTED_VAR_NAME
|
8 |
from components.model_pipeline.state_manager import ModelStepUIState, PipelineStateManager
|
9 |
from components.typed_dicts import PipelineStateDict
|
|
|
10 |
from utils import get_full_model_name
|
11 |
-
from workflows.structs import ModelStep
|
12 |
|
13 |
from .state_manager import ModelStepStateManager
|
14 |
from .ui_components import InputRowButtonGroup, OutputRowButtonGroup
|
|
|
7 |
from app_configs import UNSELECTED_VAR_NAME
|
8 |
from components.model_pipeline.state_manager import ModelStepUIState, PipelineStateManager
|
9 |
from components.typed_dicts import PipelineStateDict
|
10 |
+
from shared.workflows.structs import ModelStep
|
11 |
from utils import get_full_model_name
|
|
|
12 |
|
13 |
from .state_manager import ModelStepStateManager
|
14 |
from .ui_components import InputRowButtonGroup, OutputRowButtonGroup
|
src/components/model_step/state_manager.py
CHANGED
@@ -6,8 +6,8 @@ from loguru import logger
|
|
6 |
from app_configs import UNSELECTED_VAR_NAME
|
7 |
from components.model_pipeline.state_manager import ModelStepUIState
|
8 |
from components.utils import DIRECTIONS, move_item
|
|
|
9 |
from utils import get_model_and_provider
|
10 |
-
from workflows.structs import FieldType, ModelStep
|
11 |
|
12 |
|
13 |
class ModelStepStateManager:
|
|
|
6 |
from app_configs import UNSELECTED_VAR_NAME
|
7 |
from components.model_pipeline.state_manager import ModelStepUIState
|
8 |
from components.utils import DIRECTIONS, move_item
|
9 |
+
from shared.workflows.structs import FieldType, ModelStep
|
10 |
from utils import get_model_and_provider
|
|
|
11 |
|
12 |
|
13 |
class ModelStepStateManager:
|
src/components/quizbowl/bonus.py
CHANGED
@@ -12,9 +12,9 @@ from components import commons
|
|
12 |
from components.model_pipeline.model_pipeline import PipelineInterface, PipelineState
|
13 |
from components.typed_dicts import PipelineStateDict
|
14 |
from display.formatting import styled_error
|
|
|
|
|
15 |
from submission import submit
|
16 |
-
from workflows import factory
|
17 |
-
from workflows.qb_agents import QuizBowlBonusAgent
|
18 |
|
19 |
from . import populate, validation
|
20 |
from .plotting import create_bonus_confidence_plot, create_bonus_html
|
|
|
12 |
from components.model_pipeline.model_pipeline import PipelineInterface, PipelineState
|
13 |
from components.typed_dicts import PipelineStateDict
|
14 |
from display.formatting import styled_error
|
15 |
+
from shared.workflows import factory
|
16 |
+
from shared.workflows.qb_agents import QuizBowlBonusAgent
|
17 |
from submission import submit
|
|
|
|
|
18 |
|
19 |
from . import populate, validation
|
20 |
from .plotting import create_bonus_confidence_plot, create_bonus_html
|
src/components/quizbowl/tossup.py
CHANGED
@@ -12,9 +12,9 @@ from components import commons
|
|
12 |
from components.model_pipeline.tossup_pipeline import TossupPipelineInterface, TossupPipelineState
|
13 |
from components.typed_dicts import TossupInterfaceDefaults, TossupPipelineStateDict
|
14 |
from display.formatting import styled_error
|
|
|
|
|
15 |
from submission import submit
|
16 |
-
from workflows import factory
|
17 |
-
from workflows.qb_agents import QuizBowlTossupAgent, TossupResult
|
18 |
|
19 |
from . import populate, validation
|
20 |
from .plotting import (
|
|
|
12 |
from components.model_pipeline.tossup_pipeline import TossupPipelineInterface, TossupPipelineState
|
13 |
from components.typed_dicts import TossupInterfaceDefaults, TossupPipelineStateDict
|
14 |
from display.formatting import styled_error
|
15 |
+
from shared.workflows import factory
|
16 |
+
from shared.workflows.qb_agents import QuizBowlTossupAgent, TossupResult
|
17 |
from submission import submit
|
|
|
|
|
18 |
|
19 |
from . import populate, validation
|
20 |
from .plotting import (
|
src/components/quizbowl/validation.py
CHANGED
@@ -3,8 +3,8 @@ from typing import Literal
|
|
3 |
from app_configs import AVAILABLE_MODELS, CONFIGS
|
4 |
from components.structs import PipelineState, TossupPipelineState
|
5 |
from components.typed_dicts import PipelineStateDict, TossupPipelineStateDict
|
6 |
-
from workflows.structs import TossupWorkflow, Workflow
|
7 |
-
from workflows.validators import WorkflowValidationError, WorkflowValidator
|
8 |
|
9 |
|
10 |
def validate_workflow(
|
|
|
3 |
from app_configs import AVAILABLE_MODELS, CONFIGS
|
4 |
from components.structs import PipelineState, TossupPipelineState
|
5 |
from components.typed_dicts import PipelineStateDict, TossupPipelineStateDict
|
6 |
+
from shared.workflows.structs import TossupWorkflow, Workflow
|
7 |
+
from shared.workflows.validators import WorkflowValidationError, WorkflowValidator
|
8 |
|
9 |
|
10 |
def validate_workflow(
|
src/components/structs.py
CHANGED
@@ -2,7 +2,7 @@ from typing import Any, Literal
|
|
2 |
|
3 |
from pydantic import BaseModel, Field, model_validator
|
4 |
|
5 |
-
from workflows.structs import ModelStep, TossupWorkflow, Workflow
|
6 |
|
7 |
|
8 |
def make_step_id(step_number: int):
|
|
|
2 |
|
3 |
from pydantic import BaseModel, Field, model_validator
|
4 |
|
5 |
+
from shared.workflows.structs import ModelStep, TossupWorkflow, Workflow
|
6 |
|
7 |
|
8 |
def make_step_id(step_number: int):
|
src/components/typed_dicts.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union
|
2 |
|
3 |
-
from workflows.structs import TossupWorkflow, Workflow
|
4 |
|
5 |
|
6 |
# TypedDicts for workflows/structs.py
|
|
|
1 |
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union
|
2 |
|
3 |
+
from shared.workflows.structs import TossupWorkflow, Workflow
|
4 |
|
5 |
|
6 |
# TypedDicts for workflows/structs.py
|
src/populate.py
CHANGED
@@ -49,7 +49,7 @@ def get_tossups_leaderboard_df(repo_dir: str, eval_split: str) -> pd.DataFrame:
|
|
49 |
row["Win Rate w/ Human (Aggressive)"] = metrics["human_win_rate_strict"]
|
50 |
eval_results.append(row)
|
51 |
except Exception as e:
|
52 |
-
logger.error(f"Error processing model result: {e}")
|
53 |
continue
|
54 |
|
55 |
return pd.DataFrame(eval_results)
|
@@ -72,7 +72,7 @@ def get_bonuses_leaderboard_df(repo_dir: str, eval_split: str) -> pd.DataFrame:
|
|
72 |
}
|
73 |
eval_results.append(row)
|
74 |
except Exception as e:
|
75 |
-
logger.error(f"Error processing model result: {e}")
|
76 |
continue
|
77 |
|
78 |
return pd.DataFrame(eval_results)
|
@@ -96,9 +96,7 @@ def get_evaluation_queue_df(save_path: str, cols: list) -> list[pd.DataFrame]:
|
|
96 |
all_evals.append(data)
|
97 |
elif ".md" not in entry:
|
98 |
# this is a folder
|
99 |
-
sub_entries = [
|
100 |
-
e for e in os.listdir(f"{save_path}/{entry}") if os.path.isfile(e) and not e.startswith(".")
|
101 |
-
]
|
102 |
for sub_entry in sub_entries:
|
103 |
file_path = os.path.join(save_path, entry, sub_entry)
|
104 |
with open(file_path) as fp:
|
|
|
49 |
row["Win Rate w/ Human (Aggressive)"] = metrics["human_win_rate_strict"]
|
50 |
eval_results.append(row)
|
51 |
except Exception as e:
|
52 |
+
logger.error(f"Error processing model result '{username}/{model_name}': {e}")
|
53 |
continue
|
54 |
|
55 |
return pd.DataFrame(eval_results)
|
|
|
72 |
}
|
73 |
eval_results.append(row)
|
74 |
except Exception as e:
|
75 |
+
logger.error(f"Error processing model result '{username}/{model_name}': {e}")
|
76 |
continue
|
77 |
|
78 |
return pd.DataFrame(eval_results)
|
|
|
96 |
all_evals.append(data)
|
97 |
elif ".md" not in entry:
|
98 |
# this is a folder
|
99 |
+
sub_entries = [e for e in os.listdir(f"{save_path}/{entry}") if os.path.isfile(e) and not e.startswith(".")]
|
|
|
|
|
100 |
for sub_entry in sub_entries:
|
101 |
file_path = os.path.join(save_path, entry, sub_entry)
|
102 |
with open(file_path) as fp:
|
src/submission/_submit.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
from datetime import datetime, timezone
|
4 |
+
|
5 |
+
from src.display.formatting import styled_error, styled_message, styled_warning
|
6 |
+
from src.envs import API, EVAL_REQUESTS_PATH, TOKEN, QUEUE_REPO
|
7 |
+
from src.submission.check_validity import (
|
8 |
+
already_submitted_models,
|
9 |
+
check_model_card,
|
10 |
+
get_model_size,
|
11 |
+
is_model_on_hub,
|
12 |
+
)
|
13 |
+
|
14 |
+
REQUESTED_MODELS = None
|
15 |
+
USERS_TO_SUBMISSION_DATES = None
|
16 |
+
|
17 |
+
def add_new_eval(
|
18 |
+
model: str,
|
19 |
+
base_model: str,
|
20 |
+
revision: str,
|
21 |
+
precision: str,
|
22 |
+
weight_type: str,
|
23 |
+
model_type: str,
|
24 |
+
):
|
25 |
+
global REQUESTED_MODELS
|
26 |
+
global USERS_TO_SUBMISSION_DATES
|
27 |
+
if not REQUESTED_MODELS:
|
28 |
+
REQUESTED_MODELS, USERS_TO_SUBMISSION_DATES = already_submitted_models(EVAL_REQUESTS_PATH)
|
29 |
+
|
30 |
+
user_name = ""
|
31 |
+
model_path = model
|
32 |
+
if "/" in model:
|
33 |
+
user_name = model.split("/")[0]
|
34 |
+
model_path = model.split("/")[1]
|
35 |
+
|
36 |
+
precision = precision.split(" ")[0]
|
37 |
+
current_time = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
|
38 |
+
|
39 |
+
if model_type is None or model_type == "":
|
40 |
+
return styled_error("Please select a model type.")
|
41 |
+
|
42 |
+
# Does the model actually exist?
|
43 |
+
if revision == "":
|
44 |
+
revision = "main"
|
45 |
+
|
46 |
+
# Is the model on the hub?
|
47 |
+
if weight_type in ["Delta", "Adapter"]:
|
48 |
+
base_model_on_hub, error, _ = is_model_on_hub(model_name=base_model, revision=revision, token=TOKEN, test_tokenizer=True)
|
49 |
+
if not base_model_on_hub:
|
50 |
+
return styled_error(f'Base model "{base_model}" {error}')
|
51 |
+
|
52 |
+
if not weight_type == "Adapter":
|
53 |
+
model_on_hub, error, _ = is_model_on_hub(model_name=model, revision=revision, token=TOKEN, test_tokenizer=True)
|
54 |
+
if not model_on_hub:
|
55 |
+
return styled_error(f'Model "{model}" {error}')
|
56 |
+
|
57 |
+
# Is the model info correctly filled?
|
58 |
+
try:
|
59 |
+
model_info = API.model_info(repo_id=model, revision=revision)
|
60 |
+
except Exception:
|
61 |
+
return styled_error("Could not get your model information. Please fill it up properly.")
|
62 |
+
|
63 |
+
model_size = get_model_size(model_info=model_info, precision=precision)
|
64 |
+
|
65 |
+
# Were the model card and license filled?
|
66 |
+
try:
|
67 |
+
license = model_info.cardData["license"]
|
68 |
+
except Exception:
|
69 |
+
return styled_error("Please select a license for your model")
|
70 |
+
|
71 |
+
modelcard_OK, error_msg = check_model_card(model)
|
72 |
+
if not modelcard_OK:
|
73 |
+
return styled_error(error_msg)
|
74 |
+
|
75 |
+
# Seems good, creating the eval
|
76 |
+
print("Adding new eval")
|
77 |
+
|
78 |
+
eval_entry = {
|
79 |
+
"model": model,
|
80 |
+
"base_model": base_model,
|
81 |
+
"revision": revision,
|
82 |
+
"precision": precision,
|
83 |
+
"weight_type": weight_type,
|
84 |
+
"status": "PENDING",
|
85 |
+
"submitted_time": current_time,
|
86 |
+
"model_type": model_type,
|
87 |
+
"likes": model_info.likes,
|
88 |
+
"params": model_size,
|
89 |
+
"license": license,
|
90 |
+
"private": False,
|
91 |
+
}
|
92 |
+
|
93 |
+
# Check for duplicate submission
|
94 |
+
if f"{model}_{revision}_{precision}" in REQUESTED_MODELS:
|
95 |
+
return styled_warning("This model has been already submitted.")
|
96 |
+
|
97 |
+
print("Creating eval file")
|
98 |
+
OUT_DIR = f"{EVAL_REQUESTS_PATH}/{user_name}"
|
99 |
+
os.makedirs(OUT_DIR, exist_ok=True)
|
100 |
+
out_path = f"{OUT_DIR}/{model_path}_eval_request_False_{precision}_{weight_type}.json"
|
101 |
+
|
102 |
+
with open(out_path, "w") as f:
|
103 |
+
f.write(json.dumps(eval_entry))
|
104 |
+
|
105 |
+
print("Uploading eval file")
|
106 |
+
API.upload_file(
|
107 |
+
path_or_fileobj=out_path,
|
108 |
+
path_in_repo=out_path.split("eval-queue/")[1],
|
109 |
+
repo_id=QUEUE_REPO,
|
110 |
+
repo_type="dataset",
|
111 |
+
commit_message=f"Add {model} to eval queue",
|
112 |
+
)
|
113 |
+
|
114 |
+
# Remove the local file
|
115 |
+
os.remove(out_path)
|
116 |
+
|
117 |
+
return styled_message(
|
118 |
+
"Your request has been submitted to the evaluation queue!\nPlease wait for up to an hour for the model to show in the PENDING list."
|
119 |
+
)
|
src/submission/check_validity.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
from collections import defaultdict
|
5 |
+
from datetime import datetime, timedelta, timezone
|
6 |
+
|
7 |
+
import huggingface_hub
|
8 |
+
from huggingface_hub import ModelCard
|
9 |
+
from huggingface_hub.hf_api import ModelInfo
|
10 |
+
from transformers import AutoConfig
|
11 |
+
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
12 |
+
|
13 |
+
def check_model_card(repo_id: str) -> tuple[bool, str]:
|
14 |
+
"""Checks if the model card and license exist and have been filled"""
|
15 |
+
try:
|
16 |
+
card = ModelCard.load(repo_id)
|
17 |
+
except huggingface_hub.utils.EntryNotFoundError:
|
18 |
+
return False, "Please add a model card to your model to explain how you trained/fine-tuned it."
|
19 |
+
|
20 |
+
# Enforce license metadata
|
21 |
+
if card.data.license is None:
|
22 |
+
if not ("license_name" in card.data and "license_link" in card.data):
|
23 |
+
return False, (
|
24 |
+
"License not found. Please add a license to your model card using the `license` metadata or a"
|
25 |
+
" `license_name`/`license_link` pair."
|
26 |
+
)
|
27 |
+
|
28 |
+
# Enforce card content
|
29 |
+
if len(card.text) < 200:
|
30 |
+
return False, "Please add a description to your model card, it is too short."
|
31 |
+
|
32 |
+
return True, ""
|
33 |
+
|
34 |
+
def is_model_on_hub(model_name: str, revision: str, token: str = None, trust_remote_code=False, test_tokenizer=False) -> tuple[bool, str]:
|
35 |
+
"""Checks if the model model_name is on the hub, and whether it (and its tokenizer) can be loaded with AutoClasses."""
|
36 |
+
try:
|
37 |
+
config = AutoConfig.from_pretrained(model_name, revision=revision, trust_remote_code=trust_remote_code, token=token)
|
38 |
+
if test_tokenizer:
|
39 |
+
try:
|
40 |
+
tk = AutoTokenizer.from_pretrained(model_name, revision=revision, trust_remote_code=trust_remote_code, token=token)
|
41 |
+
except ValueError as e:
|
42 |
+
return (
|
43 |
+
False,
|
44 |
+
f"uses a tokenizer which is not in a transformers release: {e}",
|
45 |
+
None
|
46 |
+
)
|
47 |
+
except Exception as e:
|
48 |
+
return (False, "'s tokenizer cannot be loaded. Is your tokenizer class in a stable transformers release, and correctly configured?", None)
|
49 |
+
return True, None, config
|
50 |
+
|
51 |
+
except ValueError:
|
52 |
+
return (
|
53 |
+
False,
|
54 |
+
"needs to be launched with `trust_remote_code=True`. For safety reason, we do not allow these models to be automatically submitted to the leaderboard.",
|
55 |
+
None
|
56 |
+
)
|
57 |
+
|
58 |
+
except Exception as e:
|
59 |
+
return False, "was not found on hub!", None
|
60 |
+
|
61 |
+
|
62 |
+
def get_model_size(model_info: ModelInfo, precision: str):
|
63 |
+
"""Gets the model size from the configuration, or the model name if the configuration does not contain the information."""
|
64 |
+
try:
|
65 |
+
model_size = round(model_info.safetensors["total"] / 1e9, 3)
|
66 |
+
except (AttributeError, TypeError):
|
67 |
+
return 0 # Unknown model sizes are indicated as 0, see NUMERIC_INTERVALS in app.py
|
68 |
+
|
69 |
+
size_factor = 8 if (precision == "GPTQ" or "gptq" in model_info.modelId.lower()) else 1
|
70 |
+
model_size = size_factor * model_size
|
71 |
+
return model_size
|
72 |
+
|
73 |
+
def get_model_arch(model_info: ModelInfo):
|
74 |
+
"""Gets the model architecture from the configuration"""
|
75 |
+
return model_info.config.get("architectures", "Unknown")
|
76 |
+
|
77 |
+
def already_submitted_models(requested_models_dir: str) -> set[str]:
|
78 |
+
"""Gather a list of already submitted models to avoid duplicates"""
|
79 |
+
depth = 1
|
80 |
+
file_names = []
|
81 |
+
users_to_submission_dates = defaultdict(list)
|
82 |
+
|
83 |
+
for root, _, files in os.walk(requested_models_dir):
|
84 |
+
current_depth = root.count(os.sep) - requested_models_dir.count(os.sep)
|
85 |
+
if current_depth == depth:
|
86 |
+
for file in files:
|
87 |
+
if not file.endswith(".json"):
|
88 |
+
continue
|
89 |
+
with open(os.path.join(root, file), "r") as f:
|
90 |
+
info = json.load(f)
|
91 |
+
file_names.append(f"{info['model']}_{info['revision']}_{info['precision']}")
|
92 |
+
|
93 |
+
# Select organisation
|
94 |
+
if info["model"].count("/") == 0 or "submitted_time" not in info:
|
95 |
+
continue
|
96 |
+
organisation, _ = info["model"].split("/")
|
97 |
+
users_to_submission_dates[organisation].append(info["submitted_time"])
|
98 |
+
|
99 |
+
return set(file_names), users_to_submission_dates
|
src/submission/structs.py
CHANGED
@@ -3,7 +3,7 @@ from typing import Dict, List, Literal, Optional
|
|
3 |
|
4 |
from pydantic import BaseModel, Field
|
5 |
|
6 |
-
from workflows.structs import TossupWorkflow, Workflow
|
7 |
|
8 |
CompetitionType = Literal["tossup", "bonus"]
|
9 |
SubmissionType = Literal["python_file", "simple_workflow", "complex_workflow"]
|
|
|
3 |
|
4 |
from pydantic import BaseModel, Field
|
5 |
|
6 |
+
from shared.workflows.structs import TossupWorkflow, Workflow
|
7 |
|
8 |
CompetitionType = Literal["tossup", "bonus"]
|
9 |
SubmissionType = Literal["python_file", "simple_workflow", "complex_workflow"]
|
src/submission/submit.py
CHANGED
@@ -13,8 +13,8 @@ from loguru import logger
|
|
13 |
from app_configs import DAILY_SUBMISSION_LIMIT_PER_USER
|
14 |
from display.formatting import styled_error, styled_message
|
15 |
from envs import API, EVAL_REQUESTS_PATH, EXAMPLES_PATH, OWNER, QUEUE_REPO
|
|
|
16 |
from submission.structs import CompetitionType, Submission, SubmissionStatus
|
17 |
-
from workflows.structs import TossupWorkflow, Workflow
|
18 |
|
19 |
|
20 |
def get_user_submissions(username: str, competition_type: str, pattern: str = None) -> list[Submission]:
|
@@ -238,7 +238,7 @@ def load_submission(model_name: str, competition_type: CompetitionType, profile:
|
|
238 |
|
239 |
if __name__ == "__main__":
|
240 |
# Example usage
|
241 |
-
from workflows.factory import create_quizbowl_simple_step_initial_setup
|
242 |
|
243 |
# Create workflow
|
244 |
model_step = create_quizbowl_simple_step_initial_setup()
|
|
|
13 |
from app_configs import DAILY_SUBMISSION_LIMIT_PER_USER
|
14 |
from display.formatting import styled_error, styled_message
|
15 |
from envs import API, EVAL_REQUESTS_PATH, EXAMPLES_PATH, OWNER, QUEUE_REPO
|
16 |
+
from shared.workflows.structs import TossupWorkflow, Workflow
|
17 |
from submission.structs import CompetitionType, Submission, SubmissionStatus
|
|
|
18 |
|
19 |
|
20 |
def get_user_submissions(username: str, competition_type: str, pattern: str = None) -> list[Submission]:
|
|
|
238 |
|
239 |
if __name__ == "__main__":
|
240 |
# Example usage
|
241 |
+
from shared.workflows.factory import create_quizbowl_simple_step_initial_setup
|
242 |
|
243 |
# Create workflow
|
244 |
model_step = create_quizbowl_simple_step_initial_setup()
|
src/workflows/README.md
DELETED
@@ -1,129 +0,0 @@
|
|
1 |
-
# Workflows Subpackage
|
2 |
-
|
3 |
-
This subpackage provides a framework for defining, validating, and executing workflows composed of interconnected model steps with dependency management.
|
4 |
-
|
5 |
-
## Overview
|
6 |
-
|
7 |
-
The workflows subpackage enables the creation and execution of workflows where multiple model steps can be combined, with outputs from earlier steps feeding into inputs of later steps. The package handles dependency resolution, execution order, and error handling.
|
8 |
-
|
9 |
-
## Components
|
10 |
-
|
11 |
-
### `structs.py`
|
12 |
-
|
13 |
-
Contains the core data structures used throughout the workflow system:
|
14 |
-
|
15 |
-
- `InputField`: Represents an input field with name, description, and variable reference
|
16 |
-
- `OutputField`: Represents an output field with name, type, and description
|
17 |
-
- `ModelStep`: Represents a single step in a workflow with input fields, output fields, and model details
|
18 |
-
- `Workflow`: A collection of ModelSteps with their identifiers
|
19 |
-
- `TossupWorkflow`: Specialized workflow for quizbowl tossup questions with buzzing capability
|
20 |
-
|
21 |
-
### `configs.py`
|
22 |
-
|
23 |
-
Provides configuration settings and constants:
|
24 |
-
|
25 |
-
- `AVAILABLE_MODELS`: Supported model configurations from various providers
|
26 |
-
- `TYPE_MAP`: Mapping of supported field types to Python types
|
27 |
-
- `FUNCTION_MAP`: Built-in transformation functions for input/output processing
|
28 |
-
|
29 |
-
### `utils.py`
|
30 |
-
|
31 |
-
Provides utility functions for workflow operations:
|
32 |
-
|
33 |
-
- `create_dependency_graph`: Builds a dependency graph representing the execution order constraints
|
34 |
-
- `topological_sort`: Sorts steps in execution order based on their dependencies
|
35 |
-
- `detect_cycles`: Identifies cyclic dependencies in workflow definitions
|
36 |
-
|
37 |
-
### `executors.py`
|
38 |
-
|
39 |
-
Handles the execution of workflows:
|
40 |
-
|
41 |
-
- `execute_model_step`: Executes a single model step with input processing and output collection
|
42 |
-
- `execute_simple_workflow`: Handles single-step workflows
|
43 |
-
- `execute_multi_step_workflow`: Manages multi-step workflows with dependency resolution
|
44 |
-
- `execute_workflow`: Main entry point that routes to appropriate executor based on workflow complexity
|
45 |
-
|
46 |
-
### `validators.py`
|
47 |
-
|
48 |
-
Provides workflow validation functionality:
|
49 |
-
|
50 |
-
- `ValidationErrorType`: Enumeration of possible validation error types
|
51 |
-
- `WorkflowValidationError`: Base class for validation errors
|
52 |
-
- Validation functions for steps, DAGs, variables, and types
|
53 |
-
|
54 |
-
### `errors.py`
|
55 |
-
|
56 |
-
Defines custom exceptions for workflow-related errors:
|
57 |
-
|
58 |
-
- `WorkflowError`: Base class for workflow errors
|
59 |
-
- `CyclicDependencyError`: Raised when detecting cycles in the workflow graph
|
60 |
-
- `UnknownVariableError`: Raised when a step requires a variable that's not provided or produced
|
61 |
-
|
62 |
-
## Usage Example
|
63 |
-
|
64 |
-
```python
|
65 |
-
from workflows.structs import InputField, ModelStep, OutputField, Workflow
|
66 |
-
|
67 |
-
# Define a workflow with two steps
|
68 |
-
step1 = ModelStep(
|
69 |
-
id="step1",
|
70 |
-
model="gpt-4o-mini",
|
71 |
-
provider="OpenAI",
|
72 |
-
call_type="llm",
|
73 |
-
system_prompt="Step1 processing",
|
74 |
-
input_fields=[InputField(name="value", description="Input value", variable="input.value")],
|
75 |
-
output_fields=[OutputField(name="result", description="Processed result", type="str", func="upper")],
|
76 |
-
)
|
77 |
-
|
78 |
-
step2 = ModelStep(
|
79 |
-
id="step2",
|
80 |
-
model="gpt-4o-mini",
|
81 |
-
provider="OpenAI",
|
82 |
-
call_type="llm",
|
83 |
-
system_prompt="Step2 processing",
|
84 |
-
input_fields=[InputField(name="result", description="Result from step1", variable="step1.result")],
|
85 |
-
output_fields=[OutputField(name="final", description="Final output", type="str", func="lower")],
|
86 |
-
)
|
87 |
-
|
88 |
-
workflow = Workflow(
|
89 |
-
steps={"step1": step1, "step2": step2},
|
90 |
-
inputs=["input.value"],
|
91 |
-
outputs={"final": "step2.final"}
|
92 |
-
)
|
93 |
-
|
94 |
-
# Execute the workflow
|
95 |
-
from workflows.executors import execute_workflow
|
96 |
-
|
97 |
-
result = execute_workflow(
|
98 |
-
workflow=workflow,
|
99 |
-
input_values={"input.value": "Hello, World!"},
|
100 |
-
return_full_content=True,
|
101 |
-
logprob_step="step2"
|
102 |
-
)
|
103 |
-
|
104 |
-
# Access results
|
105 |
-
final_output = result["final_outputs"]["final"]
|
106 |
-
intermediate_results = result["intermediate_outputs"]
|
107 |
-
step_contents = result["step_contents"]
|
108 |
-
logprob = result["logprob"]
|
109 |
-
```
|
110 |
-
|
111 |
-
## Error Handling
|
112 |
-
|
113 |
-
The workflows system provides robust error handling:
|
114 |
-
|
115 |
-
- Detects cyclic dependencies in workflow definitions
|
116 |
-
- Validates input/output variable references
|
117 |
-
- Ensures all required inputs are provided
|
118 |
-
- Supports custom validation rules through the validation system
|
119 |
-
- Provides detailed error messages for debugging
|
120 |
-
|
121 |
-
## Extending the Workflows System
|
122 |
-
|
123 |
-
To extend the workflows system:
|
124 |
-
|
125 |
-
1. Add new model step types by extending the `ModelStep` class
|
126 |
-
2. Create custom field types by extending validation in the execution logic
|
127 |
-
3. Implement additional error types in `errors.py` for specialized error handling
|
128 |
-
4. Add new transformation functions to `FUNCTION_MAP` in `configs.py`
|
129 |
-
5. Create specialized workflow types by extending the `Workflow` class
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/workflows/configs.py
DELETED
@@ -1,56 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Configuration settings for the workflows package.
|
3 |
-
|
4 |
-
This module contains configuration settings and constants used across the workflows package,
|
5 |
-
including model configurations, workflow settings, and other package-wide constants.
|
6 |
-
"""
|
7 |
-
|
8 |
-
AVAILABLE_MODELS = {
|
9 |
-
"OpenAI/gpt-4o": {
|
10 |
-
"model": "gpt-4o-2024-11-20",
|
11 |
-
"logprobs": True,
|
12 |
-
},
|
13 |
-
"OpenAI/gpt-4o-mini": {
|
14 |
-
"model": "gpt-4o-mini-2024-07-18",
|
15 |
-
"logprobs": True,
|
16 |
-
},
|
17 |
-
"OpenAI/gpt-3.5-turbo": {
|
18 |
-
"model": "gpt-3.5-turbo-0125",
|
19 |
-
},
|
20 |
-
"Anthropic/claude-3-7-sonnet": {
|
21 |
-
"model": "claude-3-7-sonnet-20250219",
|
22 |
-
},
|
23 |
-
"Anthropic/claude-3-5-sonnet": {
|
24 |
-
"model": "claude-3-5-sonnet-20241022",
|
25 |
-
},
|
26 |
-
"Anthropic/claude-3-5-haiku": {
|
27 |
-
"model": "claude-3-5-haiku-20241022",
|
28 |
-
},
|
29 |
-
"Cohere/command-r": {
|
30 |
-
"model": "command-r-08-2024",
|
31 |
-
"logprobs": True,
|
32 |
-
},
|
33 |
-
"Cohere/command-r-plus": {
|
34 |
-
"model": "command-r-plus-08-2024",
|
35 |
-
"logprobs": True,
|
36 |
-
},
|
37 |
-
"Cohere/command-r7b": {
|
38 |
-
"model": "command-r7b-12-2024",
|
39 |
-
"logprobs": False,
|
40 |
-
},
|
41 |
-
}
|
42 |
-
|
43 |
-
# Function mapping for input/output transformations
|
44 |
-
TYPE_MAP = {
|
45 |
-
"str": str,
|
46 |
-
"int": int,
|
47 |
-
"float": float,
|
48 |
-
"bool": bool,
|
49 |
-
}
|
50 |
-
|
51 |
-
FUNCTION_MAP = {
|
52 |
-
"upper": str.upper,
|
53 |
-
"lower": str.lower,
|
54 |
-
"len": len,
|
55 |
-
"split": str.split,
|
56 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/workflows/errors.py
DELETED
@@ -1,63 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Custom exceptions for workflow validation and execution errors.
|
3 |
-
|
4 |
-
This module defines the exception hierarchy for the workflows package, enabling
|
5 |
-
specific error types to be raised and caught during workflow validation and execution.
|
6 |
-
Each exception provides detailed error messages to help diagnose and fix issues in
|
7 |
-
workflow definitions or execution.
|
8 |
-
|
9 |
-
Exception hierarchy:
|
10 |
-
- WorkflowError (base class)
|
11 |
-
- UnknownVariableError (missing variable reference)
|
12 |
-
- CyclicDependencyError (circular dependencies)
|
13 |
-
- FunctionNotFoundError (missing function reference)
|
14 |
-
"""
|
15 |
-
|
16 |
-
|
17 |
-
# Define custom exceptions for workflow errors
|
18 |
-
class WorkflowError(Exception):
|
19 |
-
"""
|
20 |
-
Base exception class for all workflow-related errors.
|
21 |
-
|
22 |
-
This is the parent class for all workflow-specific exceptions and can be used
|
23 |
-
to catch any error from the workflows package.
|
24 |
-
"""
|
25 |
-
|
26 |
-
pass
|
27 |
-
|
28 |
-
|
29 |
-
class UnknownVariableError(WorkflowError):
|
30 |
-
"""
|
31 |
-
Raised when a workflow step references a variable that doesn't exist.
|
32 |
-
|
33 |
-
This typically occurs when a step's input field references a variable that is neither
|
34 |
-
provided as an external input nor produced as an output by any previous step.
|
35 |
-
"""
|
36 |
-
|
37 |
-
def __init__(self, var: str):
|
38 |
-
super().__init__(f"Unknown variable referenced: {var}")
|
39 |
-
|
40 |
-
|
41 |
-
class CyclicDependencyError(WorkflowError):
|
42 |
-
"""
|
43 |
-
Raised when a cyclic dependency is detected in a workflow.
|
44 |
-
|
45 |
-
A cyclic dependency occurs when there is a circular reference in the workflow graph,
|
46 |
-
such as step A depending on step B, which depends on step A. Such workflows cannot
|
47 |
-
be executed because there's no valid order to process the steps.
|
48 |
-
"""
|
49 |
-
|
50 |
-
def __init__(self):
|
51 |
-
super().__init__("Cyclic dependency detected in workflow")
|
52 |
-
|
53 |
-
|
54 |
-
class FunctionNotFoundError(WorkflowError):
|
55 |
-
"""
|
56 |
-
Raised when a referenced function cannot be found during workflow execution.
|
57 |
-
|
58 |
-
This typically occurs when a step references a function that doesn't exist in
|
59 |
-
the available function registry or namespace.
|
60 |
-
"""
|
61 |
-
|
62 |
-
def __init__(self, func_name: str):
|
63 |
-
super().__init__(f"Function not found: {func_name}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/workflows/executors.py
DELETED
@@ -1,673 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Core workflow execution functionality.
|
3 |
-
|
4 |
-
This module handles the execution of defined workflows, including input processing,
|
5 |
-
dependency-based execution order, model calling, and output collection. It integrates
|
6 |
-
with the litellm library to handle model interactions.
|
7 |
-
|
8 |
-
Key components:
|
9 |
-
- Utility functions for input/output transformation
|
10 |
-
- Input processing and validation
|
11 |
-
- Model step execution with support for log probabilities
|
12 |
-
- Complete workflow execution with dependency resolution
|
13 |
-
- Support for both simple (single-step) and multi-step workflows
|
14 |
-
- Structured output collection with intermediate results
|
15 |
-
|
16 |
-
The module orchestrates the execution of steps in the correct order based on their
|
17 |
-
dependencies and manages the flow of data between steps. It supports:
|
18 |
-
- Full content tracking for debugging
|
19 |
-
- Log probability calculation for specific steps
|
20 |
-
- Flexible input/output transformations
|
21 |
-
- Error handling and validation
|
22 |
-
"""
|
23 |
-
|
24 |
-
from typing import Any, TypedDict
|
25 |
-
|
26 |
-
import pydantic
|
27 |
-
|
28 |
-
from .configs import FUNCTION_MAP, TYPE_MAP
|
29 |
-
from .errors import WorkflowError
|
30 |
-
from .llms import completion
|
31 |
-
from .structs import InputField, ModelStep, OutputField, Workflow
|
32 |
-
from .utils import create_dependency_graph, topological_sort
|
33 |
-
|
34 |
-
|
35 |
-
def get_type(type_str: str) -> type:
|
36 |
-
"""
|
37 |
-
Converts a type string to its corresponding Python type.
|
38 |
-
|
39 |
-
This function maps type strings to their actual Python type objects. It first checks
|
40 |
-
the TYPE_MAP dictionary for predefined mappings, and if not found, falls back to
|
41 |
-
evaluating the type string directly.
|
42 |
-
|
43 |
-
Args:
|
44 |
-
type_str (str): A string representation of a type (e.g., "str", "int", "list[str]")
|
45 |
-
|
46 |
-
Returns:
|
47 |
-
type: The corresponding Python type object
|
48 |
-
|
49 |
-
Note:
|
50 |
-
Uses eval() for non-predefined types, which has security implications if used
|
51 |
-
with untrusted input. This is intended for internal use with validated type strings.
|
52 |
-
"""
|
53 |
-
return TYPE_MAP.get(type_str, eval(type_str))
|
54 |
-
|
55 |
-
|
56 |
-
def create_processed_inputs(model_step: ModelStep, available_vars: dict[str, Any]) -> dict[str, Any]:
|
57 |
-
"""
|
58 |
-
Creates processed inputs for a model step.
|
59 |
-
|
60 |
-
This function extracts and processes the required inputs for a model step based on
|
61 |
-
its input field definitions. It retrieves values from the available variables dictionary
|
62 |
-
and applies any specified transformations.
|
63 |
-
|
64 |
-
Args:
|
65 |
-
model_step (ModelStep): The model step for which to create processed inputs.
|
66 |
-
available_vars (dict[str, Any]): Dictionary of variables available for use as inputs.
|
67 |
-
Keys are variable names, values are the variable values.
|
68 |
-
|
69 |
-
Returns:
|
70 |
-
dict[str, Any]: A dictionary of processed inputs ready for use by the model step.
|
71 |
-
Keys are input field names, values are the processed input values.
|
72 |
-
|
73 |
-
Raises:
|
74 |
-
WorkflowError: If a required variable is not found in available_vars,
|
75 |
-
or if a specified transformation function is not available.
|
76 |
-
|
77 |
-
Example:
|
78 |
-
>>> available_vars = {"step1.output": "Hello World"}
|
79 |
-
>>> create_processed_inputs(model_step, available_vars)
|
80 |
-
{"input_field_name": "HELLO WORLD"} # If upper transformation was specified
|
81 |
-
"""
|
82 |
-
processed_inputs: dict[str, Any] = {}
|
83 |
-
for input_field in model_step.input_fields:
|
84 |
-
var = input_field.variable
|
85 |
-
value = available_vars[var]
|
86 |
-
if input_field.func is not None:
|
87 |
-
func = FUNCTION_MAP.get(input_field.func)
|
88 |
-
func = func or eval(input_field.func)
|
89 |
-
value = func(value)
|
90 |
-
processed_inputs[input_field.name] = value
|
91 |
-
return processed_inputs
|
92 |
-
|
93 |
-
|
94 |
-
class ModelStepResult(TypedDict):
|
95 |
-
"""
|
96 |
-
Result of executing a model step.
|
97 |
-
|
98 |
-
This TypedDict contains the outputs and metadata from executing a single model step,
|
99 |
-
including the processed output values, the full response content, and log probability
|
100 |
-
information when requested.
|
101 |
-
|
102 |
-
Attributes:
|
103 |
-
outputs (dict[str, Any]): A dictionary of processed outputs from the model step,
|
104 |
-
with keys matching the output field names.
|
105 |
-
content (str | None): The full content of the model's response, only populated
|
106 |
-
if return_full_content is True.
|
107 |
-
logprob (float | None): The log probability of the model step output, only populated
|
108 |
-
if logprobs is True.
|
109 |
-
"""
|
110 |
-
|
111 |
-
# A dictionary of processed outputs from the model step,
|
112 |
-
# with keys matching the output field names.
|
113 |
-
outputs: dict[str, Any]
|
114 |
-
|
115 |
-
# The full content of the model step.
|
116 |
-
content: str | None
|
117 |
-
|
118 |
-
# The log probability of the model step output if requested.
|
119 |
-
logprob: float | None
|
120 |
-
|
121 |
-
|
122 |
-
class WorkflowOutput(TypedDict):
|
123 |
-
"""
|
124 |
-
Result of executing a complete workflow.
|
125 |
-
|
126 |
-
This TypedDict contains the outputs and metadata from executing a workflow,
|
127 |
-
including final outputs, intermediate values, step contents, and log probabilities.
|
128 |
-
|
129 |
-
Attributes:
|
130 |
-
final_outputs (dict[str, Any]): The final output values produced by the workflow,
|
131 |
-
with keys matching the names defined in workflow.outputs.
|
132 |
-
intermediate_outputs (dict[str, Any]): All computed values during workflow execution,
|
133 |
-
including both external inputs and outputs from all steps.
|
134 |
-
step_contents (dict[str, Any]): Full response content for each step, keyed by step ID.
|
135 |
-
Only populated if return_full_content is True.
|
136 |
-
logprob (float | None): The log probability of the specified step's output.
|
137 |
-
Only populated if logprob_step is specified.
|
138 |
-
"""
|
139 |
-
|
140 |
-
# A dictionary of the workflow's outputs, with keys matching the variables defined in workflow.outputs.
|
141 |
-
final_outputs: dict[str, Any]
|
142 |
-
|
143 |
-
# A dictionary of all computed values during workflow execution, including intermediate results.
|
144 |
-
intermediate_outputs: dict[str, Any]
|
145 |
-
|
146 |
-
# A dictionary of step contents, only populated if return_full_content is True.
|
147 |
-
step_contents: dict[str, Any]
|
148 |
-
|
149 |
-
# The log probability of the workflow output if requested.
|
150 |
-
logprob: float | None
|
151 |
-
|
152 |
-
|
153 |
-
# %%
|
154 |
-
def execute_model_step(
|
155 |
-
model_step: ModelStep,
|
156 |
-
available_vars: dict[str, Any],
|
157 |
-
return_full_content: bool = False,
|
158 |
-
logprobs: bool = False,
|
159 |
-
) -> ModelStepResult:
|
160 |
-
"""
|
161 |
-
Executes a model step using the provided available variables.
|
162 |
-
|
163 |
-
This function handles the complete execution of a model step, including:
|
164 |
-
1. Processing inputs using variable references and transformations
|
165 |
-
2. Constructing the appropriate prompt for the model
|
166 |
-
3. Calling the model via litellm with structured output
|
167 |
-
4. Processing and validating the model's response
|
168 |
-
5. Applying any output transformations
|
169 |
-
|
170 |
-
The function supports different providers and model types through the litellm
|
171 |
-
integration, allowing for a consistent interface regardless of the underlying model.
|
172 |
-
|
173 |
-
Args:
|
174 |
-
model_step (ModelStep): The model step to execute, containing model details,
|
175 |
-
input/output specifications, and system prompt.
|
176 |
-
available_vars (dict[str, Any]): A dictionary of all variables available to this step,
|
177 |
-
including outputs from previous steps and external inputs.
|
178 |
-
return_full_content (bool, optional): If True, includes the full model response content
|
179 |
-
in the result. Defaults to False.
|
180 |
-
logprobs (bool, optional): If True, calculates and returns log probability information
|
181 |
-
for the model response. Defaults to False.
|
182 |
-
|
183 |
-
Returns:
|
184 |
-
ModelStepResult: A TypedDict containing processed outputs, optional full content,
|
185 |
-
and optional log probability information.
|
186 |
-
|
187 |
-
Raises:
|
188 |
-
WorkflowError: If there's an error in input processing, model execution,
|
189 |
-
or output validation.
|
190 |
-
|
191 |
-
Example:
|
192 |
-
>>> step = ModelStep(
|
193 |
-
... id="summarize",
|
194 |
-
... model="gpt-3.5-turbo",
|
195 |
-
... provider="openai",
|
196 |
-
... call_type="llm",
|
197 |
-
... system_prompt="Summarize the text",
|
198 |
-
... input_fields=[InputField(name="text", variable="input_text", description="Text to summarize")],
|
199 |
-
... output_fields=[OutputField(name="summary", type="str", description="Summary of the text")]
|
200 |
-
... )
|
201 |
-
>>> result = execute_model_step(step, {"input_text": "Long text to be summarized..."})
|
202 |
-
>>> summary = result["outputs"]["summary"]
|
203 |
-
"""
|
204 |
-
# Ensure inputs are processed using the specified functions in input_fields.
|
205 |
-
processed_inputs = create_processed_inputs(model_step, available_vars)
|
206 |
-
|
207 |
-
# Construct the input prompt for the model
|
208 |
-
input_str = "\n".join(f"{k}: {v}" for k, v in processed_inputs.items())
|
209 |
-
step_result = f"Inputs: \n{input_str}"
|
210 |
-
|
211 |
-
# Define the expected output fields and their types
|
212 |
-
fields = {
|
213 |
-
field.name: (get_type(field.type), pydantic.Field(..., description=field.description))
|
214 |
-
for field in model_step.output_fields
|
215 |
-
}
|
216 |
-
ModelResponse = pydantic.create_model("ModelResponse", **fields)
|
217 |
-
|
218 |
-
# Execute the model step using litellm
|
219 |
-
api_response = completion(
|
220 |
-
model=f"{model_step.provider}/{model_step.model}",
|
221 |
-
system=model_step.system_prompt,
|
222 |
-
prompt=step_result,
|
223 |
-
response_format=ModelResponse,
|
224 |
-
temperature=model_step.temperature,
|
225 |
-
logprobs=logprobs,
|
226 |
-
)
|
227 |
-
|
228 |
-
# Map the parsed response to the output fields
|
229 |
-
outputs = {field.name: api_response["output"][field.name] for field in model_step.output_fields}
|
230 |
-
result = ModelStepResult(outputs=outputs, content=None, logprob=None)
|
231 |
-
if return_full_content:
|
232 |
-
result["content"] = api_response["content"]
|
233 |
-
if logprobs:
|
234 |
-
result["logprob"] = api_response.get("logprob")
|
235 |
-
return result
|
236 |
-
|
237 |
-
|
238 |
-
def execute_multi_step_workflow(
|
239 |
-
workflow: Workflow,
|
240 |
-
input_values: dict[str, Any],
|
241 |
-
return_full_content: bool = False,
|
242 |
-
logprob_step: str | None = None,
|
243 |
-
) -> WorkflowOutput:
|
244 |
-
"""
|
245 |
-
Execute the given workflow as a computational graph.
|
246 |
-
|
247 |
-
This function orchestrates the complete execution of a workflow by:
|
248 |
-
|
249 |
-
1. Validating and populating initial values using the provided external inputs
|
250 |
-
2. Building a dependency graph between workflow steps
|
251 |
-
3. Determining a valid execution order using topological sorting
|
252 |
-
4. Executing each step in the correct order, with inputs from previous steps
|
253 |
-
5. Collecting and returning the final outputs
|
254 |
-
|
255 |
-
The execution process ensures that all dependencies are satisfied before a step
|
256 |
-
is executed, and that the data flows correctly between steps according to the
|
257 |
-
variable references defined in each step's input fields.
|
258 |
-
|
259 |
-
Args:
|
260 |
-
workflow (Workflow): The workflow to execute, containing steps, their
|
261 |
-
dependencies, and input/output specifications.
|
262 |
-
input_values (dict[str, Any]): External input values to be used by the workflow.
|
263 |
-
Keys should match the required workflow.inputs.
|
264 |
-
return_full_content (bool, optional): If True, returns the full content of each step.
|
265 |
-
Defaults to False.
|
266 |
-
logprob_step (str, optional): The ID of the step to use for log probability calculation.
|
267 |
-
Defaults to None.
|
268 |
-
|
269 |
-
Returns:
|
270 |
-
WorkflowOutput: A dictionary of workflow outputs, including final outputs, intermediate outputs, and step contents.
|
271 |
-
|
272 |
-
Raises:
|
273 |
-
UnknownVariableError: If an input_field references a variable that is not
|
274 |
-
provided externally nor produced by any step.
|
275 |
-
CyclicDependencyError: If the workflow contains a circular dependency that
|
276 |
-
prevents a valid execution order.
|
277 |
-
FunctionNotFoundError: If a transformation function specified in input_fields.func
|
278 |
-
or output_fields.func is not available.
|
279 |
-
WorkflowError: For any other workflow-related errors, such as missing required inputs.
|
280 |
-
|
281 |
-
Example:
|
282 |
-
>>> workflow = Workflow(
|
283 |
-
... steps={
|
284 |
-
... "extract": ModelStep(...), # A step that extracts entities
|
285 |
-
... "analyze": ModelStep(...) # A step that analyzes the entities
|
286 |
-
... },
|
287 |
-
... inputs=["text"],
|
288 |
-
... outputs={"sentiment": "analyze.sentiment", "entities": "extract.entities"}
|
289 |
-
... )
|
290 |
-
>>> final_outputs, computed_values, step_contents = execute_workflow(workflow, {"text": "Apple is launching a new product tomorrow."})
|
291 |
-
>>> print(final_outputs["sentiment"])
|
292 |
-
"positive"
|
293 |
-
>>> print(final_outputs["entities"])
|
294 |
-
["Apple", "product"]
|
295 |
-
"""
|
296 |
-
# Step 1: Pre-populate computed values with external workflow inputs.
|
297 |
-
computed_values: dict[str, Any] = {}
|
298 |
-
for var in workflow.inputs:
|
299 |
-
if var not in input_values:
|
300 |
-
raise WorkflowError(f"Missing required workflow input: {var}")
|
301 |
-
computed_values[var] = input_values[var]
|
302 |
-
|
303 |
-
# Step 2: Build dependency graph among model steps.
|
304 |
-
# For each step, examine its input_fields. If an input is not in the pre-populated external inputs,
|
305 |
-
# then it is expected to be produced by some step. Otherwise, raise an error.
|
306 |
-
dependencies = create_dependency_graph(workflow, input_values)
|
307 |
-
|
308 |
-
# Step 3: Determine the execution order of the steps using topological sort.
|
309 |
-
# Raises an error if a cycle is detected.
|
310 |
-
execution_order = topological_sort(dependencies)
|
311 |
-
|
312 |
-
# Step 4: Execute steps in topological order.
|
313 |
-
step_contents: dict[str, Any] = {}
|
314 |
-
logprob = None
|
315 |
-
for step_id in execution_order:
|
316 |
-
step = workflow.steps[step_id]
|
317 |
-
return_logprobs = logprob_step == step_id
|
318 |
-
# Execute the step
|
319 |
-
result = execute_model_step(
|
320 |
-
step, computed_values, return_full_content=return_full_content, logprobs=return_logprobs
|
321 |
-
)
|
322 |
-
if return_logprobs:
|
323 |
-
logprob = result["logprob"]
|
324 |
-
if return_full_content:
|
325 |
-
step_contents[step_id] = result["content"]
|
326 |
-
outputs = {f"{step_id}.{k}": v for k, v in result["outputs"].items()}
|
327 |
-
computed_values.update(outputs)
|
328 |
-
|
329 |
-
# Step 5: Gather and return workflow outputs.
|
330 |
-
final_outputs: dict[str, Any] = {}
|
331 |
-
for target, var in workflow.outputs.items():
|
332 |
-
if var not in computed_values:
|
333 |
-
raise WorkflowError(
|
334 |
-
f"Workflow output variable {var} was not produced. Computed values: {computed_values.keys()}"
|
335 |
-
)
|
336 |
-
final_outputs[target] = computed_values[var]
|
337 |
-
|
338 |
-
return WorkflowOutput(
|
339 |
-
final_outputs=final_outputs,
|
340 |
-
intermediate_outputs=computed_values,
|
341 |
-
step_contents=step_contents,
|
342 |
-
logprob=logprob,
|
343 |
-
)
|
344 |
-
|
345 |
-
|
346 |
-
def execute_simple_workflow(
|
347 |
-
workflow: Workflow,
|
348 |
-
input_values: dict[str, Any],
|
349 |
-
return_full_content: bool = False,
|
350 |
-
logprob_step: bool | str = False,
|
351 |
-
) -> WorkflowOutput:
|
352 |
-
"""
|
353 |
-
Execute a simple workflow with a single step.
|
354 |
-
|
355 |
-
This is an optimized version of workflow execution for workflows containing only one step.
|
356 |
-
It bypasses the dependency graph building and topological sorting steps, providing a more
|
357 |
-
direct execution path for simple workflows.
|
358 |
-
|
359 |
-
Args:
|
360 |
-
workflow (Workflow): The workflow to execute, which must contain exactly one step.
|
361 |
-
input_values (dict[str, Any]): External input values to be used by the workflow.
|
362 |
-
Keys should match the required workflow.inputs.
|
363 |
-
return_full_content (bool, optional): If True, includes the full model response content
|
364 |
-
in the result. Defaults to False.
|
365 |
-
logprobs (bool, optional): If True, calculates and returns log probability information
|
366 |
-
for the model response. Defaults to False.
|
367 |
-
|
368 |
-
Returns:
|
369 |
-
WorkflowOutput: A TypedDict containing the workflow outputs, intermediate values,
|
370 |
-
optional step contents, and optional log probability information.
|
371 |
-
|
372 |
-
Raises:
|
373 |
-
WorkflowError: If the workflow has more than one step or if required inputs are missing.
|
374 |
-
|
375 |
-
Example:
|
376 |
-
>>> workflow = Workflow(
|
377 |
-
... steps={"extract": ModelStep(...)},
|
378 |
-
... inputs=["text"],
|
379 |
-
... outputs={"entities": "extract.entities"}
|
380 |
-
... )
|
381 |
-
>>> result = execute_simple_workflow(workflow, {"text": "Apple is launching a new product."})
|
382 |
-
>>> entities = result["final_outputs"]["entities"]
|
383 |
-
"""
|
384 |
-
if len(workflow.steps) != 1:
|
385 |
-
raise WorkflowError("Simple workflow must have exactly one step")
|
386 |
-
|
387 |
-
# Get the single step
|
388 |
-
step = list(workflow.steps.values())[0]
|
389 |
-
|
390 |
-
logprobs = logprob_step is True or logprob_step == step.id
|
391 |
-
|
392 |
-
# Validate inputs
|
393 |
-
for var in workflow.inputs:
|
394 |
-
if var not in input_values:
|
395 |
-
raise WorkflowError(f"Missing required workflow input: {var}")
|
396 |
-
|
397 |
-
# Execute the step
|
398 |
-
step_result = execute_model_step(step, input_values, return_full_content=return_full_content, logprobs=logprobs)
|
399 |
-
step_outputs = step_result["outputs"]
|
400 |
-
step_contents = {step.id: step_result["content"]} if return_full_content else {}
|
401 |
-
# Prepare the final outputs
|
402 |
-
final_outputs = {}
|
403 |
-
for target, var in workflow.outputs.items():
|
404 |
-
if var.startswith(f"{step.id}."):
|
405 |
-
output_key = var.split(".", 1)[1]
|
406 |
-
if output_key in step_outputs:
|
407 |
-
final_outputs[target] = step_outputs[output_key]
|
408 |
-
else:
|
409 |
-
raise WorkflowError(f"Workflow output variable {var} was not produced")
|
410 |
-
else:
|
411 |
-
raise WorkflowError(f"Invalid output mapping: {var} does not match step ID {step.id}")
|
412 |
-
|
413 |
-
# Prepare computed values (prefixed with step ID)
|
414 |
-
computed_values = input_values | {f"{step.id}.{k}": v for k, v in step_outputs.items()}
|
415 |
-
|
416 |
-
return WorkflowOutput(
|
417 |
-
final_outputs=final_outputs,
|
418 |
-
intermediate_outputs=computed_values,
|
419 |
-
step_contents=step_contents,
|
420 |
-
logprob=step_result.get("logprob"),
|
421 |
-
)
|
422 |
-
|
423 |
-
|
424 |
-
def execute_workflow(
|
425 |
-
workflow: Workflow,
|
426 |
-
input_values: dict[str, Any],
|
427 |
-
return_full_content: bool = False,
|
428 |
-
logprob_step: str | bool = False,
|
429 |
-
) -> WorkflowOutput:
|
430 |
-
"""
|
431 |
-
Main entry point for executing workflows of any complexity.
|
432 |
-
|
433 |
-
This function serves as a router that delegates to the appropriate specialized
|
434 |
-
execution function based on the complexity of the workflow:
|
435 |
-
- For single-step workflows, it calls execute_simple_workflow
|
436 |
-
- For multi-step workflows, it calls execute_multi_step_workflow
|
437 |
-
|
438 |
-
This abstraction allows callers to use a consistent interface regardless of
|
439 |
-
the workflow's complexity.
|
440 |
-
|
441 |
-
Args:
|
442 |
-
workflow (Workflow): The workflow to execute, containing steps, their
|
443 |
-
dependencies, and input/output specifications.
|
444 |
-
input_values (dict[str, Any]): External input values to be used by the workflow.
|
445 |
-
Keys should match the required workflow.inputs.
|
446 |
-
return_full_content (bool, optional): If True, includes the full model response
|
447 |
-
content in the result. Defaults to False.
|
448 |
-
logprob_step (str | bool, optional): Either a string with the ID of the step for which
|
449 |
-
to calculate log probability, or a boolean flag.
|
450 |
-
If False, no log probabilities are calculated.
|
451 |
-
Defaults to False.
|
452 |
-
|
453 |
-
Returns:
|
454 |
-
WorkflowOutput: A TypedDict containing the workflow outputs, intermediate values,
|
455 |
-
optional step contents, and optional log probability information.
|
456 |
-
|
457 |
-
Raises:
|
458 |
-
WorkflowError: For any workflow-related errors, such as missing required inputs,
|
459 |
-
circular dependencies, or invalid variable references.
|
460 |
-
|
461 |
-
Example:
|
462 |
-
>>> workflow = Workflow(
|
463 |
-
... steps={"extract": ModelStep(...), "analyze": ModelStep(...)},
|
464 |
-
... inputs=["text"],
|
465 |
-
... outputs={"sentiment": "analyze.sentiment"}
|
466 |
-
... )
|
467 |
-
>>> result = execute_workflow(
|
468 |
-
... workflow,
|
469 |
-
... {"text": "Apple is launching a new product."},
|
470 |
-
... return_full_content=True,
|
471 |
-
... logprob_step="analyze"
|
472 |
-
... )
|
473 |
-
>>> print(result["final_outputs"]["sentiment"])
|
474 |
-
"positive"
|
475 |
-
"""
|
476 |
-
if len(workflow.steps) > 1:
|
477 |
-
return execute_multi_step_workflow(workflow, input_values, return_full_content, logprob_step)
|
478 |
-
else:
|
479 |
-
return execute_simple_workflow(workflow, input_values, return_full_content, logprob_step)
|
480 |
-
|
481 |
-
|
482 |
-
def run_examples():
|
483 |
-
"""
|
484 |
-
Runs example workflows demonstrating key functionality and error handling.
|
485 |
-
|
486 |
-
This function creates and executes three different example workflows to showcase:
|
487 |
-
|
488 |
-
1. Successful workflow execution:
|
489 |
-
- A linear two-step workflow with proper dependency flow
|
490 |
-
- Input transformation using the 'upper' function
|
491 |
-
- Output transformation using the 'lower' function
|
492 |
-
- Proper variable passing between steps
|
493 |
-
|
494 |
-
2. Cyclic dependency detection:
|
495 |
-
- A workflow with two steps that depend on each other circularly
|
496 |
-
- Demonstrates the error handling for cyclic dependencies
|
497 |
-
- Shows how the system prevents infinite execution loops
|
498 |
-
|
499 |
-
3. Unknown variable detection:
|
500 |
-
- A workflow that references a variable not provided as input or by any step
|
501 |
-
- Demonstrates validation of variable references
|
502 |
-
- Shows error handling for missing dependencies
|
503 |
-
|
504 |
-
Each example prints its result or the error encountered, making this function
|
505 |
-
useful for testing and demonstration purposes.
|
506 |
-
|
507 |
-
Returns:
|
508 |
-
None: This function prints its results and doesn't return a value.
|
509 |
-
"""
|
510 |
-
print("Example 1: Successful Workflow Execution")
|
511 |
-
# Example 1: Simple linear workflow.
|
512 |
-
# External input "input.value" is provided. Two steps:
|
513 |
-
# - step1 takes "input.value" and produces "step1.result".
|
514 |
-
# - step2 uses "step1.result" and produces "step2.final".
|
515 |
-
from workflows.structs import ModelStep, Workflow
|
516 |
-
|
517 |
-
workflow_success = Workflow(
|
518 |
-
steps={
|
519 |
-
"step1": ModelStep(
|
520 |
-
id="step1",
|
521 |
-
model="gpt-4o-mini",
|
522 |
-
provider="OpenAI",
|
523 |
-
call_type="llm",
|
524 |
-
system_prompt="Step1 processing",
|
525 |
-
input_fields=[InputField(name="value", description="Input value", variable="input.value")],
|
526 |
-
output_fields=[OutputField(name="result", description="Processed result", type="str", func="upper")],
|
527 |
-
),
|
528 |
-
"step2": ModelStep(
|
529 |
-
id="step2",
|
530 |
-
model="gpt-4o-mini",
|
531 |
-
provider="OpenAI",
|
532 |
-
call_type="llm",
|
533 |
-
system_prompt="Step2 processing",
|
534 |
-
input_fields=[InputField(name="result", description="Result from step1", variable="step1.result")],
|
535 |
-
output_fields=[OutputField(name="final", description="Final output", type="str", func="lower")],
|
536 |
-
),
|
537 |
-
},
|
538 |
-
inputs=["input.value"],
|
539 |
-
outputs={"final": "step2.final"},
|
540 |
-
)
|
541 |
-
input_values_success = {"input.value": "Hello, World!"}
|
542 |
-
try:
|
543 |
-
outputs = execute_workflow(workflow_success, input_values_success)
|
544 |
-
print("Workflow outputs:", outputs)
|
545 |
-
except WorkflowError as e:
|
546 |
-
print("Workflow failed with error:", e)
|
547 |
-
|
548 |
-
print("\nExample 2: Cyclic Dependency Workflow")
|
549 |
-
# Example 2: Cyclic dependency.
|
550 |
-
# stepA depends on an output from stepB and vice versa.
|
551 |
-
workflow_cycle = Workflow(
|
552 |
-
steps={
|
553 |
-
"stepA": ModelStep(
|
554 |
-
id="stepA",
|
555 |
-
model="gpt-4o-mini",
|
556 |
-
provider="OpenAI",
|
557 |
-
call_type="llm",
|
558 |
-
system_prompt="StepA processing",
|
559 |
-
input_fields=[
|
560 |
-
InputField(name="input", description="Input from stepB", variable="stepB.output", func="identity")
|
561 |
-
],
|
562 |
-
output_fields=[OutputField(name="output", description="Output from A", type="str", func="upper")],
|
563 |
-
),
|
564 |
-
"stepB": ModelStep(
|
565 |
-
id="stepB",
|
566 |
-
model="gpt-4o-mini",
|
567 |
-
provider="OpenAI",
|
568 |
-
call_type="llm",
|
569 |
-
system_prompt="StepB processing",
|
570 |
-
input_fields=[
|
571 |
-
InputField(name="input", description="Input from stepA", variable="stepA.output", func="identity")
|
572 |
-
],
|
573 |
-
output_fields=[OutputField(name="output", description="Output from B", type="str", func="upper")],
|
574 |
-
),
|
575 |
-
},
|
576 |
-
inputs=[], # no external inputs
|
577 |
-
outputs={"output": "stepB.output"},
|
578 |
-
)
|
579 |
-
try:
|
580 |
-
outputs = execute_workflow(workflow_cycle, {})
|
581 |
-
print("Workflow outputs:", outputs)
|
582 |
-
except WorkflowError as e:
|
583 |
-
print("Workflow failed with error:", e)
|
584 |
-
|
585 |
-
print("\nExample 3: Unknown Variable Dependency Workflow")
|
586 |
-
# Example 3: A workflow that references a variable not provided as an input or produced by any step.
|
587 |
-
workflow_unknown = Workflow(
|
588 |
-
steps={
|
589 |
-
"stepX": ModelStep(
|
590 |
-
id="stepX",
|
591 |
-
model="gpt-4o-mini",
|
592 |
-
provider="OpenAI",
|
593 |
-
call_type="llm",
|
594 |
-
system_prompt="StepX processing",
|
595 |
-
input_fields=[
|
596 |
-
InputField(
|
597 |
-
name="input", description="Non-existent input", variable="nonexistent.value", func="identity"
|
598 |
-
)
|
599 |
-
],
|
600 |
-
output_fields=[OutputField(name="output", description="Output from X", type="str", func="upper")],
|
601 |
-
)
|
602 |
-
},
|
603 |
-
inputs=[], # no external inputs
|
604 |
-
outputs={"output": "stepX.output"},
|
605 |
-
)
|
606 |
-
try:
|
607 |
-
outputs = execute_workflow(workflow_unknown, {})
|
608 |
-
print("Workflow outputs:", outputs)
|
609 |
-
except WorkflowError as e:
|
610 |
-
print("Workflow failed with error:", e)
|
611 |
-
|
612 |
-
|
613 |
-
if __name__ == "__main__":
|
614 |
-
# create example of model_step
|
615 |
-
model_step = ModelStep(
|
616 |
-
id="step1",
|
617 |
-
model="gpt-4o-mini",
|
618 |
-
provider="OpenAI",
|
619 |
-
call_type="llm",
|
620 |
-
system_prompt="You are a simple NLP tool that takes a string, and a number N, and return the first N entities in the string, and the total count of entities in the string.",
|
621 |
-
input_fields=[
|
622 |
-
InputField(name="sentence", description="The sentence to process", variable="sentence", func=None),
|
623 |
-
InputField(name="n", description="The number of entities to return", variable="n", func=None),
|
624 |
-
],
|
625 |
-
output_fields=[
|
626 |
-
OutputField(
|
627 |
-
name="entities",
|
628 |
-
description="The first N entities in the string as a list of strings",
|
629 |
-
type="list[str]",
|
630 |
-
func=None,
|
631 |
-
),
|
632 |
-
OutputField(name="count", description="The total count of entities in the string", type="int", func=None),
|
633 |
-
],
|
634 |
-
)
|
635 |
-
|
636 |
-
processed_inputs = {"sentence": "Abdul Akbar is a good person, but Jesus is the son of God.", "n": 3}
|
637 |
-
processed_inputs = create_processed_inputs(model_step, processed_inputs)
|
638 |
-
print(processed_inputs)
|
639 |
-
|
640 |
-
run_examples()
|
641 |
-
|
642 |
-
# %%
|
643 |
-
|
644 |
-
# Example usage
|
645 |
-
if __name__ == "__main__":
|
646 |
-
# Define a simple model step
|
647 |
-
model_step = ModelStep(
|
648 |
-
id="step1",
|
649 |
-
model="gpt-4o-mini",
|
650 |
-
provider="OpenAI",
|
651 |
-
call_type="llm",
|
652 |
-
system_prompt="You are a simple NLP tool that takes a string, and a number N, and return the first N entities in the string, and the total count of entities in the string.",
|
653 |
-
input_fields=[
|
654 |
-
InputField(name="sentence", description="The sentence to process", variable="sentence", func=None),
|
655 |
-
InputField(name="n", description="The number of entities to return", variable="n", func=None),
|
656 |
-
],
|
657 |
-
output_fields=[
|
658 |
-
OutputField(
|
659 |
-
name="entities",
|
660 |
-
description="The first N entities in the string as a list of strings",
|
661 |
-
type="list[str]",
|
662 |
-
func=None,
|
663 |
-
),
|
664 |
-
OutputField(name="count", description="The total count of entities in the string", type="int", func=None),
|
665 |
-
],
|
666 |
-
)
|
667 |
-
|
668 |
-
# Define processed inputs
|
669 |
-
processed_inputs = {"sentence": "Abdul Akbar is a good person, but Jesus is the son of God.", "n": 3}
|
670 |
-
|
671 |
-
# Execute the model step
|
672 |
-
outputs = execute_model_step(model_step, processed_inputs)
|
673 |
-
print(outputs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/workflows/factory.py
DELETED
@@ -1,176 +0,0 @@
|
|
1 |
-
# %%
|
2 |
-
from .structs import (
|
3 |
-
Buzzer,
|
4 |
-
BuzzerMethod,
|
5 |
-
CallType,
|
6 |
-
InputField,
|
7 |
-
ModelStep,
|
8 |
-
OutputField,
|
9 |
-
TossupWorkflow,
|
10 |
-
Workflow,
|
11 |
-
)
|
12 |
-
|
13 |
-
INITIAL_SYS_PROMPT = """You are a helpful performant question answering bot.
|
14 |
-
Given a question clue, output your most likely guess in a couple words with a calibrated confidence for the guess.
|
15 |
-
"""
|
16 |
-
|
17 |
-
|
18 |
-
def create_empty_bonus_workflow():
|
19 |
-
return Workflow(
|
20 |
-
inputs=["leadin", "part"],
|
21 |
-
outputs={"answer": None, "confidence": None, "explanation": None},
|
22 |
-
steps={},
|
23 |
-
)
|
24 |
-
|
25 |
-
|
26 |
-
def create_empty_tossup_workflow():
|
27 |
-
return TossupWorkflow(
|
28 |
-
inputs=["question_text"],
|
29 |
-
outputs={"answer": None, "confidence": None},
|
30 |
-
steps={},
|
31 |
-
)
|
32 |
-
|
33 |
-
|
34 |
-
def create_first_step_input_fields() -> list[InputField]:
|
35 |
-
return [
|
36 |
-
InputField(
|
37 |
-
name="question",
|
38 |
-
description="The question text progressively revealed to the agent so far.",
|
39 |
-
variable="question_text",
|
40 |
-
)
|
41 |
-
]
|
42 |
-
|
43 |
-
|
44 |
-
def create_empty_input_field() -> list[InputField]:
|
45 |
-
return [InputField(name="", description="", variable="question_text")]
|
46 |
-
|
47 |
-
|
48 |
-
def create_quizbowl_simple_step_initial_setup():
|
49 |
-
return ModelStep(
|
50 |
-
id="simple_step",
|
51 |
-
name="Quizbowl Simple Step",
|
52 |
-
model="",
|
53 |
-
provider="",
|
54 |
-
temperature=0.7,
|
55 |
-
call_type="llm",
|
56 |
-
system_prompt=INITIAL_SYS_PROMPT,
|
57 |
-
input_fields=[
|
58 |
-
InputField(name="question", description="The question to answer", variable="question"),
|
59 |
-
],
|
60 |
-
output_fields=[
|
61 |
-
OutputField(name="answer", description="The most likely answer", type="str"),
|
62 |
-
OutputField(name="confidence", description="The confidence of the answer", type="float"),
|
63 |
-
],
|
64 |
-
)
|
65 |
-
|
66 |
-
|
67 |
-
def create_new_llm_step(step_id: str, name: str) -> ModelStep:
|
68 |
-
return ModelStep(
|
69 |
-
id=step_id,
|
70 |
-
name=name,
|
71 |
-
model="gpt-4o",
|
72 |
-
provider="OpenAI",
|
73 |
-
call_type="llm",
|
74 |
-
temperature=0.7,
|
75 |
-
system_prompt="",
|
76 |
-
input_fields=create_empty_input_field(),
|
77 |
-
output_fields=[OutputField(name="", description="")],
|
78 |
-
)
|
79 |
-
|
80 |
-
|
81 |
-
def create_first_llm_step() -> ModelStep:
|
82 |
-
return ModelStep(
|
83 |
-
id="A",
|
84 |
-
name="",
|
85 |
-
model="gpt-4o",
|
86 |
-
provider="OpenAI",
|
87 |
-
call_type="llm",
|
88 |
-
temperature=0.7,
|
89 |
-
system_prompt="",
|
90 |
-
input_fields=[create_first_step_input_fields()],
|
91 |
-
output_fields=[OutputField(name="", description="")],
|
92 |
-
)
|
93 |
-
|
94 |
-
|
95 |
-
def create_simple_qb_tossup_workflow():
|
96 |
-
return TossupWorkflow(
|
97 |
-
inputs=["question_text"],
|
98 |
-
outputs={"answer": "A.answer", "confidence": "A.confidence"},
|
99 |
-
steps={
|
100 |
-
"A": ModelStep(
|
101 |
-
id="A",
|
102 |
-
name="Tossup Agent",
|
103 |
-
model="gpt-4o-mini",
|
104 |
-
provider="OpenAI",
|
105 |
-
call_type="llm",
|
106 |
-
temperature=0.3,
|
107 |
-
system_prompt="You are a helpful assistant that can answer questions.",
|
108 |
-
input_fields=[InputField(name="question", description="The question text", variable="question_text")],
|
109 |
-
output_fields=[
|
110 |
-
OutputField(
|
111 |
-
name="answer",
|
112 |
-
description="The best guess at the answer to the question",
|
113 |
-
type="str",
|
114 |
-
),
|
115 |
-
OutputField(
|
116 |
-
name="confidence",
|
117 |
-
description="The confidence in the answer, ranging from 0 to 1 in increments of 0.05.",
|
118 |
-
type="float",
|
119 |
-
),
|
120 |
-
],
|
121 |
-
)
|
122 |
-
},
|
123 |
-
buzzer=Buzzer(
|
124 |
-
confidence_threshold=0.75,
|
125 |
-
prob_threshold=None,
|
126 |
-
method=BuzzerMethod.AND,
|
127 |
-
),
|
128 |
-
)
|
129 |
-
|
130 |
-
|
131 |
-
BONUS_SYS_PROMPT = """You are a quizbowl player answering bonus questions. For each part:
|
132 |
-
1. Read the leadin and part carefully
|
133 |
-
2. Provide a concise answer
|
134 |
-
3. Rate your confidence (0-1)
|
135 |
-
4. Explain your reasoning
|
136 |
-
|
137 |
-
Format your response as:
|
138 |
-
ANSWER: <your answer>
|
139 |
-
CONFIDENCE: <0-1>
|
140 |
-
EXPLANATION: <your reasoning>"""
|
141 |
-
|
142 |
-
|
143 |
-
def create_simple_qb_bonus_workflow() -> Workflow:
|
144 |
-
"""Create a simple model step for bonus questions."""
|
145 |
-
return Workflow(
|
146 |
-
inputs=["leadin", "part"],
|
147 |
-
outputs={"answer": "A.answer", "confidence": "A.confidence", "explanation": "A.explanation"},
|
148 |
-
steps={
|
149 |
-
"A": ModelStep(
|
150 |
-
id="A",
|
151 |
-
name="Bonus Agent",
|
152 |
-
model="gpt-4o-mini",
|
153 |
-
provider="OpenAI",
|
154 |
-
temperature=0.3,
|
155 |
-
call_type=CallType.LLM,
|
156 |
-
system_prompt=BONUS_SYS_PROMPT,
|
157 |
-
input_fields=[
|
158 |
-
InputField(
|
159 |
-
name="question_leadin",
|
160 |
-
description="The leadin text for the bonus question",
|
161 |
-
variable="leadin",
|
162 |
-
),
|
163 |
-
InputField(
|
164 |
-
name="question_part",
|
165 |
-
description="The specific part text to answer",
|
166 |
-
variable="part",
|
167 |
-
),
|
168 |
-
],
|
169 |
-
output_fields=[
|
170 |
-
OutputField(name="answer", description="The predicted answer", type="str"),
|
171 |
-
OutputField(name="confidence", description="Confidence in the answer (0-1)", type="float"),
|
172 |
-
OutputField(name="explanation", description="Short explanation for the answer", type="str"),
|
173 |
-
],
|
174 |
-
)
|
175 |
-
},
|
176 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/workflows/llmcache.py
DELETED
@@ -1,488 +0,0 @@
|
|
1 |
-
import hashlib
|
2 |
-
import json
|
3 |
-
import os
|
4 |
-
import sqlite3
|
5 |
-
import threading
|
6 |
-
import time
|
7 |
-
from pathlib import Path
|
8 |
-
from typing import Any, Optional
|
9 |
-
|
10 |
-
from datasets import Dataset, load_dataset, load_from_disk
|
11 |
-
from huggingface_hub import snapshot_download
|
12 |
-
from loguru import logger
|
13 |
-
|
14 |
-
|
15 |
-
def load_dataset_from_hf(repo_id, local_dir):
|
16 |
-
snapshot_download(
|
17 |
-
repo_id=repo_id,
|
18 |
-
local_dir=local_dir,
|
19 |
-
repo_type="dataset",
|
20 |
-
tqdm_class=None,
|
21 |
-
etag_timeout=30,
|
22 |
-
token=os.environ["HF_TOKEN"],
|
23 |
-
)
|
24 |
-
return load_dataset(repo_id)
|
25 |
-
|
26 |
-
|
27 |
-
class CacheDB:
|
28 |
-
"""Handles database operations for storing and retrieving cache entries."""
|
29 |
-
|
30 |
-
def __init__(self, db_path: Path):
|
31 |
-
"""Initialize database connection.
|
32 |
-
|
33 |
-
Args:
|
34 |
-
db_path: Path to SQLite database file
|
35 |
-
"""
|
36 |
-
self.db_path = db_path
|
37 |
-
self.lock = threading.Lock()
|
38 |
-
|
39 |
-
# Initialize the database
|
40 |
-
try:
|
41 |
-
self.initialize_db()
|
42 |
-
except Exception as e:
|
43 |
-
logger.exception(f"Failed to initialize database: {e}")
|
44 |
-
logger.warning(f"Please provide a different filepath or remove the file at {self.db_path}")
|
45 |
-
raise
|
46 |
-
|
47 |
-
def initialize_db(self) -> None:
|
48 |
-
"""Initialize SQLite database with the required table."""
|
49 |
-
# Check if database file already exists
|
50 |
-
if self.db_path.exists():
|
51 |
-
self._verify_existing_db()
|
52 |
-
else:
|
53 |
-
self._create_new_db()
|
54 |
-
|
55 |
-
def _verify_existing_db(self) -> None:
|
56 |
-
"""Verify and repair an existing database if needed."""
|
57 |
-
try:
|
58 |
-
with sqlite3.connect(self.db_path) as conn:
|
59 |
-
cursor = conn.cursor()
|
60 |
-
self._ensure_table_exists(cursor)
|
61 |
-
self._verify_schema(cursor)
|
62 |
-
self._ensure_index_exists(cursor)
|
63 |
-
conn.commit()
|
64 |
-
logger.info(f"Using existing SQLite database at {self.db_path}")
|
65 |
-
except Exception as e:
|
66 |
-
logger.exception(f"Database corruption detected: {e}")
|
67 |
-
raise ValueError(f"Corrupted database at {self.db_path}: {str(e)}")
|
68 |
-
|
69 |
-
def _create_new_db(self) -> None:
|
70 |
-
"""Create a new database with the required schema."""
|
71 |
-
try:
|
72 |
-
with sqlite3.connect(self.db_path) as conn:
|
73 |
-
cursor = conn.cursor()
|
74 |
-
self._create_table(cursor)
|
75 |
-
self._ensure_index_exists(cursor)
|
76 |
-
conn.commit()
|
77 |
-
logger.info(f"Initialized new SQLite database at {self.db_path}")
|
78 |
-
except Exception as e:
|
79 |
-
logger.exception(f"Failed to initialize SQLite database: {e}")
|
80 |
-
raise
|
81 |
-
|
82 |
-
def _ensure_table_exists(self, cursor) -> None:
|
83 |
-
"""Check if the llm_cache table exists and create it if not."""
|
84 |
-
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='llm_cache'")
|
85 |
-
if not cursor.fetchone():
|
86 |
-
self._create_table(cursor)
|
87 |
-
logger.info("Created missing llm_cache table")
|
88 |
-
|
89 |
-
def _create_table(self, cursor) -> None:
|
90 |
-
"""Create the llm_cache table with the required schema."""
|
91 |
-
cursor.execute("""
|
92 |
-
CREATE TABLE IF NOT EXISTS llm_cache (
|
93 |
-
key TEXT PRIMARY KEY,
|
94 |
-
request_json TEXT,
|
95 |
-
response_json TEXT,
|
96 |
-
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
97 |
-
)
|
98 |
-
""")
|
99 |
-
|
100 |
-
def _verify_schema(self, cursor) -> None:
|
101 |
-
"""Verify that the table schema has all required columns."""
|
102 |
-
cursor.execute("PRAGMA table_info(llm_cache)")
|
103 |
-
columns = {row[1] for row in cursor.fetchall()}
|
104 |
-
required_columns = {"key", "request_json", "response_json", "created_at"}
|
105 |
-
|
106 |
-
if not required_columns.issubset(columns):
|
107 |
-
missing = required_columns - columns
|
108 |
-
raise ValueError(f"Database schema is corrupted. Missing columns: {missing}")
|
109 |
-
|
110 |
-
def _ensure_index_exists(self, cursor) -> None:
|
111 |
-
"""Create an index on the key column for faster lookups."""
|
112 |
-
cursor.execute("CREATE INDEX IF NOT EXISTS idx_llm_cache_key ON llm_cache (key)")
|
113 |
-
|
114 |
-
def get(self, key: str) -> Optional[dict[str, Any]]:
|
115 |
-
"""Get cached entry by key.
|
116 |
-
|
117 |
-
Args:
|
118 |
-
key: Cache key to look up
|
119 |
-
|
120 |
-
Returns:
|
121 |
-
Dict containing the request and response or None if not found
|
122 |
-
"""
|
123 |
-
try:
|
124 |
-
with sqlite3.connect(self.db_path) as conn:
|
125 |
-
conn.row_factory = sqlite3.Row
|
126 |
-
cursor = conn.cursor()
|
127 |
-
cursor.execute("SELECT request_json, response_json FROM llm_cache WHERE key = ?", (key,))
|
128 |
-
result = cursor.fetchone()
|
129 |
-
|
130 |
-
if result:
|
131 |
-
logger.debug(f"Cache hit for key: {key}. Response: {result['response_json']}")
|
132 |
-
return {
|
133 |
-
"request": result["request_json"],
|
134 |
-
"response": result["response_json"],
|
135 |
-
}
|
136 |
-
|
137 |
-
logger.debug(f"Cache miss for key: {key}")
|
138 |
-
return None
|
139 |
-
except Exception as e:
|
140 |
-
logger.error(f"Error retrieving from cache: {e}")
|
141 |
-
return None
|
142 |
-
|
143 |
-
def set(self, key: str, request_json: str, response_json: str) -> bool:
|
144 |
-
"""Set entry in cache.
|
145 |
-
|
146 |
-
Args:
|
147 |
-
key: Cache key
|
148 |
-
request_json: JSON string of request parameters
|
149 |
-
response_json: JSON string of response
|
150 |
-
|
151 |
-
Returns:
|
152 |
-
True if successful, False otherwise
|
153 |
-
"""
|
154 |
-
with self.lock:
|
155 |
-
try:
|
156 |
-
with sqlite3.connect(self.db_path) as conn:
|
157 |
-
cursor = conn.cursor()
|
158 |
-
cursor.execute(
|
159 |
-
"INSERT OR REPLACE INTO llm_cache (key, request_json, response_json) VALUES (?, ?, ?)",
|
160 |
-
(key, request_json, response_json),
|
161 |
-
)
|
162 |
-
conn.commit()
|
163 |
-
logger.debug(f"Saved response to cache with key: {key}, response: {response_json}")
|
164 |
-
return True
|
165 |
-
except Exception as e:
|
166 |
-
logger.error(f"Failed to save to SQLite cache: {e}")
|
167 |
-
return False
|
168 |
-
|
169 |
-
def get_all_entries(self) -> dict[str, dict[str, Any]]:
|
170 |
-
"""Get all cache entries from the database."""
|
171 |
-
cache = {}
|
172 |
-
try:
|
173 |
-
with sqlite3.connect(self.db_path) as conn:
|
174 |
-
conn.row_factory = sqlite3.Row
|
175 |
-
cursor = conn.cursor()
|
176 |
-
cursor.execute("SELECT key, request_json, response_json FROM llm_cache ORDER BY created_at")
|
177 |
-
|
178 |
-
for row in cursor.fetchall():
|
179 |
-
cache[row["key"]] = {
|
180 |
-
"request": row["request_json"],
|
181 |
-
"response": row["response_json"],
|
182 |
-
}
|
183 |
-
|
184 |
-
logger.debug(f"Retrieved {len(cache)} entries from cache database")
|
185 |
-
return cache
|
186 |
-
except Exception as e:
|
187 |
-
logger.error(f"Error retrieving all cache entries: {e}")
|
188 |
-
return {}
|
189 |
-
|
190 |
-
def clear(self) -> bool:
|
191 |
-
"""Clear all cache entries.
|
192 |
-
|
193 |
-
Returns:
|
194 |
-
True if successful, False otherwise
|
195 |
-
"""
|
196 |
-
with self.lock:
|
197 |
-
try:
|
198 |
-
with sqlite3.connect(self.db_path) as conn:
|
199 |
-
cursor = conn.cursor()
|
200 |
-
cursor.execute("DELETE FROM llm_cache")
|
201 |
-
conn.commit()
|
202 |
-
logger.info("Cache cleared")
|
203 |
-
return True
|
204 |
-
except Exception as e:
|
205 |
-
logger.error(f"Failed to clear cache: {e}")
|
206 |
-
return False
|
207 |
-
|
208 |
-
def get_existing_keys(self) -> set:
|
209 |
-
"""Get all existing keys in the database.
|
210 |
-
|
211 |
-
Returns:
|
212 |
-
Set of keys
|
213 |
-
"""
|
214 |
-
existing_keys = set()
|
215 |
-
try:
|
216 |
-
with sqlite3.connect(self.db_path) as conn:
|
217 |
-
cursor = conn.cursor()
|
218 |
-
cursor.execute("SELECT key FROM llm_cache")
|
219 |
-
for row in cursor.fetchall():
|
220 |
-
existing_keys.add(row[0])
|
221 |
-
return existing_keys
|
222 |
-
except Exception as e:
|
223 |
-
logger.error(f"Error retrieving existing keys: {e}")
|
224 |
-
return set()
|
225 |
-
|
226 |
-
def bulk_insert(self, items: list, update: bool = False) -> int:
|
227 |
-
"""Insert multiple items into the cache.
|
228 |
-
|
229 |
-
Args:
|
230 |
-
items: List of (key, request_json, response_json) tuples
|
231 |
-
update: Whether to update existing entries
|
232 |
-
|
233 |
-
Returns:
|
234 |
-
Number of items inserted
|
235 |
-
"""
|
236 |
-
count = 0
|
237 |
-
UPDATE_OR_IGNORE = "INSERT OR REPLACE" if update else "INSERT OR IGNORE"
|
238 |
-
with self.lock:
|
239 |
-
try:
|
240 |
-
with sqlite3.connect(self.db_path) as conn:
|
241 |
-
cursor = conn.cursor()
|
242 |
-
cursor.executemany(
|
243 |
-
f"{UPDATE_OR_IGNORE} INTO llm_cache (key, request_json, response_json) VALUES (?, ?, ?)",
|
244 |
-
items,
|
245 |
-
)
|
246 |
-
count = cursor.rowcount
|
247 |
-
conn.commit()
|
248 |
-
return count
|
249 |
-
except Exception as e:
|
250 |
-
logger.error(f"Error during bulk insert: {e}")
|
251 |
-
return 0
|
252 |
-
|
253 |
-
|
254 |
-
class LLMCache:
|
255 |
-
def __init__(
|
256 |
-
self, cache_dir: str = ".", hf_repo: str | None = None, cache_sync_interval: int = 3600, reset: bool = False
|
257 |
-
):
|
258 |
-
self.cache_dir = Path(cache_dir)
|
259 |
-
self.db_path = self.cache_dir / "llm_cache.db"
|
260 |
-
self.hf_repo_id = hf_repo
|
261 |
-
self.cache_sync_interval = cache_sync_interval
|
262 |
-
self.last_sync_time = time.time()
|
263 |
-
|
264 |
-
# Create cache directory if it doesn't exist
|
265 |
-
self.cache_dir.mkdir(exist_ok=True, parents=True)
|
266 |
-
|
267 |
-
# Initialize CacheDB
|
268 |
-
self.db = CacheDB(self.db_path)
|
269 |
-
if reset:
|
270 |
-
self.db.clear()
|
271 |
-
|
272 |
-
# Try to load from HF dataset if available
|
273 |
-
try:
|
274 |
-
self._load_cache_from_hf()
|
275 |
-
except Exception as e:
|
276 |
-
logger.warning(f"Failed to load cache from HF dataset: {e}")
|
277 |
-
|
278 |
-
def response_format_to_dict(self, response_format: Any) -> dict[str, Any]:
|
279 |
-
"""Convert a response format to a dict."""
|
280 |
-
# If it's a Pydantic model, use its schema
|
281 |
-
if hasattr(response_format, "model_json_schema"):
|
282 |
-
response_format = response_format.model_json_schema()
|
283 |
-
|
284 |
-
# If it's a Pydantic model, use its dump
|
285 |
-
elif hasattr(response_format, "model_dump"):
|
286 |
-
response_format = response_format.model_dump()
|
287 |
-
|
288 |
-
if not isinstance(response_format, dict):
|
289 |
-
response_format = {"value": str(response_format)}
|
290 |
-
|
291 |
-
return response_format
|
292 |
-
|
293 |
-
def _generate_key(
|
294 |
-
self, model: str, system: str, prompt: str, response_format: Any, temperature: float | None = None
|
295 |
-
) -> str:
|
296 |
-
"""Generate a unique key for caching based on inputs."""
|
297 |
-
response_format_dict = self.response_format_to_dict(response_format)
|
298 |
-
response_format_str = json.dumps(response_format_dict, sort_keys=True)
|
299 |
-
# Include temperature in the key
|
300 |
-
key_content = f"{model}:{system}:{prompt}:{response_format_str}"
|
301 |
-
if temperature is not None:
|
302 |
-
key_content += f":{temperature:.2f}"
|
303 |
-
return hashlib.md5(key_content.encode()).hexdigest()
|
304 |
-
|
305 |
-
def _create_request_json(
|
306 |
-
self, model: str, system: str, prompt: str, response_format: Any, temperature: float | None
|
307 |
-
) -> str:
|
308 |
-
"""Create JSON string from request parameters."""
|
309 |
-
logger.info(f"Creating request JSON with temperature: {temperature}")
|
310 |
-
request_data = {
|
311 |
-
"model": model,
|
312 |
-
"system": system,
|
313 |
-
"prompt": prompt,
|
314 |
-
"response_format": self.response_format_to_dict(response_format),
|
315 |
-
"temperature": temperature,
|
316 |
-
}
|
317 |
-
return json.dumps(request_data)
|
318 |
-
|
319 |
-
def _check_request_match(
|
320 |
-
self,
|
321 |
-
cached_request: dict[str, Any],
|
322 |
-
model: str,
|
323 |
-
system: str,
|
324 |
-
prompt: str,
|
325 |
-
response_format: Any,
|
326 |
-
temperature: float | None,
|
327 |
-
) -> bool:
|
328 |
-
"""Check if the cached request matches the new request."""
|
329 |
-
# Check each field and log any mismatches
|
330 |
-
if cached_request["model"] != model:
|
331 |
-
logger.debug(f"Cache mismatch: model - cached: {cached_request['model']}, new: {model}")
|
332 |
-
return False
|
333 |
-
if cached_request["system"] != system:
|
334 |
-
logger.debug(f"Cache mismatch: system - cached: {cached_request['system']}, new: {system}")
|
335 |
-
return False
|
336 |
-
if cached_request["prompt"] != prompt:
|
337 |
-
logger.debug(f"Cache mismatch: prompt - cached: {cached_request['prompt']}, new: {prompt}")
|
338 |
-
return False
|
339 |
-
response_format_dict = self.response_format_to_dict(response_format)
|
340 |
-
if cached_request["response_format"] != response_format_dict:
|
341 |
-
logger.debug(
|
342 |
-
f"Cache mismatch: response_format - cached: {cached_request['response_format']}, new: {response_format_dict}"
|
343 |
-
)
|
344 |
-
return False
|
345 |
-
if cached_request["temperature"] != temperature:
|
346 |
-
logger.debug(f"Cache mismatch: temperature - cached: {cached_request['temperature']}, new: {temperature}")
|
347 |
-
return False
|
348 |
-
|
349 |
-
return True
|
350 |
-
|
351 |
-
def get(
|
352 |
-
self, model: str, system: str, prompt: str, response_format: dict[str, Any], temperature: float | None = None
|
353 |
-
) -> Optional[dict[str, Any]]:
|
354 |
-
"""Get cached response if it exists."""
|
355 |
-
key = self._generate_key(model, system, prompt, response_format, temperature)
|
356 |
-
result = self.db.get(key)
|
357 |
-
|
358 |
-
if not result:
|
359 |
-
return None
|
360 |
-
request_dict = json.loads(result["request"])
|
361 |
-
if not self._check_request_match(request_dict, model, system, prompt, response_format, temperature):
|
362 |
-
logger.warning(f"Cached request does not match new request for key: {key}")
|
363 |
-
return None
|
364 |
-
|
365 |
-
return json.loads(result["response"])
|
366 |
-
|
367 |
-
def set(
|
368 |
-
self,
|
369 |
-
model: str,
|
370 |
-
system: str,
|
371 |
-
prompt: str,
|
372 |
-
response_format: dict[str, Any],
|
373 |
-
temperature: float | None,
|
374 |
-
response: dict[str, Any],
|
375 |
-
) -> None:
|
376 |
-
"""Set response in cache and sync if needed."""
|
377 |
-
key = self._generate_key(model, system, prompt, response_format, temperature)
|
378 |
-
request_json = self._create_request_json(model, system, prompt, response_format, temperature)
|
379 |
-
response_json = json.dumps(response)
|
380 |
-
|
381 |
-
success = self.db.set(key, request_json, response_json)
|
382 |
-
|
383 |
-
# Check if we should sync to HF
|
384 |
-
if success and self.hf_repo_id and (time.time() - self.last_sync_time > self.cache_sync_interval):
|
385 |
-
try:
|
386 |
-
self.sync_to_hf()
|
387 |
-
self.last_sync_time = time.time()
|
388 |
-
except Exception as e:
|
389 |
-
logger.error(f"Failed to sync cache to HF dataset: {e}")
|
390 |
-
|
391 |
-
def _load_cache_from_hf(self) -> None:
|
392 |
-
"""Load cache from HF dataset if it exists and merge with local cache."""
|
393 |
-
if not self.hf_repo_id:
|
394 |
-
return
|
395 |
-
|
396 |
-
try:
|
397 |
-
# Check for new commits before loading the dataset
|
398 |
-
ds_path = (self.cache_dir / "hf_cache").as_posix()
|
399 |
-
dataset = load_dataset_from_hf(self.hf_repo_id, ds_path)["train"]
|
400 |
-
if not dataset:
|
401 |
-
logger.info("No new items to merge from HF dataset")
|
402 |
-
return
|
403 |
-
|
404 |
-
existing_keys = self.db.get_existing_keys()
|
405 |
-
|
406 |
-
logger.info(f"Found {len(dataset)} items in HF dataset. Existing keys: {len(existing_keys)}")
|
407 |
-
|
408 |
-
# Prepare batch items for insertion
|
409 |
-
items_to_insert = []
|
410 |
-
for item in dataset:
|
411 |
-
key = item["key"]
|
412 |
-
# Only update if not in local cache to prioritize local changes
|
413 |
-
if key in existing_keys:
|
414 |
-
continue
|
415 |
-
# Create request JSON
|
416 |
-
request_data = {
|
417 |
-
"model": item["model"],
|
418 |
-
"system": item["system"],
|
419 |
-
"prompt": item["prompt"],
|
420 |
-
"temperature": item["temperature"],
|
421 |
-
"response_format": None, # We can't fully reconstruct this
|
422 |
-
}
|
423 |
-
|
424 |
-
items_to_insert.append(
|
425 |
-
(
|
426 |
-
key,
|
427 |
-
json.dumps(request_data),
|
428 |
-
item["response"], # This is already a JSON string
|
429 |
-
)
|
430 |
-
)
|
431 |
-
logger.info(
|
432 |
-
f"Inserting item: {key} with temperature: {item['temperature']} and response: {item['response']}"
|
433 |
-
)
|
434 |
-
|
435 |
-
# Bulk insert new items
|
436 |
-
if items_to_insert:
|
437 |
-
inserted_count = self.db.bulk_insert(items_to_insert)
|
438 |
-
logger.info(f"Merged {inserted_count} items from HF dataset into SQLite cache")
|
439 |
-
else:
|
440 |
-
logger.info("No new items to merge from HF dataset")
|
441 |
-
except Exception as e:
|
442 |
-
logger.warning(f"Could not load cache from HF dataset: {e}")
|
443 |
-
|
444 |
-
def get_all_entries(self) -> dict[str, dict[str, Any]]:
|
445 |
-
"""Get all cache entries from the database."""
|
446 |
-
cache = self.db.get_all_entries()
|
447 |
-
entries = {}
|
448 |
-
for key, entry in cache.items():
|
449 |
-
request = json.loads(entry["request"])
|
450 |
-
response = json.loads(entry["response"])
|
451 |
-
entries[key] = {"request": request, "response": response}
|
452 |
-
return entries
|
453 |
-
|
454 |
-
def sync_to_hf(self) -> None:
|
455 |
-
"""Sync cache to HF dataset."""
|
456 |
-
if not self.hf_repo_id:
|
457 |
-
return
|
458 |
-
|
459 |
-
self._load_cache_from_hf()
|
460 |
-
|
461 |
-
# Get all entries from the database
|
462 |
-
cache = self.db.get_all_entries()
|
463 |
-
|
464 |
-
# Convert cache to dataset format
|
465 |
-
entries = []
|
466 |
-
for key, entry in cache.items():
|
467 |
-
request = json.loads(entry["request"])
|
468 |
-
response_str = entry["response"]
|
469 |
-
entries.append(
|
470 |
-
{
|
471 |
-
"key": key,
|
472 |
-
"model": request["model"],
|
473 |
-
"system": request["system"],
|
474 |
-
"prompt": request["prompt"],
|
475 |
-
"response_format": request["response_format"],
|
476 |
-
"temperature": request["temperature"],
|
477 |
-
"response": response_str,
|
478 |
-
}
|
479 |
-
)
|
480 |
-
|
481 |
-
# Create and push dataset
|
482 |
-
dataset = Dataset.from_list(entries)
|
483 |
-
dataset.push_to_hub(self.hf_repo_id, private=True)
|
484 |
-
logger.info(f"Synced {len(cache)} cached items to HF dataset {self.hf_repo_id}")
|
485 |
-
|
486 |
-
def clear(self) -> None:
|
487 |
-
"""Clear all cache entries."""
|
488 |
-
self.db.clear()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/workflows/llms.py
DELETED
@@ -1,285 +0,0 @@
|
|
1 |
-
# %%
|
2 |
-
|
3 |
-
import json
|
4 |
-
import os
|
5 |
-
from typing import Any, Optional
|
6 |
-
|
7 |
-
import cohere
|
8 |
-
import numpy as np
|
9 |
-
from langchain_anthropic import ChatAnthropic
|
10 |
-
from langchain_cohere import ChatCohere
|
11 |
-
from langchain_core.language_models import BaseChatModel
|
12 |
-
from langchain_openai import ChatOpenAI
|
13 |
-
from loguru import logger
|
14 |
-
from openai import OpenAI
|
15 |
-
from pydantic import BaseModel, Field
|
16 |
-
from pydantic._internal._core_utils import CoreSchemaOrField, is_core_schema
|
17 |
-
from pydantic.json_schema import GenerateJsonSchema
|
18 |
-
from rich import print as rprint
|
19 |
-
|
20 |
-
# Initialize global cache
|
21 |
-
from src.envs import CACHE_PATH, LLM_CACHE_REPO
|
22 |
-
|
23 |
-
from .configs import AVAILABLE_MODELS
|
24 |
-
from .llmcache import LLMCache
|
25 |
-
|
26 |
-
llm_cache = LLMCache(cache_dir=CACHE_PATH, hf_repo=LLM_CACHE_REPO)
|
27 |
-
|
28 |
-
|
29 |
-
class CohereSchemaGenerator(GenerateJsonSchema):
|
30 |
-
"""Generates JSON schema for Cohere models without default titles."""
|
31 |
-
|
32 |
-
def field_title_should_be_set(self, schema: CoreSchemaOrField) -> bool:
|
33 |
-
return_value = super().field_title_should_be_set(schema)
|
34 |
-
if return_value and is_core_schema(schema):
|
35 |
-
return False
|
36 |
-
return return_value
|
37 |
-
|
38 |
-
|
39 |
-
def _openai_is_json_mode_supported(model_name: str) -> bool:
|
40 |
-
if model_name.startswith("gpt-4"):
|
41 |
-
return True
|
42 |
-
if model_name.startswith("gpt-3.5"):
|
43 |
-
return False
|
44 |
-
logger.warning(f"OpenAI model {model_name} is not available in this app, skipping JSON mode, returning False")
|
45 |
-
return False
|
46 |
-
|
47 |
-
|
48 |
-
class LLMOutput(BaseModel):
|
49 |
-
content: str = Field(description="The content of the response")
|
50 |
-
logprob: Optional[float] = Field(None, description="The log probability of the response")
|
51 |
-
|
52 |
-
|
53 |
-
def _get_langchain_chat_output(llm: BaseChatModel, system: str, prompt: str) -> str:
|
54 |
-
output = llm.invoke([("system", system), ("human", prompt)])
|
55 |
-
ai_message = output["raw"]
|
56 |
-
content = {"content": ai_message.content, "tool_calls": ai_message.tool_calls}
|
57 |
-
content_str = json.dumps(content)
|
58 |
-
return {"content": content_str, "output": output["parsed"].model_dump()}
|
59 |
-
|
60 |
-
|
61 |
-
def _cohere_completion(
|
62 |
-
model: str, system: str, prompt: str, response_model, temperature: float | None = None, logprobs: bool = True
|
63 |
-
) -> str:
|
64 |
-
messages = [
|
65 |
-
{"role": "system", "content": system},
|
66 |
-
{"role": "user", "content": prompt},
|
67 |
-
]
|
68 |
-
client = cohere.ClientV2(api_key=os.getenv("COHERE_API_KEY"))
|
69 |
-
schema = response_model.model_json_schema(schema_generator=CohereSchemaGenerator)
|
70 |
-
if "title" in schema:
|
71 |
-
del schema["title"]
|
72 |
-
response_format = {
|
73 |
-
"type": "json_object",
|
74 |
-
"schema": schema,
|
75 |
-
}
|
76 |
-
response = client.chat(
|
77 |
-
model=model,
|
78 |
-
messages=messages,
|
79 |
-
response_format=response_format,
|
80 |
-
logprobs=logprobs,
|
81 |
-
temperature=temperature,
|
82 |
-
)
|
83 |
-
output = {}
|
84 |
-
output["content"] = response.message.content[0].text
|
85 |
-
output["output"] = response_model.model_validate_json(response.message.content[0].text).model_dump()
|
86 |
-
if logprobs:
|
87 |
-
output["logprob"] = sum(lp.logprobs[0] for lp in response.logprobs)
|
88 |
-
output["prob"] = np.exp(output["logprob"])
|
89 |
-
return output
|
90 |
-
|
91 |
-
|
92 |
-
def _openai_langchain_completion(
|
93 |
-
model: str, system: str, prompt: str, response_model, temperature: float | None = None
|
94 |
-
) -> str:
|
95 |
-
llm = ChatOpenAI(model=model, temperature=temperature).with_structured_output(response_model, include_raw=True)
|
96 |
-
return _get_langchain_chat_output(llm, system, prompt)
|
97 |
-
|
98 |
-
|
99 |
-
def _openai_completion(
|
100 |
-
model: str, system: str, prompt: str, response_model, temperature: float | None = None, logprobs: bool = True
|
101 |
-
) -> str:
|
102 |
-
messages = [
|
103 |
-
{"role": "system", "content": system},
|
104 |
-
{"role": "user", "content": prompt},
|
105 |
-
]
|
106 |
-
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
107 |
-
response = client.beta.chat.completions.parse(
|
108 |
-
model=model,
|
109 |
-
messages=messages,
|
110 |
-
response_format=response_model,
|
111 |
-
logprobs=logprobs,
|
112 |
-
temperature=temperature,
|
113 |
-
)
|
114 |
-
output = {}
|
115 |
-
output["content"] = response.choices[0].message.content
|
116 |
-
output["output"] = response.choices[0].message.parsed.model_dump()
|
117 |
-
if logprobs:
|
118 |
-
output["logprob"] = sum(lp.logprob for lp in response.choices[0].logprobs.content)
|
119 |
-
output["prob"] = np.exp(output["logprob"])
|
120 |
-
return output
|
121 |
-
|
122 |
-
|
123 |
-
def _anthropic_completion(
|
124 |
-
model: str, system: str, prompt: str, response_model, temperature: float | None = None
|
125 |
-
) -> str:
|
126 |
-
llm = ChatAnthropic(model=model, temperature=temperature).with_structured_output(response_model, include_raw=True)
|
127 |
-
return _get_langchain_chat_output(llm, system, prompt)
|
128 |
-
|
129 |
-
|
130 |
-
def _llm_completion(
|
131 |
-
model: str, system: str, prompt: str, response_format, temperature: float | None = None, logprobs: bool = False
|
132 |
-
) -> dict[str, Any]:
|
133 |
-
"""
|
134 |
-
Generate a completion from an LLM provider with structured output without caching.
|
135 |
-
|
136 |
-
Args:
|
137 |
-
model (str): Provider and model name in format "provider/model" (e.g. "OpenAI/gpt-4")
|
138 |
-
system (str): System prompt/instructions for the model
|
139 |
-
prompt (str): User prompt/input
|
140 |
-
response_format: Pydantic model defining the expected response structure
|
141 |
-
logprobs (bool, optional): Whether to return log probabilities. Defaults to False.
|
142 |
-
Note: Not supported by Anthropic models.
|
143 |
-
|
144 |
-
Returns:
|
145 |
-
dict: Contains:
|
146 |
-
- output: The structured response matching response_format
|
147 |
-
- logprob: (optional) Sum of log probabilities if logprobs=True
|
148 |
-
- prob: (optional) Exponential of logprob if logprobs=True
|
149 |
-
|
150 |
-
Raises:
|
151 |
-
ValueError: If logprobs=True with Anthropic models
|
152 |
-
"""
|
153 |
-
model_name = AVAILABLE_MODELS[model]["model"]
|
154 |
-
provider = model.split("/")[0]
|
155 |
-
if provider == "Cohere":
|
156 |
-
return _cohere_completion(model_name, system, prompt, response_format, temperature, logprobs)
|
157 |
-
elif provider == "OpenAI":
|
158 |
-
if _openai_is_json_mode_supported(model_name):
|
159 |
-
return _openai_completion(model_name, system, prompt, response_format, temperature, logprobs)
|
160 |
-
elif logprobs:
|
161 |
-
raise ValueError(f"{model} does not support logprobs feature.")
|
162 |
-
else:
|
163 |
-
return _openai_langchain_completion(model_name, system, prompt, response_format, temperature)
|
164 |
-
elif provider == "Anthropic":
|
165 |
-
if logprobs:
|
166 |
-
raise ValueError("Anthropic models do not support logprobs")
|
167 |
-
return _anthropic_completion(model_name, system, prompt, response_format, temperature)
|
168 |
-
else:
|
169 |
-
raise ValueError(f"Provider {provider} not supported")
|
170 |
-
|
171 |
-
|
172 |
-
def completion(
|
173 |
-
model: str, system: str, prompt: str, response_format, temperature: float | None = None, logprobs: bool = False
|
174 |
-
) -> dict[str, Any]:
|
175 |
-
"""
|
176 |
-
Generate a completion from an LLM provider with structured output with caching.
|
177 |
-
|
178 |
-
Args:
|
179 |
-
model (str): Provider and model name in format "provider/model" (e.g. "OpenAI/gpt-4")
|
180 |
-
system (str): System prompt/instructions for the model
|
181 |
-
prompt (str): User prompt/input
|
182 |
-
response_format: Pydantic model defining the expected response structure
|
183 |
-
logprobs (bool, optional): Whether to return log probabilities. Defaults to False.
|
184 |
-
Note: Not supported by Anthropic models.
|
185 |
-
|
186 |
-
Returns:
|
187 |
-
dict: Contains:
|
188 |
-
- output: The structured response matching response_format
|
189 |
-
- logprob: (optional) Sum of log probabilities if logprobs=True
|
190 |
-
- prob: (optional) Exponential of logprob if logprobs=True
|
191 |
-
|
192 |
-
Raises:
|
193 |
-
ValueError: If logprobs=True with Anthropic models
|
194 |
-
"""
|
195 |
-
if model not in AVAILABLE_MODELS:
|
196 |
-
raise ValueError(f"Model {model} not supported")
|
197 |
-
if logprobs and not AVAILABLE_MODELS[model].get("logprobs", False):
|
198 |
-
logger.warning(f"{model} does not support logprobs feature, setting logprobs to False")
|
199 |
-
logprobs = False
|
200 |
-
|
201 |
-
# Check cache first
|
202 |
-
cached_response = llm_cache.get(model, system, prompt, response_format, temperature)
|
203 |
-
if cached_response and (not logprobs or cached_response.get("logprob")):
|
204 |
-
logger.debug(f"Cache hit for model {model}")
|
205 |
-
return cached_response
|
206 |
-
|
207 |
-
logger.debug(f"Cache miss for model {model}, calling API. Logprobs: {logprobs}")
|
208 |
-
|
209 |
-
# Continue with the original implementation for cache miss
|
210 |
-
response = _llm_completion(model, system, prompt, response_format, temperature, logprobs)
|
211 |
-
|
212 |
-
# Update cache with the new response
|
213 |
-
llm_cache.set(
|
214 |
-
model,
|
215 |
-
system,
|
216 |
-
prompt,
|
217 |
-
response_format,
|
218 |
-
temperature,
|
219 |
-
response,
|
220 |
-
)
|
221 |
-
|
222 |
-
return response
|
223 |
-
|
224 |
-
|
225 |
-
# %%
|
226 |
-
if __name__ == "__main__":
|
227 |
-
from tqdm import tqdm
|
228 |
-
|
229 |
-
class ExplainedAnswer(BaseModel):
|
230 |
-
"""
|
231 |
-
The answer to the question and a terse explanation of the answer.
|
232 |
-
"""
|
233 |
-
|
234 |
-
answer: str = Field(description="The short answer to the question")
|
235 |
-
explanation: str = Field(description="5 words terse best explanation of the answer.")
|
236 |
-
|
237 |
-
models = list(AVAILABLE_MODELS.keys())[:1] # Just use the first model for testing
|
238 |
-
system = "You are an accurate and concise explainer of scientific concepts."
|
239 |
-
prompt = "Which planet is closest to the sun in the Milky Way galaxy? Answer directly, no explanation needed."
|
240 |
-
|
241 |
-
llm_cache = LLMCache(cache_dir=".", hf_repo="qanta-challenge/advcal-llm-cache", reset=True)
|
242 |
-
|
243 |
-
# First call - should be a cache miss
|
244 |
-
logger.info("First call - should be a cache miss")
|
245 |
-
for model in tqdm(models):
|
246 |
-
response = completion(model, system, prompt, ExplainedAnswer, logprobs=False)
|
247 |
-
rprint(response)
|
248 |
-
|
249 |
-
# Second call - should be a cache hit
|
250 |
-
logger.info("Second call - should be a cache hit")
|
251 |
-
for model in tqdm(models):
|
252 |
-
response = completion(model, system, prompt, ExplainedAnswer, logprobs=False)
|
253 |
-
rprint(response)
|
254 |
-
|
255 |
-
# Slightly different prompt - should be a cache miss
|
256 |
-
logger.info("Different prompt - should be a cache miss")
|
257 |
-
prompt2 = "Which planet is closest to the sun? Answer directly."
|
258 |
-
for model in tqdm(models):
|
259 |
-
response = completion(model, system, prompt2, ExplainedAnswer, logprobs=False)
|
260 |
-
rprint(response)
|
261 |
-
|
262 |
-
# Get cache entries count from SQLite
|
263 |
-
try:
|
264 |
-
cache_entries = llm_cache.get_all_entries()
|
265 |
-
logger.info(f"Cache now has {len(cache_entries)} items")
|
266 |
-
except Exception as e:
|
267 |
-
logger.error(f"Failed to get cache entries: {e}")
|
268 |
-
|
269 |
-
# Test adding entry with temperature parameter
|
270 |
-
logger.info("Testing with temperature parameter")
|
271 |
-
response = completion(models[0], system, "What is Mars?", ExplainedAnswer, temperature=0.7, logprobs=False)
|
272 |
-
rprint(response)
|
273 |
-
|
274 |
-
# Demonstrate forced sync to HF if repo is configured
|
275 |
-
if llm_cache.hf_repo_id:
|
276 |
-
logger.info("Forcing sync to HF dataset")
|
277 |
-
try:
|
278 |
-
llm_cache.sync_to_hf()
|
279 |
-
logger.info("Successfully synced to HF dataset")
|
280 |
-
except Exception as e:
|
281 |
-
logger.exception(f"Failed to sync to HF: {e}")
|
282 |
-
else:
|
283 |
-
logger.info("HF repo not configured, skipping sync test")
|
284 |
-
|
285 |
-
# %%
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/workflows/qb_agents.py
DELETED
@@ -1,232 +0,0 @@
|
|
1 |
-
import time
|
2 |
-
from typing import Any, Iterable, TypedDict
|
3 |
-
|
4 |
-
from loguru import logger
|
5 |
-
|
6 |
-
from .executors import WorkflowOutput, execute_workflow
|
7 |
-
from .structs import TossupWorkflow, Workflow
|
8 |
-
|
9 |
-
|
10 |
-
def _get_workflow_response(
|
11 |
-
workflow: Workflow, available_vars: dict[str, Any], logprob_step: bool | str = False
|
12 |
-
) -> tuple[WorkflowOutput, float]:
|
13 |
-
"""Get response from executing a complete workflow."""
|
14 |
-
start_time = time.time()
|
15 |
-
workflow_output = execute_workflow(workflow, available_vars, return_full_content=True, logprob_step=logprob_step)
|
16 |
-
response_time = time.time() - start_time
|
17 |
-
return workflow_output, response_time
|
18 |
-
|
19 |
-
|
20 |
-
class TossupResult(TypedDict):
|
21 |
-
answer: str # the model's answer
|
22 |
-
confidence: float # confidence score
|
23 |
-
logprob: float | None # log probability of the answer
|
24 |
-
buzz: bool # whether the agent buzzed
|
25 |
-
question_fragment: str # prefix of the question text so far
|
26 |
-
position: int # 1-indexed question run index
|
27 |
-
step_contents: list[str] # string content outputs of each step
|
28 |
-
response_time: float
|
29 |
-
step_outputs: dict[str, Any]
|
30 |
-
|
31 |
-
|
32 |
-
class BonusResult(TypedDict):
|
33 |
-
answer: str
|
34 |
-
confidence: float
|
35 |
-
explanation: str
|
36 |
-
response_time: float
|
37 |
-
step_contents: list[str]
|
38 |
-
step_outputs: dict[str, Any]
|
39 |
-
|
40 |
-
|
41 |
-
class QuizBowlTossupAgent:
|
42 |
-
"""Agent for handling tossup questions with multiple steps in the workflow."""
|
43 |
-
|
44 |
-
external_input_variable = "question_text"
|
45 |
-
output_variables = ["answer", "confidence"]
|
46 |
-
|
47 |
-
def __init__(self, workflow: TossupWorkflow):
|
48 |
-
"""Initialize the multi-step tossup agent.
|
49 |
-
|
50 |
-
Args:
|
51 |
-
workflow: The workflow containing multiple steps
|
52 |
-
buzz_threshold: Confidence threshold for buzzing
|
53 |
-
"""
|
54 |
-
self.workflow = workflow
|
55 |
-
self.output_variables = list(workflow.outputs.keys())
|
56 |
-
|
57 |
-
# Validate input variables
|
58 |
-
if self.external_input_variable not in workflow.inputs:
|
59 |
-
raise ValueError(f"External input variable {self.external_input_variable} not found in workflow inputs")
|
60 |
-
|
61 |
-
# Validate output variables
|
62 |
-
for out_var in self.output_variables:
|
63 |
-
if out_var not in workflow.outputs:
|
64 |
-
raise ValueError(f"Output variable {out_var} not found in workflow outputs")
|
65 |
-
|
66 |
-
def _single_run(self, question_run: str, position: int) -> TossupResult:
|
67 |
-
"""Process a single question run.
|
68 |
-
Args:
|
69 |
-
question_run: The question run to process
|
70 |
-
position: The position of the question run
|
71 |
-
|
72 |
-
Returns:
|
73 |
-
A TossupResult containing the answer, confidence, logprob, buzz, question fragment, position, step contents, response time, and step outputs
|
74 |
-
"""
|
75 |
-
answer_var_step = self.workflow.outputs["answer"].split(".")[0]
|
76 |
-
workflow_output, response_time = _get_workflow_response(
|
77 |
-
self.workflow, {self.external_input_variable: question_run}, logprob_step=answer_var_step
|
78 |
-
)
|
79 |
-
final_outputs = workflow_output["final_outputs"]
|
80 |
-
buzz = self.workflow.buzzer.run(final_outputs["confidence"], logprob=workflow_output["logprob"])
|
81 |
-
result: TossupResult = {
|
82 |
-
"position": position,
|
83 |
-
"answer": final_outputs["answer"],
|
84 |
-
"confidence": final_outputs["confidence"],
|
85 |
-
"logprob": workflow_output["logprob"],
|
86 |
-
"buzz": buzz,
|
87 |
-
"question_fragment": question_run,
|
88 |
-
"step_contents": workflow_output["step_contents"],
|
89 |
-
"step_outputs": workflow_output["intermediate_outputs"], # Include intermediate step outputs
|
90 |
-
"response_time": response_time,
|
91 |
-
}
|
92 |
-
return result
|
93 |
-
|
94 |
-
def run(self, question_runs: list[str], early_stop: bool = True) -> Iterable[TossupResult]:
|
95 |
-
"""Process a tossup question and decide when to buzz based on confidence.
|
96 |
-
|
97 |
-
Args:
|
98 |
-
question_runs: Progressive reveals of the question text
|
99 |
-
early_stop: Whether to stop after the first buzz
|
100 |
-
|
101 |
-
Yields:
|
102 |
-
Dict containing:
|
103 |
-
- answer: The model's answer
|
104 |
-
- confidence: Confidence score
|
105 |
-
- buzz: Whether to buzz
|
106 |
-
- question_fragment: Current question text
|
107 |
-
- position: Current position in question
|
108 |
-
- step_contents: String content outputs of each step
|
109 |
-
- response_time: Time taken for response
|
110 |
-
- step_outputs: Outputs from each step
|
111 |
-
"""
|
112 |
-
for i, question_text in enumerate(question_runs):
|
113 |
-
# Execute the complete workflow
|
114 |
-
result = self._single_run(question_text, i + 1)
|
115 |
-
|
116 |
-
yield result
|
117 |
-
|
118 |
-
# If we've reached the confidence threshold, buzz and stop
|
119 |
-
if early_stop and result["buzz"]:
|
120 |
-
if i + 1 < len(question_runs):
|
121 |
-
yield self._single_run(question_runs[-1], len(question_runs))
|
122 |
-
return
|
123 |
-
|
124 |
-
|
125 |
-
class QuizBowlBonusAgent:
|
126 |
-
"""Agent for handling bonus questions with multiple steps in the workflow."""
|
127 |
-
|
128 |
-
external_input_variables = ["leadin", "part"]
|
129 |
-
output_variables = ["answer", "confidence", "explanation"]
|
130 |
-
|
131 |
-
def __init__(self, workflow: Workflow):
|
132 |
-
"""Initialize the multi-step bonus agent.
|
133 |
-
|
134 |
-
Args:
|
135 |
-
workflow: The workflow containing multiple steps
|
136 |
-
"""
|
137 |
-
self.workflow = workflow
|
138 |
-
self.output_variables = list(workflow.outputs.keys())
|
139 |
-
|
140 |
-
# Validate input variables
|
141 |
-
for input_var in self.external_input_variables:
|
142 |
-
if input_var not in workflow.inputs:
|
143 |
-
raise ValueError(f"External input variable {input_var} not found in workflow inputs")
|
144 |
-
|
145 |
-
# Validate output variables
|
146 |
-
for out_var in self.output_variables:
|
147 |
-
if out_var not in workflow.outputs:
|
148 |
-
raise ValueError(f"Output variable {out_var} not found in workflow outputs")
|
149 |
-
|
150 |
-
def run(self, leadin: str, part: str) -> BonusResult:
|
151 |
-
"""Process a bonus part with the given leadin.
|
152 |
-
|
153 |
-
Args:
|
154 |
-
leadin: The leadin text for the bonus question
|
155 |
-
part: The specific part text to answer
|
156 |
-
|
157 |
-
Returns:
|
158 |
-
Dict containing:
|
159 |
-
- answer: The model's answer
|
160 |
-
- confidence: Confidence score
|
161 |
-
- explanation: Explanation for the answer
|
162 |
-
- step_contents: String content outputs of each step
|
163 |
-
- response_time: Time taken for response
|
164 |
-
- step_outputs: Outputs from each step
|
165 |
-
"""
|
166 |
-
workflow_output, response_time = _get_workflow_response(
|
167 |
-
self.workflow,
|
168 |
-
{
|
169 |
-
"leadin": leadin,
|
170 |
-
"part": part,
|
171 |
-
},
|
172 |
-
)
|
173 |
-
final_outputs = workflow_output["final_outputs"]
|
174 |
-
return {
|
175 |
-
"answer": final_outputs["answer"],
|
176 |
-
"confidence": final_outputs["confidence"],
|
177 |
-
"explanation": final_outputs["explanation"],
|
178 |
-
"step_contents": workflow_output["step_contents"],
|
179 |
-
"response_time": response_time,
|
180 |
-
"step_outputs": workflow_output["intermediate_outputs"], # Include intermediate step outputs
|
181 |
-
}
|
182 |
-
|
183 |
-
|
184 |
-
# Example usage
|
185 |
-
if __name__ == "__main__":
|
186 |
-
# Load the Quizbowl dataset
|
187 |
-
from datasets import load_dataset
|
188 |
-
|
189 |
-
from workflows.factory import create_quizbowl_bonus_workflow, create_quizbowl_tossup_workflow
|
190 |
-
|
191 |
-
ds_name = "qanta-challenge/leaderboard_co_set"
|
192 |
-
ds = load_dataset(ds_name, split="train")
|
193 |
-
|
194 |
-
# Create the agents with multi-step workflows
|
195 |
-
tossup_workflow = create_quizbowl_tossup_workflow()
|
196 |
-
tossup_agent = QuizBowlTossupAgent(workflow=tossup_workflow, buzz_threshold=0.9)
|
197 |
-
|
198 |
-
bonus_workflow = create_quizbowl_bonus_workflow()
|
199 |
-
bonus_agent = QuizBowlBonusAgent(workflow=bonus_workflow)
|
200 |
-
|
201 |
-
# Example for tossup mode
|
202 |
-
print("\n=== TOSSUP MODE EXAMPLE ===")
|
203 |
-
sample_question = ds[30]
|
204 |
-
print(sample_question["question_runs"][-1])
|
205 |
-
print(sample_question["gold_label"])
|
206 |
-
print()
|
207 |
-
question_runs = sample_question["question_runs"]
|
208 |
-
|
209 |
-
results = tossup_agent.run(question_runs, early_stop=True)
|
210 |
-
for result in results:
|
211 |
-
print(result["step_contents"])
|
212 |
-
print(f"Guess at position {result['position']}: {result['answer']}")
|
213 |
-
print(f"Confidence: {result['confidence']}")
|
214 |
-
print("Step outputs:", result["step_outputs"])
|
215 |
-
if result["buzz"]:
|
216 |
-
print("Buzzed!\n")
|
217 |
-
|
218 |
-
# Example for bonus mode
|
219 |
-
print("\n=== BONUS MODE EXAMPLE ===")
|
220 |
-
sample_bonus = ds[31] # Assuming this is a bonus question
|
221 |
-
leadin = sample_bonus["leadin"]
|
222 |
-
parts = sample_bonus["parts"]
|
223 |
-
|
224 |
-
print(f"Leadin: {leadin}")
|
225 |
-
for i, part in enumerate(parts):
|
226 |
-
print(f"\nPart {i + 1}: {part['part']}")
|
227 |
-
result = bonus_agent.run(leadin, part["part"])
|
228 |
-
print(f"Answer: {result['answer']}")
|
229 |
-
print(f"Confidence: {result['confidence']}")
|
230 |
-
print(f"Explanation: {result['explanation']}")
|
231 |
-
print(f"Response time: {result['response_time']:.2f}s")
|
232 |
-
print("Step outputs:", result["step_outputs"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/workflows/structs.py
DELETED
@@ -1,370 +0,0 @@
|
|
1 |
-
# %%
|
2 |
-
from copy import deepcopy
|
3 |
-
from enum import Enum
|
4 |
-
from typing import Any, Literal, Optional
|
5 |
-
|
6 |
-
import numpy as np
|
7 |
-
from pydantic import BaseModel, Field, model_validator
|
8 |
-
|
9 |
-
from .configs import AVAILABLE_MODELS
|
10 |
-
|
11 |
-
"""
|
12 |
-
Core data structures for defining workflows and their components.
|
13 |
-
|
14 |
-
This module defines the primary classes used to model workflows, steps, and their
|
15 |
-
input/output fields. These data structures serve as the foundation for workflow
|
16 |
-
definition, validation, and execution throughout the workflows package.
|
17 |
-
|
18 |
-
The primary components are:
|
19 |
-
- InputField: Represents an input to a model step with name and source variable
|
20 |
-
- OutputField: Represents an output from a model step with name and type
|
21 |
-
- ModelStep: Represents a single step in a workflow with inputs and outputs
|
22 |
-
- Workflow: A collection of interconnected steps with defined inputs and outputs
|
23 |
-
|
24 |
-
All classes use Pydantic's BaseModel for validation and serialization support.
|
25 |
-
"""
|
26 |
-
FieldType = Literal["input", "output"]
|
27 |
-
|
28 |
-
|
29 |
-
SUPPORTED_TYPES = Literal["str", "int", "float", "bool", "list[str]", "list[int]", "list[float]", "list[bool]"]
|
30 |
-
"""Supported field types for input and output fields"""
|
31 |
-
|
32 |
-
|
33 |
-
class InputField(BaseModel):
|
34 |
-
"""
|
35 |
-
Defines an input field for a model step.
|
36 |
-
|
37 |
-
An input field specifies what data a step requires, where it comes from,
|
38 |
-
and optional pre-processing to apply before use.
|
39 |
-
|
40 |
-
Attributes:
|
41 |
-
name: The name of the input field within the step's context
|
42 |
-
description: Human-readable description of the input's purpose
|
43 |
-
variable: Reference to the source variable (format: "{step_id}.{field_name}" or external input name)
|
44 |
-
func: Optional function name to transform the input value before use
|
45 |
-
"""
|
46 |
-
|
47 |
-
name: str
|
48 |
-
description: str
|
49 |
-
variable: str
|
50 |
-
|
51 |
-
# function to call on the input before passing it to the model
|
52 |
-
func: str | None = None
|
53 |
-
|
54 |
-
class Config:
|
55 |
-
frozen = True
|
56 |
-
|
57 |
-
|
58 |
-
class OutputField(BaseModel):
|
59 |
-
"""
|
60 |
-
Defines an output field produced by a model step.
|
61 |
-
|
62 |
-
An output field specifies a value that the step will produce, including
|
63 |
-
its data type and optional post-processing.
|
64 |
-
|
65 |
-
Attributes:
|
66 |
-
name: The name of the output field within the step's context
|
67 |
-
description: Human-readable description of the output's purpose
|
68 |
-
type: The data type of the output (one of SUPPORTED_TYPES)
|
69 |
-
func: Optional function name to transform the raw output value
|
70 |
-
"""
|
71 |
-
|
72 |
-
name: str
|
73 |
-
type: SUPPORTED_TYPES = Field(default="str")
|
74 |
-
description: str
|
75 |
-
|
76 |
-
# function to call on the output string from the model
|
77 |
-
func: str | None = None
|
78 |
-
|
79 |
-
class Config:
|
80 |
-
frozen = True
|
81 |
-
|
82 |
-
|
83 |
-
class CallType(str, Enum):
|
84 |
-
LLM = "llm"
|
85 |
-
SEARCH = "search"
|
86 |
-
PYTHON_FUNC = "python_func"
|
87 |
-
|
88 |
-
|
89 |
-
class ModelStep(BaseModel):
|
90 |
-
"""
|
91 |
-
Represents a single step in a workflow.
|
92 |
-
|
93 |
-
A model step encapsulates the details of a specific operation within a workflow,
|
94 |
-
including what model to use, what inputs it requires, and what outputs it produces.
|
95 |
-
|
96 |
-
Attributes:
|
97 |
-
id: Unique identifier for this step within a workflow
|
98 |
-
model: The model to use for this step (e.g., "gpt-4")
|
99 |
-
provider: The provider of the model (e.g., "openai")
|
100 |
-
call_type: The type of operation (e.g., "llm", "search")
|
101 |
-
system_prompt: Instructions for the model
|
102 |
-
input_fields: List of input fields required by this step
|
103 |
-
output_fields: List of output fields produced by this step
|
104 |
-
"""
|
105 |
-
|
106 |
-
id: str
|
107 |
-
name: str
|
108 |
-
model: str
|
109 |
-
provider: str
|
110 |
-
call_type: CallType = CallType.LLM
|
111 |
-
|
112 |
-
# TODO: Validate that this is not None for call_type = llm
|
113 |
-
temperature: Optional[float] = None
|
114 |
-
|
115 |
-
system_prompt: str
|
116 |
-
input_fields: list[InputField]
|
117 |
-
output_fields: list[OutputField]
|
118 |
-
|
119 |
-
class Config:
|
120 |
-
use_enum_values = True
|
121 |
-
|
122 |
-
def fields(self, field_type: FieldType) -> list[InputField | OutputField]:
|
123 |
-
return self.input_fields if field_type == "input" else self.output_fields
|
124 |
-
|
125 |
-
def get_full_model_name(self) -> str:
|
126 |
-
return f"{self.provider}/{self.model}"
|
127 |
-
|
128 |
-
def get_produced_variables(self) -> list[str]:
|
129 |
-
return [f"{self.id}.{field.name}" for field in self.output_fields if field.name]
|
130 |
-
|
131 |
-
def update(self, update: dict[str, Any]) -> "ModelStep":
|
132 |
-
"""Returns a new copy with the updated properties."""
|
133 |
-
return self.model_copy(update=update)
|
134 |
-
|
135 |
-
def update_property(self, field: str, value: Any) -> "ModelStep":
|
136 |
-
"Update the `field` key of the model step with `value`."
|
137 |
-
return self.update({field: value})
|
138 |
-
|
139 |
-
def update_field(self, field_type: FieldType, index: int, key: str, value: str) -> "ModelStep":
|
140 |
-
"""Update a specific field of an input or output field at the given index."""
|
141 |
-
if field_type == "input":
|
142 |
-
fields = self.input_fields
|
143 |
-
elif field_type == "output":
|
144 |
-
fields = self.output_fields
|
145 |
-
else:
|
146 |
-
raise ValueError(f"Invalid field type: {field_type}")
|
147 |
-
|
148 |
-
if index < len(fields):
|
149 |
-
fields[index] = fields[index].model_copy(update={key: value})
|
150 |
-
return self.model_copy()
|
151 |
-
|
152 |
-
@staticmethod
|
153 |
-
def create_new_field(field_type: FieldType, input_var: str | None = None) -> InputField | OutputField:
|
154 |
-
if field_type == "input":
|
155 |
-
return InputField(name="", description="", variable=input_var)
|
156 |
-
elif field_type == "output":
|
157 |
-
return OutputField(name="", description="")
|
158 |
-
else:
|
159 |
-
raise ValueError(f"Invalid field type: {field_type}")
|
160 |
-
|
161 |
-
def add_field(self, field_type: FieldType, index: int = -1, input_var: str | None = None) -> "ModelStep":
|
162 |
-
"""Add a new field to the state and update visibility.
|
163 |
-
|
164 |
-
Args:
|
165 |
-
field_type: Type of field to add ('input' or 'output').
|
166 |
-
index: Position to insert the new field (-1 to append).
|
167 |
-
Returns:
|
168 |
-
A new ModelStep with the updated fields.
|
169 |
-
"""
|
170 |
-
if field_type == "input":
|
171 |
-
fields = deepcopy(self.input_fields)
|
172 |
-
new_field = ModelStep.create_new_field(field_type, input_var)
|
173 |
-
fields.insert(index + 1, new_field) if index != -1 else fields.append(new_field)
|
174 |
-
return self.model_copy(update={"input_fields": fields})
|
175 |
-
else:
|
176 |
-
fields = deepcopy(self.output_fields)
|
177 |
-
new_field = ModelStep.create_new_field(field_type)
|
178 |
-
fields.insert(index + 1, new_field) if index != -1 else fields.append(new_field)
|
179 |
-
return self.model_copy(update={"output_fields": fields})
|
180 |
-
|
181 |
-
def delete_field(self, field_type: FieldType, index: int) -> "ModelStep":
|
182 |
-
"""
|
183 |
-
Delete an input or output field from the state and update visibility.
|
184 |
-
|
185 |
-
Args:
|
186 |
-
field_type: Type of field to delete ('input' or 'output').
|
187 |
-
index: Index of the field to delete. [-1 to delete the last field]
|
188 |
-
|
189 |
-
Returns:
|
190 |
-
A new ModelStep with the updated fields.
|
191 |
-
"""
|
192 |
-
fields = self.input_fields if field_type == "input" else self.output_fields
|
193 |
-
fields = deepcopy(fields)
|
194 |
-
fields.pop(index)
|
195 |
-
return self.model_copy(update={"input_fields": fields} if field_type == "input" else {"output_fields": fields})
|
196 |
-
|
197 |
-
|
198 |
-
class Workflow(BaseModel):
|
199 |
-
"""
|
200 |
-
Represents a complete workflow composed of interconnected steps.
|
201 |
-
|
202 |
-
A workflow defines a directed acyclic graph of model steps, where outputs
|
203 |
-
from earlier steps can be used as inputs to later steps.
|
204 |
-
|
205 |
-
Attributes:
|
206 |
-
inputs: List of input variables required by the workflow
|
207 |
-
outputs: List of output variables produced by the workflow
|
208 |
-
steps: Dictionary mapping step IDs to ModelStep instances
|
209 |
-
|
210 |
-
The inputs and outputs lists use the format "{step_id}.{field_name}"
|
211 |
-
to uniquely identify variables within the workflow.
|
212 |
-
"""
|
213 |
-
|
214 |
-
# variables of form {node}.{field}
|
215 |
-
inputs: list[str] = Field(default_factory=list)
|
216 |
-
|
217 |
-
# variables of form {node}.{field}
|
218 |
-
outputs: dict[str, str | None] = Field(default_factory=dict)
|
219 |
-
steps: dict[str, ModelStep] = Field(default_factory=dict)
|
220 |
-
|
221 |
-
def model_dump(self, *args, **kwargs):
|
222 |
-
data = super().model_dump(*args, **kwargs)
|
223 |
-
if "steps" in data:
|
224 |
-
data["steps"] = list(data["steps"].values())
|
225 |
-
return data
|
226 |
-
|
227 |
-
@model_validator(mode="before")
|
228 |
-
def dictify_steps(cls, data):
|
229 |
-
if "steps" in data and isinstance(data["steps"], list):
|
230 |
-
steps_dict = {}
|
231 |
-
for step in data["steps"]:
|
232 |
-
if isinstance(step, ModelStep):
|
233 |
-
step_id = step.id
|
234 |
-
else:
|
235 |
-
step_id = step["id"]
|
236 |
-
if step_id in steps_dict:
|
237 |
-
raise ValueError(f"Duplicate step ID: {step_id}")
|
238 |
-
steps_dict[step_id] = step
|
239 |
-
data["steps"] = steps_dict
|
240 |
-
return data
|
241 |
-
|
242 |
-
def get_step_variables(self, step_id: str) -> list[str]:
|
243 |
-
"""Get all variables from a specific step."""
|
244 |
-
step = self.steps[step_id]
|
245 |
-
variables = []
|
246 |
-
for output in step.output_fields:
|
247 |
-
if output.name == "":
|
248 |
-
continue
|
249 |
-
output_var = f"{step.id}.{output.name}"
|
250 |
-
variables.append(output_var)
|
251 |
-
return variables
|
252 |
-
|
253 |
-
def get_available_variables(self) -> list[str]:
|
254 |
-
"""Get all output variables from all steps."""
|
255 |
-
variables = set(self.inputs)
|
256 |
-
for step in self.steps.values():
|
257 |
-
variables.update(self.get_step_variables(step.id))
|
258 |
-
return list(variables)
|
259 |
-
|
260 |
-
def get_step_model_selections(self) -> dict[str, str]:
|
261 |
-
"""Get all model selections for all steps."""
|
262 |
-
return {step_id: step.get_full_model_name() for step_id, step in self.steps.items()}
|
263 |
-
|
264 |
-
def get_output_model_selections(self) -> dict[str, str]:
|
265 |
-
"""Get all output model selections for all steps."""
|
266 |
-
return {
|
267 |
-
output_var: target_var.split(".")[0] if target_var else None
|
268 |
-
for output_var, target_var in self.outputs.items()
|
269 |
-
}
|
270 |
-
|
271 |
-
# Step update method
|
272 |
-
|
273 |
-
def add_step(self, step: ModelStep) -> "Workflow":
|
274 |
-
"""Add a step to the workflow."""
|
275 |
-
steps = self.steps | {step.id: step}
|
276 |
-
return self.model_copy(update={"steps": steps})
|
277 |
-
|
278 |
-
def remove_step(self, step_id: str) -> "Workflow":
|
279 |
-
"""Remove a step from the workflow."""
|
280 |
-
self.steps.pop(step_id)
|
281 |
-
workflow = self.model_copy(update={"steps": self.steps})
|
282 |
-
workflow.refresh_output_variables()
|
283 |
-
return workflow
|
284 |
-
|
285 |
-
def update_step(self, step: ModelStep) -> "Workflow":
|
286 |
-
"""Update a step in the workflow."""
|
287 |
-
self.steps[step.id] = step
|
288 |
-
steps = self.steps | {step.id: step}
|
289 |
-
workflow = self.model_copy(update={"steps": steps})
|
290 |
-
workflow.refresh_output_variables()
|
291 |
-
return workflow
|
292 |
-
|
293 |
-
# Output variables
|
294 |
-
def refresh_output_variables(self) -> "Workflow":
|
295 |
-
"""Refresh the output variables for the workflow."""
|
296 |
-
produced_variables = self.get_available_variables()
|
297 |
-
self.outputs = {k: (v if v in produced_variables else None) for k, v in self.outputs.items()}
|
298 |
-
return self
|
299 |
-
|
300 |
-
|
301 |
-
class BuzzerMethod(str, Enum):
|
302 |
-
AND = "AND"
|
303 |
-
OR = "OR"
|
304 |
-
|
305 |
-
|
306 |
-
class Buzzer(BaseModel):
|
307 |
-
"""Configuration for when to buzz in a tossup question."""
|
308 |
-
|
309 |
-
method: BuzzerMethod = BuzzerMethod.AND # Logic to combine thresholds
|
310 |
-
confidence_threshold: float = Field(default=0.5, ge=0.0, le=1.0) # Minimum confidence to trigger a buzz
|
311 |
-
prob_threshold: float | None = None # Optional log probability threshold
|
312 |
-
|
313 |
-
class Config:
|
314 |
-
use_enum_values = True
|
315 |
-
frozen = True
|
316 |
-
|
317 |
-
def update(self, **kwargs) -> "Buzzer":
|
318 |
-
"""Update the buzzer with the given kwargs."""
|
319 |
-
return self.model_copy(update=kwargs)
|
320 |
-
|
321 |
-
def run(self, confidence: float, prob: float | None = None, logprob: float | None = None) -> bool:
|
322 |
-
"""Run the buzzer logic."""
|
323 |
-
if logprob is not None and prob is not None:
|
324 |
-
raise ValueError("Cannot provide both logprob and prob")
|
325 |
-
if self.prob_threshold is None:
|
326 |
-
return confidence >= self.confidence_threshold
|
327 |
-
if logprob is None and prob is None:
|
328 |
-
raise ValueError("Must provide either logprob or prob if prob_threshold is not None")
|
329 |
-
prob = prob or float(np.exp(logprob))
|
330 |
-
if self.method == BuzzerMethod.AND:
|
331 |
-
return confidence >= self.confidence_threshold and prob >= self.prob_threshold
|
332 |
-
elif self.method == BuzzerMethod.OR:
|
333 |
-
return confidence >= self.confidence_threshold or prob >= self.prob_threshold
|
334 |
-
else:
|
335 |
-
raise ValueError(f"Invalid buzzer method: {self.method}")
|
336 |
-
|
337 |
-
@model_validator(mode="after")
|
338 |
-
def validate_method_with_log_prob(cls, data):
|
339 |
-
"""Validate that if prob_threshold is None, method must be 'and'."""
|
340 |
-
if data.prob_threshold is None and data.method != BuzzerMethod.AND:
|
341 |
-
raise ValueError("If prob_threshold is None, method must be 'and'")
|
342 |
-
return data
|
343 |
-
|
344 |
-
|
345 |
-
class TossupWorkflow(Workflow):
|
346 |
-
"""Workflow specialized for tossup questions with buzzing capability."""
|
347 |
-
|
348 |
-
buzzer: Buzzer = Field(default_factory=Buzzer)
|
349 |
-
|
350 |
-
def get_answer_model(self, answer_var: str | None = None) -> str | None:
|
351 |
-
answer_var = answer_var or self.outputs["answer"]
|
352 |
-
if answer_var is None:
|
353 |
-
return None
|
354 |
-
step_id = answer_var.split(".")[0]
|
355 |
-
return self.steps[step_id].get_full_model_name()
|
356 |
-
|
357 |
-
def is_token_probs_supported(self, answer_var: str | None = None) -> bool:
|
358 |
-
model_name = self.get_answer_model(answer_var)
|
359 |
-
if model_name is None:
|
360 |
-
return True
|
361 |
-
return AVAILABLE_MODELS[model_name].get("logprobs", False)
|
362 |
-
|
363 |
-
def update_buzzer(self, buzzer: Buzzer) -> "TossupWorkflow":
|
364 |
-
"""Update the buzzer."""
|
365 |
-
return self.model_copy(update={"buzzer": buzzer})
|
366 |
-
|
367 |
-
def refresh_buzzer(self) -> "TossupWorkflow":
|
368 |
-
if not self.is_token_probs_supported():
|
369 |
-
return self.update_buzzer(self.buzzer.update(prob_threshold=None, method="AND"))
|
370 |
-
return self
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/workflows/utils.py
DELETED
@@ -1,195 +0,0 @@
|
|
1 |
-
from collections import deque
|
2 |
-
from typing import Any, Iterable
|
3 |
-
|
4 |
-
from .errors import CyclicDependencyError, UnknownVariableError, WorkflowError
|
5 |
-
from .structs import Workflow
|
6 |
-
|
7 |
-
"""
|
8 |
-
Utilities for workflow dependency management and execution order determination.
|
9 |
-
|
10 |
-
This module provides functions for analyzing workflows, determining dependencies between steps,
|
11 |
-
and calculating the correct execution order to ensure all dependencies are satisfied.
|
12 |
-
Key functionality includes:
|
13 |
-
|
14 |
-
- Variable to step mapping: Identifying which step produces each variable
|
15 |
-
- Dependency graph creation: Building a graph representing dependencies between steps
|
16 |
-
- Topological sorting: Determining a valid execution order based on dependencies
|
17 |
-
- Cycle detection: Identifying cyclic dependencies that would prevent execution
|
18 |
-
|
19 |
-
These utilities form the foundation for workflow validation and execution in the
|
20 |
-
workflow_executor module.
|
21 |
-
"""
|
22 |
-
|
23 |
-
|
24 |
-
def _create_variable_step_mapping(workflow: Workflow) -> dict[str, str]:
|
25 |
-
"""
|
26 |
-
Creates a mapping from produced variable names to the model step that produces them.
|
27 |
-
|
28 |
-
Args:
|
29 |
-
workflow (Workflow): The workflow containing steps and their input/output fields.
|
30 |
-
|
31 |
-
Returns:
|
32 |
-
dict[str, str]: A dictionary where keys are variable names (formatted as "{step_id}.{output name}")
|
33 |
-
and values are the step IDs that produce them.
|
34 |
-
|
35 |
-
Raises:
|
36 |
-
WorkflowError: If there are duplicate step IDs or if a variable is produced by multiple steps.
|
37 |
-
|
38 |
-
Example:
|
39 |
-
For a workflow with steps "extract" and "summarize" each producing outputs:
|
40 |
-
>>> _create_variable_step_mapping(workflow)
|
41 |
-
{'extract.keywords': 'extract', 'summarize.summary': 'summarize'}
|
42 |
-
"""
|
43 |
-
variable_step_map: dict[str, str] = {} # variable name -> step id
|
44 |
-
for step_id, step in workflow.steps.items():
|
45 |
-
for output in step.output_fields:
|
46 |
-
var_name = f"{step_id}.{output.name}"
|
47 |
-
if var_name in variable_step_map:
|
48 |
-
raise WorkflowError(f"Variable '{output.name}' has duplicate entry in step {step_id}")
|
49 |
-
variable_step_map[var_name] = step_id
|
50 |
-
return variable_step_map
|
51 |
-
|
52 |
-
|
53 |
-
def create_dependency_graph(workflow: Workflow, input_values: dict[str, Any]) -> dict[str, set[str]]:
|
54 |
-
"""
|
55 |
-
Creates a dependency graph from a workflow.
|
56 |
-
|
57 |
-
This function analyzes the workflow and determines which steps depend on others
|
58 |
-
based on their input/output relationships. A step depends on another if it requires
|
59 |
-
a variable that is produced by the other step. External inputs provided through
|
60 |
-
input_values don't create dependencies.
|
61 |
-
|
62 |
-
Args:
|
63 |
-
workflow (Workflow): The workflow containing steps and their input/output fields.
|
64 |
-
input_values (dict[str, Any]): A dictionary of external input values provided to the workflow.
|
65 |
-
|
66 |
-
Returns:
|
67 |
-
dict[str, set[str]]: A dictionary where keys are step IDs and values are sets of step IDs
|
68 |
-
that the key step depends on.
|
69 |
-
|
70 |
-
Raises:
|
71 |
-
UnknownVariableError: If an input field references a variable that is not provided
|
72 |
-
externally nor produced by any step.
|
73 |
-
|
74 |
-
Example:
|
75 |
-
For a workflow where step "classify" depends on output from "extract":
|
76 |
-
>>> create_dependency_graph(workflow, {})
|
77 |
-
{'extract': set(), 'classify': {'extract'}}
|
78 |
-
|
79 |
-
With external input provided for "text" variable:
|
80 |
-
>>> create_dependency_graph(workflow, {'text': 'Sample text'})
|
81 |
-
{'extract': set(), 'classify': {'extract'}}
|
82 |
-
"""
|
83 |
-
produced_by = _create_variable_step_mapping(workflow)
|
84 |
-
dependencies: dict[str, set[str]] = {step_id: set() for step_id in workflow.steps}
|
85 |
-
for step_id, step in workflow.steps.items():
|
86 |
-
for input_field in step.input_fields:
|
87 |
-
var = input_field.variable
|
88 |
-
# If the variable was provided externally, then no dependency is needed.
|
89 |
-
if var in input_values:
|
90 |
-
continue
|
91 |
-
# Otherwise, check if the variable is produced by a step.
|
92 |
-
if var in produced_by:
|
93 |
-
producer_step_id = produced_by[var]
|
94 |
-
if producer_step_id != step_id: # Avoid self-dependency
|
95 |
-
dependencies[step_id].add(producer_step_id)
|
96 |
-
else:
|
97 |
-
raise UnknownVariableError(f"Variable '{var}' is not provided externally nor produced by any step")
|
98 |
-
return dependencies
|
99 |
-
|
100 |
-
|
101 |
-
def detect_cycles(dep_graph: dict[str, Iterable[str]]) -> str | None:
|
102 |
-
"""Detects cycles in the dependency graph.
|
103 |
-
Args:
|
104 |
-
dep_graph: A dictionary where the keys are node IDs and the values are the dependent node IDs
|
105 |
-
Returns:
|
106 |
-
The first step id of a model_step that is part of a cycle, None if no cycles are found
|
107 |
-
"""
|
108 |
-
# Check for cycles in step dependencies
|
109 |
-
visited = set()
|
110 |
-
path = set()
|
111 |
-
|
112 |
-
def has_cycle(node: str) -> bool:
|
113 |
-
if node in path:
|
114 |
-
return True
|
115 |
-
if node in visited:
|
116 |
-
return False
|
117 |
-
|
118 |
-
visited.add(node)
|
119 |
-
path.add(node)
|
120 |
-
|
121 |
-
for neighbor in dep_graph.get(node, set()):
|
122 |
-
if has_cycle(neighbor):
|
123 |
-
return True
|
124 |
-
|
125 |
-
path.remove(node)
|
126 |
-
return False
|
127 |
-
|
128 |
-
# Check each step for cycles
|
129 |
-
for node_id in dep_graph:
|
130 |
-
if has_cycle(node_id):
|
131 |
-
return node_id
|
132 |
-
return None
|
133 |
-
|
134 |
-
|
135 |
-
def topological_sort(dependencies: dict[str, set[str]]) -> list[str]:
|
136 |
-
"""
|
137 |
-
Performs a topological sort on a dependency graph and detects cycles using Kahn's algorithm.
|
138 |
-
|
139 |
-
A topological sort orders the steps such that for every dependency from step A to step B,
|
140 |
-
step A comes before step B in the ordering. This ensures that all dependencies are satisfied
|
141 |
-
when executing steps in the returned order.
|
142 |
-
|
143 |
-
Args:
|
144 |
-
dependencies (dict[str, set[str]]): A dictionary where each key is a node identifier and
|
145 |
-
each value is a set of nodes that the key node depends on.
|
146 |
-
|
147 |
-
Returns:
|
148 |
-
list[str]: A list representing the nodes in topological order if no cycle is detected.
|
149 |
-
|
150 |
-
Raises:
|
151 |
-
CyclicDependencyError: If a cycle is detected in the graph.
|
152 |
-
|
153 |
-
Example:
|
154 |
-
>>> topological_sort({'A': set(), 'B': {'A'}, 'C': {'B'}})
|
155 |
-
['A', 'B', 'C']
|
156 |
-
|
157 |
-
>>> topological_sort({'A': {'B'}, 'B': {'A'}}) # Cyclic dependency
|
158 |
-
CyclicDependencyError
|
159 |
-
|
160 |
-
Algorithm:
|
161 |
-
This implementation uses Kahn's algorithm:
|
162 |
-
1. Calculate in-degree for all nodes (number of dependencies)
|
163 |
-
2. Start with nodes having 0 in-degree (no dependencies)
|
164 |
-
3. Process each node by removing its outgoing edges
|
165 |
-
4. Add newly dependency-free nodes to the processing queue
|
166 |
-
5. If not all nodes are processed, a cycle exists
|
167 |
-
"""
|
168 |
-
|
169 |
-
nodes = list(dependencies.keys())
|
170 |
-
dependents: dict[str, list[str]] = {node: [] for node in nodes}
|
171 |
-
in_degree: dict[str, int] = dict.fromkeys(nodes, 0)
|
172 |
-
|
173 |
-
# Calculate in-degrees and build dependents list
|
174 |
-
for node, deps in dependencies.items():
|
175 |
-
in_degree[node] = len(deps)
|
176 |
-
for dep in deps:
|
177 |
-
dependents[dep].append(node)
|
178 |
-
|
179 |
-
# Initialize queue with nodes having zero in-degree
|
180 |
-
queue = deque([node for node, deg in in_degree.items() if deg == 0])
|
181 |
-
execution_order: list[str] = []
|
182 |
-
|
183 |
-
# Process nodes in topological order
|
184 |
-
while queue:
|
185 |
-
current = queue.popleft()
|
186 |
-
execution_order.append(current)
|
187 |
-
for dep in dependents[current]:
|
188 |
-
in_degree[dep] -= 1
|
189 |
-
if in_degree[dep] == 0:
|
190 |
-
queue.append(dep)
|
191 |
-
|
192 |
-
# If execution order includes all nodes, no cycle exists
|
193 |
-
if len(execution_order) != len(nodes):
|
194 |
-
raise CyclicDependencyError()
|
195 |
-
return execution_order
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/workflows/validators.py
DELETED
@@ -1,615 +0,0 @@
|
|
1 |
-
import keyword
|
2 |
-
import re
|
3 |
-
from dataclasses import dataclass
|
4 |
-
from enum import Enum
|
5 |
-
from typing import Optional
|
6 |
-
|
7 |
-
from .structs import CallType, InputField, ModelStep, OutputField, Workflow
|
8 |
-
from .utils import detect_cycles
|
9 |
-
|
10 |
-
SUPPORTED_TYPES = {"str", "int", "float", "bool", "list[str]", "list[int]", "list[float]", "list[bool]"}
|
11 |
-
|
12 |
-
# Constants for validation
|
13 |
-
MAX_FIELD_NAME_LENGTH = 50
|
14 |
-
MAX_DESCRIPTION_LENGTH = 200
|
15 |
-
MAX_SYSTEM_PROMPT_LENGTH = 4000
|
16 |
-
MAX_TEMPERATURE = 10.0
|
17 |
-
|
18 |
-
from loguru import logger
|
19 |
-
|
20 |
-
|
21 |
-
class ValidationErrorType(Enum):
|
22 |
-
"""Types of validation errors that can occur"""
|
23 |
-
|
24 |
-
INPUTS = "inputs"
|
25 |
-
OUTPUTS = "outputs"
|
26 |
-
STEP = "step"
|
27 |
-
DAG = "dag"
|
28 |
-
VARIABLE = "variable"
|
29 |
-
TYPE = "type"
|
30 |
-
GENERAL = "general"
|
31 |
-
NAMING = "naming"
|
32 |
-
LENGTH = "length"
|
33 |
-
RANGE = "range"
|
34 |
-
|
35 |
-
|
36 |
-
@dataclass
|
37 |
-
class ValidationError:
|
38 |
-
"""Represents a validation error with type and message"""
|
39 |
-
|
40 |
-
error_type: ValidationErrorType
|
41 |
-
message: str
|
42 |
-
step_id: Optional[str] = None
|
43 |
-
field_name: Optional[str] = None
|
44 |
-
|
45 |
-
def __str__(self):
|
46 |
-
subject = ""
|
47 |
-
if self.step_id:
|
48 |
-
subject = f"Model step '{self.step_id}'"
|
49 |
-
if self.field_name:
|
50 |
-
if self.step_id:
|
51 |
-
subject = f"Field '{self.step_id}.{self.field_name}'"
|
52 |
-
else:
|
53 |
-
subject = f"Field '{self.field_name}'"
|
54 |
-
return f"{self.error_type.value}: {subject} - {self.message}"
|
55 |
-
|
56 |
-
|
57 |
-
class WorkflowValidationError(ValueError):
|
58 |
-
"""Base class for workflow validation errors"""
|
59 |
-
|
60 |
-
def __init__(self, errors: list[ValidationError]):
|
61 |
-
self.errors = errors
|
62 |
-
super().__init__(f"Workflow validation failed with {len(errors)} errors")
|
63 |
-
|
64 |
-
|
65 |
-
def _parse_variable_reference(var: str) -> tuple[Optional[str], str]:
|
66 |
-
"""Extracts step_id and field_name from variable reference"""
|
67 |
-
parts = var.split(".")
|
68 |
-
if len(parts) == 1:
|
69 |
-
return None, parts[0]
|
70 |
-
return parts[0], parts[1]
|
71 |
-
|
72 |
-
|
73 |
-
def _get_step_dependencies(step: ModelStep) -> set[str]:
|
74 |
-
"""Gets set of step IDs that this step depends on"""
|
75 |
-
deps = set()
|
76 |
-
for field in step.input_fields:
|
77 |
-
step_id, _ = _parse_variable_reference(field.variable)
|
78 |
-
if step_id:
|
79 |
-
deps.add(step_id)
|
80 |
-
return deps
|
81 |
-
|
82 |
-
|
83 |
-
def create_step_dep_graph(workflow: Workflow) -> dict[str, set[str]]:
|
84 |
-
"""Creates a dependency graph of steps"""
|
85 |
-
dep_graph: dict[str, set[str]] = {}
|
86 |
-
for step_id, step in workflow.steps.items():
|
87 |
-
dep_graph[step_id] = _get_step_dependencies(step)
|
88 |
-
return dep_graph
|
89 |
-
|
90 |
-
|
91 |
-
class WorkflowValidator:
|
92 |
-
"""Validates workflows for correctness and consistency"""
|
93 |
-
|
94 |
-
def __init__(
|
95 |
-
self,
|
96 |
-
min_temperature: float = 0,
|
97 |
-
max_temperature: float = MAX_TEMPERATURE,
|
98 |
-
max_field_name_length: int = MAX_FIELD_NAME_LENGTH,
|
99 |
-
max_description_length: int = MAX_DESCRIPTION_LENGTH,
|
100 |
-
max_system_prompt_length: int = MAX_SYSTEM_PROMPT_LENGTH,
|
101 |
-
allowed_model_names: Optional[list[str]] = None,
|
102 |
-
required_input_vars: Optional[list[str]] = None,
|
103 |
-
required_output_vars: Optional[list[str]] = None,
|
104 |
-
):
|
105 |
-
self.errors: list[ValidationError] = []
|
106 |
-
self.workflow: Optional[Workflow] = None
|
107 |
-
self.min_temperature = min_temperature
|
108 |
-
self.max_temperature = max_temperature
|
109 |
-
self.max_field_name_length = max_field_name_length
|
110 |
-
self.max_description_length = max_description_length
|
111 |
-
self.max_system_prompt_length = max_system_prompt_length
|
112 |
-
self.required_input_vars = required_input_vars
|
113 |
-
self.required_output_vars = required_output_vars
|
114 |
-
self.allowed_model_names = set(allowed_model_names) if allowed_model_names else None
|
115 |
-
|
116 |
-
def validate(self, workflow: Workflow, allow_empty: bool = False) -> bool:
|
117 |
-
validated = self._validate(workflow, allow_empty)
|
118 |
-
if not validated:
|
119 |
-
raise WorkflowValidationError(self.errors)
|
120 |
-
return True
|
121 |
-
|
122 |
-
def _validate(self, workflow: Workflow, allow_empty: bool = False) -> bool:
|
123 |
-
"""Main validation entry point
|
124 |
-
Args:
|
125 |
-
workflow: The workflow to validate.
|
126 |
-
allow_empty: If True, empty workflow is allowed. This flag is used to validate the intermediate states while User edits the workflow.
|
127 |
-
"""
|
128 |
-
self.errors = []
|
129 |
-
self.workflow = workflow
|
130 |
-
|
131 |
-
# Basic workflow validation
|
132 |
-
if not self._validate_workflow_basic(workflow, allow_empty):
|
133 |
-
return False
|
134 |
-
|
135 |
-
# If it's a single-step workflow, use simple validation
|
136 |
-
if len(workflow.steps) == 1:
|
137 |
-
return self.validate_simple_workflow(workflow, allow_empty)
|
138 |
-
|
139 |
-
# Otherwise use complex validation
|
140 |
-
return self.validate_complex_workflow(workflow, allow_empty)
|
141 |
-
|
142 |
-
def _validate_required_inputs(self, workflow: Workflow, allow_empty: bool = False) -> bool:
|
143 |
-
"""Validates that the workflow has the correct inputs"""
|
144 |
-
|
145 |
-
required_input_vars = self.required_input_vars or []
|
146 |
-
input_vars = set(workflow.inputs)
|
147 |
-
for req_var in required_input_vars:
|
148 |
-
if req_var in input_vars:
|
149 |
-
continue
|
150 |
-
self.errors.append(
|
151 |
-
ValidationError(ValidationErrorType.INPUTS, f"Workflow must have '{req_var}' as an input")
|
152 |
-
)
|
153 |
-
return False
|
154 |
-
|
155 |
-
for input_var in input_vars:
|
156 |
-
if not self._is_valid_external_input(input_var):
|
157 |
-
self.errors.append(
|
158 |
-
ValidationError(ValidationErrorType.VARIABLE, f"Invalid input variable format: {input_var}")
|
159 |
-
)
|
160 |
-
return False
|
161 |
-
return True
|
162 |
-
|
163 |
-
def _validate_required_outputs(self, workflow: Workflow, allow_empty: bool = False) -> bool:
|
164 |
-
"""Validates that the workflow has the correct outputs"""
|
165 |
-
|
166 |
-
required_output_vars = self.required_output_vars or []
|
167 |
-
output_vars = set(workflow.outputs)
|
168 |
-
for req_var in required_output_vars:
|
169 |
-
if req_var in output_vars:
|
170 |
-
continue
|
171 |
-
self.errors.append(
|
172 |
-
ValidationError(ValidationErrorType.OUTPUTS, f"Workflow must produce '{req_var}' as an output")
|
173 |
-
)
|
174 |
-
return False
|
175 |
-
|
176 |
-
# Validate output variables
|
177 |
-
for output_name, output_var in workflow.outputs.items():
|
178 |
-
logger.debug(f"Output name: {output_name}, Output var: {output_var}")
|
179 |
-
if not output_var:
|
180 |
-
if allow_empty:
|
181 |
-
continue
|
182 |
-
self.errors.append(
|
183 |
-
ValidationError(ValidationErrorType.VARIABLE, f"Missing output variable for {output_name}")
|
184 |
-
)
|
185 |
-
return False
|
186 |
-
|
187 |
-
# Check if output variable references a valid step output
|
188 |
-
if not self._is_valid_variable_reference(output_var):
|
189 |
-
self.errors.append(
|
190 |
-
ValidationError(ValidationErrorType.VARIABLE, f"Invalid output variable reference: {output_var}")
|
191 |
-
)
|
192 |
-
return False
|
193 |
-
|
194 |
-
# Verify the output field exists in the referenced step
|
195 |
-
step_id, field_name = _parse_variable_reference(output_var)
|
196 |
-
logger.debug(f"Step ID: {step_id}, Field name: {field_name}, Workflow steps: {workflow.steps.keys()}")
|
197 |
-
if step_id not in workflow.steps:
|
198 |
-
self.errors.append(
|
199 |
-
ValidationError(ValidationErrorType.VARIABLE, f"Referenced model step '{step_id}' not found")
|
200 |
-
)
|
201 |
-
return False
|
202 |
-
|
203 |
-
ref_step = workflow.steps[step_id]
|
204 |
-
if not any(field.name == field_name for field in ref_step.output_fields):
|
205 |
-
self.errors.append(
|
206 |
-
ValidationError(
|
207 |
-
ValidationErrorType.VARIABLE,
|
208 |
-
f"Output field '{field_name}' not found in model step '{step_id}'",
|
209 |
-
step_id,
|
210 |
-
field_name,
|
211 |
-
)
|
212 |
-
)
|
213 |
-
return False
|
214 |
-
return True
|
215 |
-
|
216 |
-
def validate_input_outputs(self, workflow: Workflow, allow_empty: bool = False) -> bool:
|
217 |
-
"""Validates the input and output variables"""
|
218 |
-
|
219 |
-
self._validate_required_inputs(workflow, allow_empty)
|
220 |
-
self._validate_required_outputs(workflow, allow_empty)
|
221 |
-
|
222 |
-
# Check for atleast one input
|
223 |
-
if not workflow.inputs:
|
224 |
-
self.errors.append(
|
225 |
-
ValidationError(ValidationErrorType.GENERAL, "Workflow must contain at least one input")
|
226 |
-
)
|
227 |
-
|
228 |
-
# Check for atleast one output
|
229 |
-
if not workflow.outputs:
|
230 |
-
self.errors.append(
|
231 |
-
ValidationError(ValidationErrorType.GENERAL, "Workflow must contain at least one output")
|
232 |
-
)
|
233 |
-
|
234 |
-
return len(self.errors) == 0
|
235 |
-
|
236 |
-
def validate_simple_workflow(self, workflow: Workflow, allow_empty: bool = False) -> bool:
|
237 |
-
"""Validates a single-step workflow"""
|
238 |
-
if not self.workflow:
|
239 |
-
return False
|
240 |
-
|
241 |
-
# Get the single step
|
242 |
-
step = next(iter(workflow.steps.values()))
|
243 |
-
|
244 |
-
# Validate the step itself
|
245 |
-
if not self._validate_step(step, allow_empty):
|
246 |
-
return False
|
247 |
-
|
248 |
-
return True
|
249 |
-
|
250 |
-
def validate_complex_workflow(self, workflow: Workflow, allow_empty: bool = False) -> bool:
|
251 |
-
"""Validates a multi-step workflow"""
|
252 |
-
if not self.workflow:
|
253 |
-
return False
|
254 |
-
|
255 |
-
# Validate each step
|
256 |
-
for step in workflow.steps.values():
|
257 |
-
if not self._validate_step(step, allow_empty):
|
258 |
-
return False
|
259 |
-
|
260 |
-
dep_graph = create_step_dep_graph(workflow)
|
261 |
-
if cycle_step_id := detect_cycles(dep_graph):
|
262 |
-
self.errors.append(
|
263 |
-
ValidationError(
|
264 |
-
ValidationErrorType.DAG, f"Circular dependency detected involving step: {cycle_step_id}"
|
265 |
-
)
|
266 |
-
)
|
267 |
-
return False
|
268 |
-
|
269 |
-
# Check for orphaned steps (steps that aren't used by any other step)
|
270 |
-
used_steps = set()
|
271 |
-
for deps in dep_graph.values():
|
272 |
-
used_steps.update(deps)
|
273 |
-
for step_id in workflow.steps:
|
274 |
-
if step_id not in used_steps and not any(
|
275 |
-
output_var and _parse_variable_reference(output_var)[0] == step_id
|
276 |
-
for output_var in workflow.outputs.values()
|
277 |
-
):
|
278 |
-
self.errors.append(ValidationError(ValidationErrorType.DAG, f"Orphaned step detected: {step_id}"))
|
279 |
-
return False
|
280 |
-
|
281 |
-
# Validate variable dependencies
|
282 |
-
if not self._validate_variable_dependencies(workflow):
|
283 |
-
return False
|
284 |
-
|
285 |
-
return True
|
286 |
-
|
287 |
-
def _validate_workflow_basic(self, workflow: Workflow, allow_empty: bool = False) -> bool:
|
288 |
-
"""Validates basic workflow properties"""
|
289 |
-
|
290 |
-
# Check the workflow inputs and outputs
|
291 |
-
if not self.validate_input_outputs(workflow, allow_empty):
|
292 |
-
return False
|
293 |
-
|
294 |
-
# Check for empty workflow
|
295 |
-
if not workflow.steps:
|
296 |
-
if allow_empty:
|
297 |
-
return True
|
298 |
-
self.errors.append(ValidationError(ValidationErrorType.GENERAL, "Workflow must contain at least one step"))
|
299 |
-
return False
|
300 |
-
|
301 |
-
# Check for step ID consistency
|
302 |
-
for step_id, step in workflow.steps.items():
|
303 |
-
if step_id != step.id:
|
304 |
-
self.errors.append(
|
305 |
-
ValidationError(ValidationErrorType.STEP, f"Step ID mismatch: {step_id} != {step.id}", step_id)
|
306 |
-
)
|
307 |
-
return False
|
308 |
-
return True
|
309 |
-
|
310 |
-
def _validate_step(self, step: ModelStep, allow_empty: bool = False) -> bool:
|
311 |
-
"""Validates a single step"""
|
312 |
-
# Validate required fields
|
313 |
-
|
314 |
-
model_name = step.get_full_model_name()
|
315 |
-
|
316 |
-
if model_name == "/" and not allow_empty:
|
317 |
-
self.errors.append(
|
318 |
-
ValidationError(ValidationErrorType.STEP, "Model name and provider cannot be empty", step.id)
|
319 |
-
)
|
320 |
-
return False
|
321 |
-
|
322 |
-
# Check if the model names are allowed
|
323 |
-
if self.allowed_model_names and model_name not in self.allowed_model_names:
|
324 |
-
self.errors.append(
|
325 |
-
ValidationError(ValidationErrorType.STEP, f"Model name '{model_name}' is not allowed", step.id)
|
326 |
-
)
|
327 |
-
return False
|
328 |
-
|
329 |
-
if not step.id or not step.call_type:
|
330 |
-
self.errors.append(ValidationError(ValidationErrorType.STEP, "Step missing required fields", step.id))
|
331 |
-
return False
|
332 |
-
|
333 |
-
# Validate step ID and name
|
334 |
-
if not self._is_valid_identifier(step.id):
|
335 |
-
self.errors.append(
|
336 |
-
ValidationError(
|
337 |
-
ValidationErrorType.NAMING,
|
338 |
-
f"Invalid step ID format: {step.id}. Must be a valid identifier.",
|
339 |
-
step.id,
|
340 |
-
)
|
341 |
-
)
|
342 |
-
return False
|
343 |
-
|
344 |
-
# Validate temperature for LLM call type
|
345 |
-
if step.call_type == CallType.LLM:
|
346 |
-
if step.temperature is None:
|
347 |
-
self.errors.append(
|
348 |
-
ValidationError(ValidationErrorType.STEP, "LLM step must specify temperature", step.id)
|
349 |
-
)
|
350 |
-
return False
|
351 |
-
|
352 |
-
if not self.min_temperature <= step.temperature <= self.max_temperature:
|
353 |
-
self.errors.append(
|
354 |
-
ValidationError(
|
355 |
-
ValidationErrorType.RANGE,
|
356 |
-
f"Temperature must be between {self.min_temperature} and {self.max_temperature}",
|
357 |
-
step.id,
|
358 |
-
)
|
359 |
-
)
|
360 |
-
return False
|
361 |
-
|
362 |
-
# Validate system prompt for LLM call type
|
363 |
-
if step.call_type == CallType.LLM:
|
364 |
-
if not step.system_prompt:
|
365 |
-
self.errors.append(
|
366 |
-
ValidationError(ValidationErrorType.STEP, "LLM step must specify system prompt", step.id)
|
367 |
-
)
|
368 |
-
return False
|
369 |
-
|
370 |
-
if len(step.system_prompt) > self.max_system_prompt_length:
|
371 |
-
self.errors.append(
|
372 |
-
ValidationError(
|
373 |
-
ValidationErrorType.LENGTH,
|
374 |
-
f"System prompt exceeds maximum length of {self.max_system_prompt_length} characters",
|
375 |
-
step.id,
|
376 |
-
)
|
377 |
-
)
|
378 |
-
return False
|
379 |
-
|
380 |
-
# Validate input fields
|
381 |
-
input_names = set()
|
382 |
-
for field in step.input_fields:
|
383 |
-
if not self._validate_input_field(field, allow_empty):
|
384 |
-
return False
|
385 |
-
if field.name in input_names:
|
386 |
-
self.errors.append(
|
387 |
-
ValidationError(
|
388 |
-
ValidationErrorType.STEP, f"Duplicate input field name: {field.name}", step.id, field.name
|
389 |
-
)
|
390 |
-
)
|
391 |
-
return False
|
392 |
-
input_names.add(field.name)
|
393 |
-
|
394 |
-
# Validate output fields
|
395 |
-
output_names = set()
|
396 |
-
for field in step.output_fields:
|
397 |
-
if not self._validate_output_field(field, allow_empty):
|
398 |
-
return False
|
399 |
-
if field.name in output_names:
|
400 |
-
self.errors.append(
|
401 |
-
ValidationError(
|
402 |
-
ValidationErrorType.STEP, f"Duplicate output field name: {field.name}", step.id, field.name
|
403 |
-
)
|
404 |
-
)
|
405 |
-
return False
|
406 |
-
output_names.add(field.name)
|
407 |
-
|
408 |
-
return True
|
409 |
-
|
410 |
-
def _validate_input_field(self, field: InputField, allow_empty: bool = False) -> bool:
|
411 |
-
"""Validates an input field"""
|
412 |
-
# Validate required fields
|
413 |
-
if not field.name or not field.description or not field.variable:
|
414 |
-
self.errors.append(
|
415 |
-
ValidationError(ValidationErrorType.STEP, "Input field missing required fields", field_name=field.name)
|
416 |
-
)
|
417 |
-
return False
|
418 |
-
|
419 |
-
# Validate field name
|
420 |
-
if not self._is_valid_identifier(field.name, allow_empty):
|
421 |
-
self.errors.append(
|
422 |
-
ValidationError(
|
423 |
-
ValidationErrorType.NAMING,
|
424 |
-
f"Invalid field name format: {field.name}. Must be a valid Python identifier.",
|
425 |
-
field_name=field.name,
|
426 |
-
)
|
427 |
-
)
|
428 |
-
return False
|
429 |
-
|
430 |
-
# Validate field name length
|
431 |
-
if len(field.name) > self.max_field_name_length:
|
432 |
-
self.errors.append(
|
433 |
-
ValidationError(
|
434 |
-
ValidationErrorType.LENGTH,
|
435 |
-
f"Field name exceeds maximum length of {self.max_field_name_length} characters",
|
436 |
-
field_name=field.name,
|
437 |
-
)
|
438 |
-
)
|
439 |
-
return False
|
440 |
-
|
441 |
-
# Validate description length
|
442 |
-
if len(field.description) > self.max_description_length:
|
443 |
-
self.errors.append(
|
444 |
-
ValidationError(
|
445 |
-
ValidationErrorType.LENGTH,
|
446 |
-
f"Description exceeds maximum length of {self.max_description_length} characters",
|
447 |
-
field_name=field.name,
|
448 |
-
)
|
449 |
-
)
|
450 |
-
return False
|
451 |
-
|
452 |
-
# Validate variable reference
|
453 |
-
if not self._is_valid_variable_reference(field.variable):
|
454 |
-
self.errors.append(
|
455 |
-
ValidationError(
|
456 |
-
ValidationErrorType.VARIABLE,
|
457 |
-
f"Invalid variable reference: {field.variable}",
|
458 |
-
field_name=field.name,
|
459 |
-
)
|
460 |
-
)
|
461 |
-
return False
|
462 |
-
|
463 |
-
return True
|
464 |
-
|
465 |
-
def _validate_output_field(self, field: OutputField, allow_empty: bool = False) -> bool:
|
466 |
-
"""Validates an output field"""
|
467 |
-
# Validate required fields
|
468 |
-
if not field.name or not field.description:
|
469 |
-
self.errors.append(
|
470 |
-
ValidationError(
|
471 |
-
ValidationErrorType.STEP, "Output field missing required fields", field_name=field.name
|
472 |
-
)
|
473 |
-
)
|
474 |
-
return False
|
475 |
-
|
476 |
-
# Validate field name
|
477 |
-
if not self._is_valid_identifier(field.name, allow_empty):
|
478 |
-
self.errors.append(
|
479 |
-
ValidationError(
|
480 |
-
ValidationErrorType.NAMING,
|
481 |
-
f"Invalid field name format: {field.name}. Must be a valid Python identifier.",
|
482 |
-
field_name=field.name,
|
483 |
-
)
|
484 |
-
)
|
485 |
-
return False
|
486 |
-
|
487 |
-
# Validate field name length
|
488 |
-
if len(field.name) > self.max_field_name_length:
|
489 |
-
self.errors.append(
|
490 |
-
ValidationError(
|
491 |
-
ValidationErrorType.LENGTH,
|
492 |
-
f"Field name exceeds maximum length of {self.max_field_name_length} characters",
|
493 |
-
field_name=field.name,
|
494 |
-
)
|
495 |
-
)
|
496 |
-
return False
|
497 |
-
|
498 |
-
# Validate description length
|
499 |
-
if len(field.description) > self.max_description_length:
|
500 |
-
self.errors.append(
|
501 |
-
ValidationError(
|
502 |
-
ValidationErrorType.LENGTH,
|
503 |
-
f"Description exceeds maximum length of {self.max_description_length} characters",
|
504 |
-
field_name=field.name,
|
505 |
-
)
|
506 |
-
)
|
507 |
-
return False
|
508 |
-
|
509 |
-
# Validate type
|
510 |
-
if field.type not in SUPPORTED_TYPES:
|
511 |
-
self.errors.append(
|
512 |
-
ValidationError(
|
513 |
-
ValidationErrorType.TYPE, f"Unsupported output type: {field.type}", field_name=field.name
|
514 |
-
)
|
515 |
-
)
|
516 |
-
return False
|
517 |
-
|
518 |
-
return True
|
519 |
-
|
520 |
-
def _validate_simple_workflow_variables(self, workflow: Workflow) -> bool:
|
521 |
-
"""Validates variables in a simple workflow"""
|
522 |
-
step = next(iter(workflow.steps.values()))
|
523 |
-
|
524 |
-
# Validate input variables
|
525 |
-
for input_var in workflow.inputs:
|
526 |
-
if not self._is_valid_external_input(input_var):
|
527 |
-
self.errors.append(
|
528 |
-
ValidationError(ValidationErrorType.VARIABLE, f"Invalid input variable format: {input_var}")
|
529 |
-
)
|
530 |
-
return False
|
531 |
-
|
532 |
-
# Validate output variables
|
533 |
-
for output_name, output_var in workflow.outputs.items():
|
534 |
-
if output_var and not self._is_valid_variable_reference(output_var):
|
535 |
-
self.errors.append(
|
536 |
-
ValidationError(ValidationErrorType.VARIABLE, f"Invalid output variable reference: {output_var}")
|
537 |
-
)
|
538 |
-
return False
|
539 |
-
|
540 |
-
return True
|
541 |
-
|
542 |
-
def _validate_variable_dependencies(self, workflow: Workflow) -> bool:
|
543 |
-
"""Validates variable dependencies between steps"""
|
544 |
-
# Build variable dependency graph
|
545 |
-
var_graph: dict[str, set[str]] = {}
|
546 |
-
|
547 |
-
def create_var_dep_graph(workflow: Workflow) -> dict[str, set[str]]:
|
548 |
-
var_graph: dict[str, set[str]] = {}
|
549 |
-
for step_id, step in workflow.steps.items():
|
550 |
-
for field in step.input_fields:
|
551 |
-
if field.variable not in var_graph:
|
552 |
-
var_graph[field.variable] = set()
|
553 |
-
# Add dependency from input variable to step's outputs
|
554 |
-
for output in step.output_fields:
|
555 |
-
var_graph[field.variable].add(f"{step_id}.{output.name}")
|
556 |
-
return var_graph
|
557 |
-
|
558 |
-
# Check for cycles in variable dependencies
|
559 |
-
var_graph = create_var_dep_graph(workflow)
|
560 |
-
if cycle_var := detect_cycles(var_graph):
|
561 |
-
self.errors.append(
|
562 |
-
ValidationError(ValidationErrorType.VARIABLE, f"Circular variable dependency detected: {cycle_var}")
|
563 |
-
)
|
564 |
-
return False
|
565 |
-
|
566 |
-
# Validate external input existence
|
567 |
-
external_inputs = set(workflow.inputs)
|
568 |
-
for step in workflow.steps.values():
|
569 |
-
for field in step.input_fields:
|
570 |
-
step_id, field_name = _parse_variable_reference(field.variable)
|
571 |
-
if not step_id and field_name not in external_inputs:
|
572 |
-
self.errors.append(
|
573 |
-
ValidationError(
|
574 |
-
ValidationErrorType.VARIABLE,
|
575 |
-
f"External input '{field_name}' not found in workflow inputs",
|
576 |
-
field_name=field_name,
|
577 |
-
)
|
578 |
-
)
|
579 |
-
return False
|
580 |
-
|
581 |
-
return True
|
582 |
-
|
583 |
-
def _is_valid_variable_reference(self, var: str | None, allow_empty: bool = True) -> bool:
|
584 |
-
"""Validates if a variable reference is properly formatted"""
|
585 |
-
if not self.workflow:
|
586 |
-
return False
|
587 |
-
if var is None:
|
588 |
-
return allow_empty
|
589 |
-
parts = var.split(".")
|
590 |
-
if len(parts) == 1:
|
591 |
-
return True # External input
|
592 |
-
if len(parts) != 2:
|
593 |
-
return False
|
594 |
-
step_id, field_name = parts
|
595 |
-
return step_id in self.workflow.steps and any(
|
596 |
-
field.name == field_name for field in self.workflow.steps[step_id].output_fields
|
597 |
-
)
|
598 |
-
|
599 |
-
def _is_valid_external_input(self, var: str) -> bool:
|
600 |
-
"""Validates if a variable is a valid external input"""
|
601 |
-
if not var:
|
602 |
-
return False
|
603 |
-
if not self._is_valid_identifier(var):
|
604 |
-
return False
|
605 |
-
if keyword.iskeyword(var):
|
606 |
-
return False
|
607 |
-
if "." in var: # External inputs should not contain dots
|
608 |
-
return False
|
609 |
-
return True
|
610 |
-
|
611 |
-
def _is_valid_identifier(self, name: str, allow_empty: bool = False) -> bool:
|
612 |
-
"""Validates if a string is a valid Python identifier"""
|
613 |
-
if name and name.strip():
|
614 |
-
return bool(re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", name))
|
615 |
-
return allow_empty
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|